#!/usr/bin/env python3
# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from CliModel import Dict, List, Model, Str, Submodel, Int, Bool
import Arnet
import ArnetModel
from IntfModels import Interface
from Intf.IntfRange import intfListToCanonical
from CliPlugin.TrafficPolicyCliModel import Rule
from Toggles.PostcardTelemetryCommonToggleLib import (
   toggleFeaturePostcardTelemetryTcpControlEnabled )

class IntfList( Model ):
   interfaces = List( valueType=Interface, help='Interface list', optional=True )

class ProfileModel( Model ):
   interfaces = List( valueType=Interface, help='Interface list' )
   appliedInterfaces = List( valueType=Interface, help='Applied interface list '
                                                       'for IPv4 rules' )
   samplePolicy = Str( help="Sample Policy", optional=True )
   failures = Dict( keyType=str, valueType=IntfList,
                    help="Failed interfaces for IPv4 rules, "
                         "indexed by the failure message" )
   appliedInterfaces6 = List( valueType=Interface, help='Applied interface list '
                                                         'for IPv6 rules' )
   failures6 = Dict( keyType=str, valueType=IntfList,
                     help="Failed interfaces for IPv6 rules, "
                           "indexed by the failure message" )

   def render( self ):
      print( 'Sample policy:', self.samplePolicy )
      print( 'Configured on:',
             ','.join( intfListToCanonical( Arnet.sortIntf( self.interfaces ) ) ) )
      matchOptions = [ 'matchIpAccessGroup', 'matchIpv6AccessGroup' ]
      for matchOption in matchOptions:
         if matchOption == 'matchIpAccessGroup':
            ipType = 'IPv4'
            appliedInterfaces = self.appliedInterfaces
            failures = self.failures
         else:
            ipType = 'IPv6'
            appliedInterfaces = self.appliedInterfaces6
            failures = self.failures6

         if appliedInterfaces:
            print( f'{ipType} rules applied on interfaces:',
                      ','.join( intfListToCanonical( Arnet.sortIntf(
                         appliedInterfaces ) ) ) )
         else:
            print( f'{ipType} rules applied on interfaces:' )

         if not failures:
            print( f'{ipType} rules failed on interfaces:' )
         else:
            for error in sorted( failures ):
               intfs = failures[ error ].interfaces
               intfs = Arnet.sortIntf( intfs )
               intfs = ','.join( intfListToCanonical( intfs ) )
               error = '(%s)' % error
               print( f'{ipType} rules failed on interfaces:',
                      intfs, error )
      print( "" )

class ModelPostcardProfile( Model ):
   profiles = Dict( keyType=str, valueType=ProfileModel,
                    help="Maps core profiles to their corresponding configuration." )

   def render( self ):
      if self.profiles:
         print( 'Profiles' )
         for name in sorted( self.profiles.keys() ):
            print( 'Name:', name )
            profile = self.profiles[ name ]
            profile.render()

class SamplePolicyModel( Model ):
   rules = List( valueType=Rule, help="Detailed information of match rules" )
   tcpFlags = Dict( keyType=str, valueType=str,
                   help="Maps tcp flags to their corresponding match rules." )

   def render( self ):
      print( "Total number of rules configured:", len( self.rules ) )
      for rule in self.rules:
         print( f"match {rule.matchOption} {rule.ruleString}:" )
         if rule.matches.srcPrefixes:
            srcPrefixes = [ "%s" % src for src in rule.matches.srcPrefixes ]
            srcStr = ' '.join( srcPrefixes ).rstrip( ',' )
            print( "\tSource:", srcStr )
         if rule.matches.destPrefixes:
            destPrefixes = [ "%s" % dst for dst in rule.matches.destPrefixes ]
            dstStr = ' '.join( destPrefixes ).rstrip( ',' )
            print( "\tDestination:", dstStr )
         if rule.matches.protocols:
            for proto in sorted( rule.matches.protocols,
                                 key=lambda proto: proto.protocolRange.low ):
               protoStr = ""
               low = proto.protocolRange.low
               high = proto.protocolRange.high
               if low == high == 6:
                  protoStr += "tcp"
               elif low == high == 17:
                  protoStr += "udp"
               if proto.ports:
                  protoPort = proto.ports[ 0 ]
                  if protoPort.srcPorts:
                     srcPorts = protoPort.srcPorts
                     protoStr += '\n\t\tSource port: '
                     protoStr += ', '.join(
                        str( port.low ) if port.low == port.high else
                        str( port.low ) + '-' + str( port.high )
                        for port in sorted( srcPorts, key=lambda port: port.low ) )
                     protoStr = protoStr.rstrip( " " )
                     protoStr = protoStr.rstrip( "," )
                  if protoPort.destPorts:
                     destPorts = protoPort.destPorts
                     protoStr += '\n\t\tDestination port: '
                     protoStr += ', '.join(
                        str( port.low ) if port.low == port.high else
                        str( port.low ) + '-' + str( port.high )
                        for port in sorted( destPorts, key=lambda port: port.low ) )
                     protoStr = protoStr.rstrip( " " )
                     protoStr = protoStr.rstrip( "," )
               if toggleFeaturePostcardTelemetryTcpControlEnabled() and \
                  ( low == high == 6 ) and rule.ruleString in self.tcpFlags:
                  if self.tcpFlags[ rule.ruleString ]:
                     protoStr += '\n\t\tflags:'
                     protoStr += self.tcpFlags[ rule.ruleString ]

               print( "\tProtocol:", protoStr )
         else:
            if toggleFeaturePostcardTelemetryTcpControlEnabled():
               if rule.ruleString in self.tcpFlags:
                  if self.tcpFlags[ rule.ruleString ]:
                     protoStr = ''
                     protoStr += 'tcp'
                     protoStr += '\n\t\tflags:    '
                     protoStr += self.tcpFlags[ rule.ruleString ]
                     print( "\tProtocol:", protoStr )
         actions = rule.actions
         if actions.sample or actions.sampleAll:
            action = "sample" if actions.sample else "sample all"
            print( "\tActions:", action )
      print( "" )

class SamplePoliciesModel( Model ):
   policies = Dict( keyType=str, valueType=SamplePolicyModel,
                    help="Maps sample policy name to its configuration" )

   def render( self ):
      if self.policies:
         for name in sorted( self.policies.keys() ):
            policy = self.policies[ name ]
            print( "Sample policy", name )
            policy.render()

class Collection( Model ):
   srcIp = ArnetModel.IpGenericAddress( help="Source IP address" )
   dstIp = ArnetModel.IpGenericAddress( help="Destination IP address" )
   version = Int( help="GRE postcard version" )
   def render( self ):
      print( "Ingress collection source:", self.srcIp )
      print( "Ingress collection destination:", self.dstIp )
      print( "Ingress collection type: GRE postcard version", self.version )

class SampleTcpUdpChecksum( Model ):
   value = Int( help="TCP/UDP checksum or IP ID value" )
   mask = Int( help="Mask of TCP/UDP checksum or IP ID" )

   def render( self ):
      print( "Ingress collection sample TCP/UDP checksum value:", hex( self.value ) )
      print( "Ingress collection sample TCP/UDP checksum mask:", hex( self.mask ) )

class PostcardTelemetry( Model ):
   enabled = Bool( help="Postcard telemetry is enabled" )
   sampleRate = Int( help="Sampling rate of 1 in 16k, 32k or 64k packets" )
   sampleTcpUdpChecksum = Submodel( valueType=SampleTcpUdpChecksum,
                                    help="TCP/UDP checksum and mask",
                                    optional=True )
   vxlanEnabled = Bool( help="VXLAN marker mode is enabled" )
   vxlanMarkerBit = Int( help="VXLAN marker bit", optional=True )
   collection = Submodel( valueType=Collection, help="Collection attributes",
                          optional=True )

   def render( self ):
      print( "Enabled:", str( self.enabled ).lower() )
      print( "VXLAN enabled:", str( self.vxlanEnabled ).lower() )
      if self.vxlanEnabled:
         print( "VXLAN marker: header word 0 bit", self.vxlanMarkerBit )
      print( "Ingress collection sample rate:", self.sampleRate )
      if self.sampleTcpUdpChecksum:
         self.sampleTcpUdpChecksum.render()
      else:
         print( "Ingress collection sample TCP/UDP checksum value: unconfigured" )
         print( "Ingress collection sample TCP/UDP checksum mask: unconfigured" )
      if self.collection:
         self.collection.render()

class PostcardTelemetryCounters( Model ):
   sampleRcvd = Int( help="Total packets sampled by GreenT" )
   sampleDiscarded = Int( help="Total samples discarded" )
   multiDstSampleRcvd = Int( help="Muti destination packets sampled by GreenT" )
   grePktSent = Int( help="GRE packets sent to collector" )
   sampleSent = Int( help="Total samples sent to collector" )

   def render( self ):
      print( "Total packets sampled:", self.sampleRcvd )
      print( "Total samples discarded:", self.sampleDiscarded )
      print( "Multi destination packets sampled:", self.multiDstSampleRcvd )
      print( "GRE packets sent to collector:", self.grePktSent )
      print( "Total samples sent to collector:", self.sampleSent )
