# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Cell
import CliGlobal
import CliCommand
import CliParser
import ConfigMount
import LazyMount
import Tac
import CliMatcher
from CliMode.TrafficPolicy import (
      TrafficPoliciesConfigMode,
      TrafficPoliciesVrfConfigMode,
      MatchRuleIpv4ConfigMode,
      MatchRuleIpv6ConfigMode,
      FieldSetServiceConfigMode,
      TrafficPolicyConfigMode,
)
from CliParser import guardNotThisPlatform
from CliPlugin import ( VlanCli,
                        VxlanCli )
from CliPlugin.VrfCli import (
      VrfExprFactory,
)
from CliPlugin.TrafficPolicyCli import (
      TrafficPoliciesConfigCmd,
      TrafficPolicyConfigCmd,
      FieldSetIpPrefixConfigCmd,
      FieldSetIpv6PrefixConfigCmd,
      FieldSetL4PortConfigCmd,
      FieldSetVlanConfigCmd,
      FieldSetIntegerConfigCmd,
      FieldSetMacAddrConfigCmd,
      nodeCounterGlobalConfig,
      nodeInterface,
)
from CliPlugin.TrafficPolicyCliLib import (
      ActionType,
)
from CliPlugin.PolicyMapCliLib import (
      handleVrfPolicy,
)
from CliPlugin.ClassificationCliLib import (
      AppProfileMatchBaseConfigCmd,
      ServiceFieldSetCmdBase,
      FieldSetServiceBaseConfigCmd,
      CommitAbortModelet,
      ProtocolOredBase,
      protectedFieldSetNamesRegex,
      generateFieldSetExpression,
      generateTcpFlagExpression,
)
from Toggles.TrafficPolicyToggleLib import (
   toggleTrafficPoliciesVrfModeEnabled,
   toggleTrafficPolicyAppProfileMatchEnabled,
   toggleCpuTrafficPolicyPerVrfEnabled,
   toggleTrafficPolicyAndMatchDescriptionEnabled,
)

from Toggles.ClassificationToggleLib import (
   toggleTrafficPolicyFieldSetServiceEnabled,
)

from CliMatcher import (
      DynamicNameMatcher,
      IntegerMatcher,
      KeywordMatcher,
)
from VxlanVniLib import VniFormat

gv = CliGlobal.CliGlobal(
   vrfConfig=None,
   policiesCliConfig=None,
   policiesIntfParamConfig=None,
   policiesStatusRequestDir=None,
   policiesStatus=None,
   fieldSetConfig=None,
   locationAliasConfig=None,
   l3IntfConfigDir=None,
   switchIntfConfigDir=None,
   vlanConfigDir=None,
   intfTrafficPolicyHwStatus=None,
   cpuTrafficPolicyHwStatus=None,
   routingHwRouteStatus=None,
   ingressTrafficPolicyIntfConfig=None,
   cpuPoliciesVrfConfig=None,
   cliAppProfileConfig=None,
   feature='traffic-policy',
)

# ---------------------------------------------------------
# vxlan vni <VNI> traffic-policy input PMAP_NAME handlers
# ---------------------------------------------------------
def setTrafficPolicyOnVni( mode, args ):
   vni = args[ 'VNI' ]
   vniNum = VniFormat( vni ).toNum()
   if not VxlanCli.isValidVniWithError( mode, vniNum, mode.intf.name ):
      return

   # ingressTrafficPolicyIntfConfig can accept VNI with both types of
   # VLAN (static, dynamic) but for VNI with static VLAN, config will be ignored.
   policyName = args[ 'PMAP_NAME' ]
   gv.ingressTrafficPolicyIntfConfig.vni[ vniNum ] = policyName

   dynVlanSet = VlanCli.Vlan.getDynVlanSet( mode )
   vlan = VxlanCli.vniToVlanMap.getVlan( int( vniNum ), mode.intf.name )
   if vlan and vlan not in dynVlanSet:
      mode.addWarning( f"Traffic policy on VNI {vniNum} with static VLAN "
                        "{vlan} will be ingnored" )

def noTrafficPolicyOnVni( mode, args ):
   vni = args[ 'VNI' ]
   vniNum = VniFormat( vni ).toNum()
   if not VxlanCli.isValidVniWithError( mode, vniNum, mode.intf.name ):
      return

   inputPolicyName = args.get( 'PMAP_NAME' )
   currentPolicyName = gv.ingressTrafficPolicyIntfConfig.vni.get( vniNum )
   if inputPolicyName and currentPolicyName and inputPolicyName != currentPolicyName:
      return

   del gv.ingressTrafficPolicyIntfConfig.vni[ vniNum ]


def guardTrafficPoliciesVrfMode( mode, args ):
   # Use any PmapVrfIntf to check for presence of VRF support.
   ingressIntfVrfSupport = \
         gv.intfTrafficPolicyHwStatus.ingressTrafficPolicySupportedForIntf(
               "PmapVrfIntf1" )
   cpuTrafficPolicyPerVrfSupport = \
      gv.cpuTrafficPolicyHwStatus.cpuTrafficPolicyPerVrfSupported
   if ingressIntfVrfSupport or cpuTrafficPolicyPerVrfSupport:
      return None
   return guardNotThisPlatform

def guardTrafficPoliciesVrfModeInput( mode, args ):
   # Use any PmapVrfIntf to check for presence of VRF support.
   ingressIntfVrfSupport = \
         gv.intfTrafficPolicyHwStatus.ingressTrafficPolicySupportedForIntf(
               "PmapVrfIntf1" )
   if ingressIntfVrfSupport:
      return None
   return guardNotThisPlatform

trafficPolicyKeyword = KeywordMatcher( 'traffic-policy',
      helpdesc='Apply a traffic-policy' )
inputVrfKeyword = CliCommand.guardedKeyword( 'input',
      helpdesc='Assign traffic-policy to the input of a VRF',
      guard=guardTrafficPoliciesVrfModeInput )

def getTrafficPolicyNames( mode=None ):
   return gv.policiesCliConfig.pmapType.pmap

def noTrafficPolicies( mode, args ):
   gv.vrfConfig.vrf.clear()
   gv.vrfConfig.vrfConfig.trafficPolicies.clear()
   gv.cpuPoliciesVrfConfig.trafficPolicies.clear()

# --------------------------------------
# The "traffic-policies" mode command
# --------------------------------------
def doTrafficPoliciesHandler( mode, args ):
   childMode = mode.childMode( TrafficPoliciesConfigMode )
   mode.session_.gotoChildMode( childMode )

def noTrafficPoliciesHandler( mode, args ):
   # pylint: disable=protected-access
   for handler in TrafficPoliciesConfigCmd._noHandlers:
      handler( mode, args )
   for name in gv.policiesCliConfig.pmapType.pmap:
      TrafficPolicyConfigCmd._removePolicy( mode, name )
   for name in gv.fieldSetConfig.fieldSetIpPrefix:
      FieldSetIpPrefixConfigCmd._removeFieldSet( mode, name, "ipv4" )
   for name in gv.fieldSetConfig.fieldSetIpv6Prefix:
      FieldSetIpv6PrefixConfigCmd._removeFieldSet( mode, name, "ipv6" )
   for name in gv.fieldSetConfig.fieldSetL4Port:
      FieldSetL4PortConfigCmd._removeFieldSet( mode, name )
   for name in gv.fieldSetConfig.fieldSetVlan:
      FieldSetVlanConfigCmd._removeFieldSet( mode, name )
   for name in gv.fieldSetConfig.fieldSetInteger:
      FieldSetIntegerConfigCmd._removeFieldSet( mode, name )
   for name in gv.fieldSetConfig.fieldSetMacAddr:
      FieldSetMacAddrConfigCmd._removeFieldSet( mode, name )
   for name in gv.fieldSetConfig.fieldSetService:
      FieldSetServiceConfigCmd._removeFieldSet( mode, name )
   gv.locationAliasConfig.udf.clear()

# --------------------------------------
# The "vrf VRF" mode command
# --------------------------------------
class TrafficPolicyVrfConfigCmd( CliCommand.CliCommandClass ):
   syntax = "VRF"
   noOrDefaultSyntax = syntax
   data = {
      'VRF' : VrfExprFactory( helpdesc="Enter VRF sub-mode",
                              inclDefaultVrf=True,
                              guard=guardTrafficPoliciesVrfMode ),
   }

   @staticmethod
   def _vrfModeParams():
      return dict( feature=gv.feature, vrfConfig=gv.vrfConfig,
                   cpuPoliciesVrfConfig=gv.cpuPoliciesVrfConfig )

   @classmethod
   def handler( cls, mode, args ):
      vrfName = args[ 'VRF' ]
      childMode = mode.childMode( TrafficPoliciesVrfConfigMode,
                                  vrfName=vrfName,
                                  **cls._vrfModeParams() )
      childMode.createVrfConfig()
      mode.session_.gotoChildMode( childMode )

   @classmethod
   def noOrDefaultHandler( cls, mode, args ):
      vrfName = args[ 'VRF' ]
      childMode = mode.childMode( TrafficPoliciesVrfConfigMode,
                                  vrfName=vrfName,
                                  **cls._vrfModeParams() )
      childMode.vrfJanitor()

# ----------------------------------------
# Apply 'traffic-policy input PMAP' to VRF
# ----------------------------------------
class IntfTrafficPolicyOnVrf( CliCommand.CliCommandClass ):
   syntax = "traffic-policy input TRAFFIC_POLICY physical"
   noOrDefaultSyntax = "traffic-policy input [ TRAFFIC_POLICY physical ]"
   data = {
      'traffic-policy' : trafficPolicyKeyword,
      'input' : inputVrfKeyword,
      'TRAFFIC_POLICY' : DynamicNameMatcher(
         getTrafficPolicyNames, "Traffic Policy Name" ),
      'physical' : ( 'Apply traffic-policy to traffic arriving on physical '
                     'interfaces within VRF' ),
   }

   @staticmethod
   def _handleVrfPolicy( no, mode, args ):
      trafficPolicyName = args.get( 'TRAFFIC_POLICY' )
      usePmapVrfIntfId = False
      if gv.intfTrafficPolicyHwStatus:
         usePmapVrfIntfId = (
               gv.intfTrafficPolicyHwStatus.vrfApplicationViaPmapVrfIntfId )
      handleVrfPolicy( mode, no, trafficPolicyName,
                       gv.policiesCliConfig,
                       gv.policiesStatusRequestDir,
                       gv.policiesStatus,
                       gv.l3IntfConfigDir,
                       gv.switchIntfConfigDir,
                       gv.vlanConfigDir,
                       mode.vrfName,
                       gv.vrfConfig.vrfConfig,
                       sessionKey=f'intf-{gv.feature}',
                       routingHwRouteStatus=gv.routingHwRouteStatus,
                       usePmapVrfIntfId=usePmapVrfIntfId )

   @staticmethod
   def handler( mode, args ):
      IntfTrafficPolicyOnVrf._handleVrfPolicy( False, mode, args )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      IntfTrafficPolicyOnVrf._handleVrfPolicy( True, mode, args )

class ProtocolListServiceFieldSetCmd( ProtocolOredBase ):
   syntax = '''
protocol ( TCP_UDP | FLAGS_EXPR )
         ( source port SPORT [ destination port ( DPORT | all ) ] )
         | ( [ source port all ] destination port DPORT )
   '''
   noOrDefaultSyntax = '''
protocol [ ( TCP_UDP | FLAGS_EXPR )
             [ ( source port [ SPORT [ destination port ( DPORT | all ) ] ] )
             | ( source port all destination port DPORT )
             | ( destination port [ DPORT ] ) ] ]'''
   data = {
      'FLAGS_EXPR' : generateTcpFlagExpression( tcpFlagsSupported=False )
   }
   data.update( ProtocolOredBase._baseData )

class FieldSetServiceConfigNoAuthCmds( ProtocolListServiceFieldSetCmd ):
   data = ProtocolListServiceFieldSetCmd.data.copy()
   authz = False

# --------------------------------------------------------------------------
# The "protocol service field-set { FIELD_SET }"
# command
# --------------------------------------------------------------------------
def getServiceFieldSetNames( mode ):
   if gv.fieldSetConfig is None:
      return []
   return gv.fieldSetConfig.fieldSetService

serviceFieldSetNameMatcher = DynamicNameMatcher(
   getServiceFieldSetNames,
   "service field-set name",
   pattern=protectedFieldSetNamesRegex( 'service' ),
   priority=CliParser.PRIO_LOW )
serviceFieldSetExpr = generateFieldSetExpression( serviceFieldSetNameMatcher,
                                                   'FIELD_SET',
                                                    allowMultiple=True )

class ServiceFieldSetIpv4Cmd( ServiceFieldSetCmdBase ):
   data = {
      "FIELD_SET" : serviceFieldSetExpr
   }
   data.update( ServiceFieldSetCmdBase._baseData )

class ServiceFieldSetIpv6Cmd( ServiceFieldSetCmdBase ):
   data = {
      "FIELD_SET" : serviceFieldSetExpr
   }
   data.update( ServiceFieldSetCmdBase._baseData )

# --------------------------------------------------------------------------
# The "field-set service SERVICE_NAME" command
# --------------------------------------------------------------------------
class FieldSetServiceConfigCmd( FieldSetServiceBaseConfigCmd ):
   syntax = 'field-set service FIELD_SET_NAME'
   noOrDefaultSyntax = syntax
   _feature = "tp"
   data = {
      'FIELD_SET_NAME' : serviceFieldSetNameMatcher,
   }
   data.update( FieldSetServiceBaseConfigCmd._baseData )

   @classmethod
   def _getContextKwargs( cls, fieldSetServiceName, mode=None ):
      return {
         'fieldSetServiceName' : fieldSetServiceName,
         'fieldSetConfig' : gv.fieldSetConfig,
         'childMode' : FieldSetServiceConfigMode,
         'featureName' : 'tp',
         # Enabled config rollback in the ServiceFieldSetContext.
         'rollbackSupported' : True,
      }

FieldSetServiceConfigMode.addModelet( CommitAbortModelet )

if toggleTrafficPoliciesVrfModeEnabled() or \
   toggleCpuTrafficPolicyPerVrfEnabled():
   # toggleCpuTrafficPolicyPerVrf is for per-vrf cpu policy.
   TrafficPoliciesConfigMode.addCommandClass( TrafficPolicyVrfConfigCmd )
   # pylint: disable=protected-access
   TrafficPoliciesConfigCmd._registerNoHandler( noTrafficPolicies )
   # pylint: enable=protected-access
   if toggleTrafficPoliciesVrfModeEnabled():
      TrafficPoliciesVrfConfigMode.addCommandClass( IntfTrafficPolicyOnVrf )

def getAppProfileNames( mode ):
   return gv.cliAppProfileConfig.appProfile

appProfileNameMatcher = DynamicNameMatcher(
   getAppProfileNames, ' Application profile name ',
   pattern=r'[A-Za-z0-9_:{}\[\]-]+' )

# --------------------------------------------------------------------------
# The "application-profile APP_PROFILE_NAME" command
# --------------------------------------------------------------------------
class AppProfileMatchConfigCmd( AppProfileMatchBaseConfigCmd ):
   _feature = "tp"
   data = {
      'APP_PROFILE_NAME' : appProfileNameMatcher,
   }

   data.update( AppProfileMatchBaseConfigCmd._baseData )

# -------------------------------------------------------------------------------
# The "[no] description <desc>" command, in "traffic-policy" mode.
# -------------------------------------------------------------------------------
class PolicyDescriptionCmd( CliCommand.CliCommandClass ):
   syntax = "description DESC"
   noOrDefaultSyntax = "description ..."
   data = { 'description' : 'Description for this traffic policy',
            'DESC' : CliMatcher.StringMatcher(
                   helpname='LINE',
                   helpdesc='Description for this traffic policy' ) }

   @classmethod
   def handler( cls, mode, args ):
      context = mode.getContext()
      policyDesc = args.get( 'DESC', '' )
      if errStr := context.updatePolicyDesc( policyDesc,
            add=not CliCommand.isNoOrDefaultCmd( args ) ):
         mode.addError( errStr )

   noOrDefaultHandler = handler

if toggleTrafficPolicyAndMatchDescriptionEnabled():
   TrafficPolicyConfigMode.addCommandClass( PolicyDescriptionCmd )
# -------------------------------------------------------------------------------
# The "[no] description <desc>" command, in "match" mode.
# -------------------------------------------------------------------------------

class MatchDescriptionCmd( CliCommand.CliCommandClass ):
   syntax = "description DESC"
   noOrDefaultSyntax = "description ..."
   data = { 'description' : 'Description for this match rule',
            'DESC' : CliMatcher.StringMatcher(
                   helpname='LINE',
                   helpdesc='Description for this match rule' ) }

   @classmethod
   def handler( cls, mode, args ):
      context = mode.getContext()
      matchDesc = args.get( 'DESC', '' )
      if errStr := context.updateMatchDesc( matchDesc, mode,
            add=not CliCommand.isNoOrDefaultCmd( args ) ):
         mode.addError( errStr )

   noOrDefaultHandler = handler

#
# AF-agnostic rules
for _mode in [ MatchRuleIpv4ConfigMode,
               MatchRuleIpv6ConfigMode ]:
   if toggleTrafficPolicyAndMatchDescriptionEnabled():
      _mode.addCommandClass( MatchDescriptionCmd )
   if toggleTrafficPolicyAppProfileMatchEnabled():
      _mode.addCommandClass( AppProfileMatchConfigCmd )

def baseCheckErrors( mode, args ):
   errorStr = mode.context.matchRuleAction.actionCombinationError()
   if errorStr:
      mode.addError( errorStr )
   return errorStr is not None

# Override the ActionCmdBase handlers since we need to pass the msgType as the
# actionValue to setAction.
def dropActionCmdHandler( mode, args ):
   mode.context.setAction( ActionType.deny, args[ 'MSG_TYPE' ], no=False )

def dropActionCmdNoOrDefaultHandler( mode, args ):
   msgType = args[ 'MSG_TYPE' ]
   mode.context.setAction( ActionType.deny, msgType, no=True )
   hasError = baseCheckErrors( mode, args )
   if hasError:
      # revert change
      mode.context.setAction( ActionType.deny, msgType, no=False )

# --------------------------------------------------------------------
# The "counter interface poll interval <2-60 seconds> seconds" command
# --------------------------------------------------------------------
class CounterPollIntervalCmd( CliCommand.CliCommandClass ):
   syntax = "counter interface poll interval SECONDS seconds"
   noOrDefaultSyntax = "counter interface poll interval ..."
   data = {
      "counter" : nodeCounterGlobalConfig,
      "interface" : nodeInterface,
      "poll" : "Counters polling configuration",
      "interval" : "Counters update interval",
      "SECONDS" : IntegerMatcher( 2, 60,
                                  helpdesc="Seconds between consecutive polls" ),
      "seconds" : "Time unit in seconds",
   }

   @staticmethod
   def handler( mode, args ):
      gv.policiesIntfParamConfig.counterPollInterval = args[ 'SECONDS' ]

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      gv.policiesIntfParamConfig.counterPollInterval = None

TrafficPoliciesConfigMode.addCommandClass( CounterPollIntervalCmd )
# pylint: disable=protected-access
TrafficPoliciesConfigCmd._registerNoHandler(
   CounterPollIntervalCmd.noOrDefaultHandler )
# pylint: enable=protected-access

if toggleTrafficPolicyFieldSetServiceEnabled():
   MatchRuleIpv4ConfigMode.addCommandClass( ServiceFieldSetIpv4Cmd )
   MatchRuleIpv6ConfigMode.addCommandClass( ServiceFieldSetIpv6Cmd )

   TrafficPoliciesConfigMode.addCommandClass( FieldSetServiceConfigCmd )

   FieldSetServiceConfigMode.addCommandClass( FieldSetServiceConfigNoAuthCmds )

def Plugin( entityManager ):
   statusPath = f'cell/{Cell.cellId()}/trafficPolicies/status'
   gv.policiesStatus = LazyMount.mount( entityManager, statusPath, 'Tac::Dir', 'ri' )

   gv.vrfConfig = ConfigMount.mount( entityManager,
      'trafficPolicies/vrf/input/cli', 'TrafficPolicy::TrafficPolicyVrfConfig', 'w' )
   gv.policiesCliConfig = ConfigMount.mount( entityManager,
         'trafficPolicies/input/cli', 'TrafficPolicy::TrafficPolicyConfig', 'wi' )
   gv.policiesIntfParamConfig = ConfigMount.mount( entityManager,
         'trafficPolicies/param/config/interface',
         'TrafficPolicy::TrafficPolicyIntfParamConfig', 'w' )
   gv.policiesStatusRequestDir = LazyMount.mount( entityManager,
         'trafficPolicies/statusRequest/cli', 'PolicyMap::PolicyMapStatusRequestDir',
         'wc' )
   gv.fieldSetConfig = ConfigMount.mount( entityManager,
         'trafficPolicies/fieldset/input/cli', 'Classification::FieldSetConfig',
         'w' )
   gv.locationAliasConfig = ConfigMount.mount( entityManager,
                                               'trafficPolicies/udfconfig/cli',
                                               'Classification::UdfConfig', 'w' )
   gv.l3IntfConfigDir = LazyMount.mount( entityManager,
         'l3/intf/config', 'L3::Intf::ConfigDir', 'r' )
   gv.switchIntfConfigDir = LazyMount.mount( entityManager,
         'bridging/switchIntfConfig', 'Bridging::SwitchIntfConfigDir', 'r' )
   gv.vlanConfigDir = LazyMount.mount( entityManager,
         'bridging/config', 'Bridging::Config', 'r' )
   gv.intfTrafficPolicyHwStatus = (
         LazyMount.mount( entityManager,
                          "trafficPolicies/hardware/status/interface",
                          "TrafficPolicy::HwStatus", "r" )
   )
   gv.cpuTrafficPolicyHwStatus = (
         LazyMount.mount( entityManager,
                         'trafficPolicies/hardware/status/cpu',
                         'TrafficPolicy::HwStatus', 'r' )
   )
   gv.routingHwRouteStatus = (
         LazyMount.mount( entityManager,
                          "routing/hardware/route/status",
                          "Routing::Hardware::RouteStatus", "r" )
   )
   gv.ingressTrafficPolicyIntfConfig = ConfigMount.mount( entityManager,
                                         "trafficPolicies/intf/input/config/vxlan",
                                         "PolicyMap::IntfConfig", 'w' )
   gv.cpuPoliciesVrfConfig = ConfigMount.mount( entityManager,
                                "trafficPolicies/cpu/vrf",
                                "PolicyMap::VrfConfig", 'wi' )
   gv.cliAppProfileConfig = LazyMount.mount( entityManager,
                                 'classification/app-recognition/config',
                                 'Classification::AppRecognitionConfig',
                                 'r' )
