#!/usr/bin/env python3
# Copyright (c) 2009, 2010, 2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import Tac, AclLib, WaitForWarmup
from AclLib import ( defaultVrf,
                     newServiceAcl, CopyToCpuDst )
import CliCommand
import CliMatcher
import Tracing
from TeCliLib import dscpAclNames
from TypeFuture import TacLazyType

th = Tracing.Handle( "AclCliLib" )
t0 = th.trace0

AclNameMaps = Tac.singleton( 'Acl::AclNameMaps' )
TacDscp = TacLazyType( 'Arnet::DscpValue' )
U8_MAX_VALUE = 0xFF

# Halo computations tend me lot more time consuming, and as a temporary
# stop-gap, we have a much higher than previous timeout of 60s
hwWaitTimeout = 180

ecnAclNames = {
   ecnName: Tac.enumName( 'Acl::Ecn', enumVal )
      for ecnName, enumVal in AclNameMaps.ecnByName.items()
}

# the maps from copy to CPU destination to name
copyToCpuDestinations = {
   'none': CopyToCpuDst.cpuDstNone,
   'captive-portal': CopyToCpuDst.cpuDstCaptivePortal, 
}

# generic IP protocols (only support generic IP CLI)
genericIpProtocols = {
   'ospf': ( AclLib.IPPROTO_OSPF, 'OSPF routing protocol' ),
   'vrrp': ( AclLib.IPPROTO_VRRP, 'Virtual Router Redundancy Protocol (VRRP)' ),
   'pim': ( AclLib.IPPROTO_PIM, 'Protocol Independent Multicast (PIM)' ),
   'igmp': ( AclLib.IPPROTO_IGMP, 'Internet Group Management Protocol (IGMP)' ),
   'ahp':  ( AclLib.IPPROTO_AHP, 'Authentication Header Protocol' ),
   'rsvp': ( AclLib.IPPROTO_RSVP, 'Resource Reservation Protocol (RSVP)' ),
   'esp': ( AclLib.IPPROTO_ESP, 'Encapsulation Security Payload (ESP)' ),
   }

# generic IPv6 protocols (only support generic IPv6 CLI)
genericIp6Protocols = {
   'ospf': ( AclLib.IPPROTO_OSPF, 'OSPF routing protocol' ),
   'vrrp': ( AclLib.IPPROTO_VRRP, 'Virtual Router Redundancy Protocol (VRRP)' ),
   'rsvp': ( AclLib.IPPROTO_RSVP, 'Resource Reservation Protocol (RSVP)' ),
   'pim': ( AclLib.IPPROTO_PIM, 'Protocol Independent Multicast (PIM)' ),
   }

# ICMP message { token: ( helpdesc, type, code ) }
icmpMessages = AclNameMaps.icmpDesc

# ICMPV6 message { token: ( helpdesc, type, code ) }
icmp6Messages = AclNameMaps.icmp6Desc

noHwWarningMsg = "Hardware not present. ACL(s) not programmed in the hardware."

def findNameByValue( value, nameMap ):
   # Given a { name : value } map do a reverse lookup to find the name.
   # This isn't exactly fast, so use it in non-performance critical code
   # if the map is big (tests).
   for k, v in nameMap.items():
      if isinstance( v, tuple ):
         if v[ 0 ] == value:
            return k
      elif v == value:
         return k
   return str( value )

def dscpValueFromCli( mode, dscpVal ):
   dscpResult = dscpAclNames.get( dscpVal, False )
   return ( dscpResult[ 0 ], True ) if dscpResult else ( dscpVal, False )

def ecnValueFromCli( mode, ecnVal ):
   return ecnAclNames[ ecnVal ]

def ecnNameFromValue( ecnValue ):
   return findNameByValue( ecnValue, ecnAclNames )

def dscpNameFromValue( dscpValue ):
   return findNameByValue( dscpValue, dscpAclNames )

def copyToCpuDstFromCli( mode, copyToCpuDst ):
   return copyToCpuDestinations.get( copyToCpuDst, CopyToCpuDst.cpuDstNone )

def copyToCpuDstFromValue( cpuDstValue ):
   return findNameByValue( cpuDstValue, copyToCpuDestinations )

def portNumberFromString( proto, port ):
   # from a string return the numeric port number
   if serviceByName := AclLib.serviceMap.get( proto ):
      if serviceDesc := serviceByName.get( port ):
         return serviceDesc.port
   return int( port )

# IP protocols
ipProtoByName = AclNameMaps.ipProtoByName

def ipProtoFromValue( proto ):
   # from proto value return a string
   protoName = AclNameMaps.ipProto.get( proto )
   if protoName is not None:
      return protoName
   return str( proto )

def ipProtoFromString( proto ):
   # from proto string return a value
   v = ipProtoByName.get( str( proto ) )
   if v is not None:
      return v
   return int( proto )

# IPv6 protocols
ip6ProtoByName = AclNameMaps.ip6ProtoByName

def ip6ProtoFromString( proto ):
   # from proto string return a value
   v = ip6ProtoByName.get( str( proto ) )
   if v is not None:
      return v
   return int( proto )

def payloadStringToValue( payloadString ):
   # returns list of (offset, pattern, mask, alias, patternOverride, maskOverride)
   # tuples from payload string
   if( payloadString == '' ): # pylint: disable=superfluous-parens
      return []
   values = payloadString.split( AclLib.PayloadStringConstants.fieldDelim )
   payloadList = []
   i = 0
   while( i < len(values) ): # pylint: disable=superfluous-parens
      offset, pattern, mask, alias = values[ i:i+4 ]
      aliasParts = alias.split( AclLib.PayloadStringConstants.overrideDelim )
      alias = aliasParts[ 0 ]
      patternOverride = False
      maskOverride = False
      for override in aliasParts[ 1: ]:
         if override == AclLib.PayloadStringConstants.patternOverride:
            patternOverride = True
         if override == AclLib.PayloadStringConstants.maskOverride:
            maskOverride = True
      payloadList.append( [ int( offset ), int( pattern ), int( mask ), alias,
                            patternOverride, maskOverride ] )
      i = i + 4
   return sorted( payloadList )

def ipLenFromString( inputString ):
   # return a Acl::IpLenSpec from a string
   if not inputString:
      return AclLib.anyIpLenValue
   ipLens = inputString.split( ' ' )
   return AclLib.IpLenValue( ipLens[ 1 ],
                             ' '.join( ipLens[ 2: ] ) )

def portFromString( proto, inputString ):
   # return a Acl::PortSpec from a string
   if not inputString:
      return AclLib.anyPortValue
   ports = inputString.split( ' ' )
   return AclLib.PortValue( ports[ 0 ],
                            ' '.join( str( portNumberFromString( proto, x ) )
                                      for x in ports[ 1: ] ) )

def remarkFromValue( remarkConfig ):
   # from a remark config, return the original CLI command string
   return "remark " + remarkConfig.remark

def ruleFromValue( rule, aclType, standard=False, convert=True ):
   # from a rule value, return the original CLI command string
   return rule.ruleStr( standard, convert )

# Error messages
def unsuppRulesWarning( name, aclType, errMsg="" ):
   if aclType == 'ipv6':
      return "Warning: ACL %s contains rules not supported in data-plane.\n" % \
            name + "Unsupported options will be ignored by hardware.\n" + errMsg
   else:
      return "Warning: ACL %s contains rules not supported in data-plane.\n" % \
         name + "Unsupported options will be ignored by hardware.\n" + errMsg

def notConfiguredError( aclType, aclName, intfName ):
   return "A different %s access-list is configured on %s%s." % (
      aclType, '' if 'control-plane' in intfName or 'service' in intfName \
         else 'interface ', intfName )

def dpiRulesWarning( aclName, aclType ):
   return "ACL %s contains deep inspection rules which " \
          "require deep inspection payload skip value to be 0.\n" % aclName + \
          "Unsupported options ( nvgre, gre, vxlan, gtp, mpls " \
          "matches ) will be ignored by hardware."

def extendedAclWarning():
   return 'Only source, destination and log are used in extended ACL'

def errWithReason( msg, reason="" ):
   if reason:
      return "%s (%s)" % ( msg, reason )
   else:
      return msg      

def aclSessionCommitError( error ):
   return errWithReason( 
      "Error in ACL commit, configuration may differ from hardware",
      error )

def aclTimeoutWarning():
   return "The ACL configuration is still being programmed into hardware. "\
       "The system might be busy for a while."

def aclCommitError( name, aclType, error="" ):
   return errWithReason( "Error: Cannot commit %s ACL %s" % (
      aclType, name ), error )

def aclCreateError( name, aclType, error="" ):
   return errWithReason( "Error: Cannot create %s ACL %s" % (
      aclType, name ), error )

def aclModifyError( name, aclType, error="" ):
   return errWithReason( "Error: Cannot modify %s ACL %s" % (
      aclType, name ), error )

def intfAclConfigError( name, aclType, intf, error="" ):
   return errWithReason( "Error: Cannot apply %s ACL %s to %s" % (
     aclType,  name, intf ), error )

def intfAclDeleteError( name, aclType, intf, error="" ):
   return errWithReason( "Error: Cannot remove %s ACL %s from %s" % (
      aclType, name, intf ), error )

def aclTypeError( name, aclType ):
   return "%s access list with this name already exists" % aclType

def aclTcamStatus( aclStatusDp, switchName ):
   for v in aclStatusDp.values():
      acl = v.tcamStatus.get( switchName )
      if acl:
         return acl
   return None

def getAclNames( config, aclType ):
   return list( config.config[ aclType ].acl )

def setServiceAclTypeVrfMap( mode, serviceAclConfig,
                             aclName, aclType='ip', vrfName=None,
                             force=False ):
   if not vrfName:
      vrfName = defaultVrf
   aclTypeAndVrfName = Tac.Value( 'Acl::AclTypeAndVrfName', aclType, vrfName )
   if not aclName and not force:
      del serviceAclConfig.aclName[ aclTypeAndVrfName ]
   else:
      serviceAclConfig.aclName[ aclTypeAndVrfName ] = aclName
   if mode:
      tryWaitForWarmup( mode )

def noServiceAclTypeVrfMap( mode, serviceAclConfig,
                            aclName, aclType='ip', vrfName=None ):
   if not vrfName:
      vrfName = defaultVrf
   aclTypeAndVrfName = Tac.Value( 'Acl::AclTypeAndVrfName', aclType, vrfName )
   name = serviceAclConfig.aclName.get( aclTypeAndVrfName, "" )
   if aclName is None or name == aclName:
      del serviceAclConfig.aclName[ aclTypeAndVrfName ]
      if mode:
         tryWaitForWarmup( mode )
   else:
      mode.addWarning( notConfiguredError( aclType, aclName,
                  'service %s(%s VRF)' % ( serviceAclConfig.name, vrfName ) ) )

def checkServiceAcl( mode, config, aclName, aclType='ip' ):
   if aclName:
      acl = config.config[ aclType ].acl.get( aclName )
      if acl and not acl.standard and mode:
         # mode can be None is called from CpAclTests
         mode.addWarning( extendedAclWarning() )

def setServiceAcl( mode, service, proto, config, cpConfig,
                   aclName, aclType='ip', vrfName=None,
                   port=None, sport=None, defaultAction='deny', tracked=False ):
   if not vrfName:
      vrfName = defaultVrf
   newServiceAcl( config, cpConfig, service, aclName, vrf=vrfName, aclType=aclType,
                  proto=proto, port=port, defaultAction=defaultAction, sport=sport,
                  tracked=tracked )
   if mode:
      tryWaitForWarmup( mode )

def noServiceAcl( mode, service, config, cpConfig, 
                  aclName, aclType='ip', vrfName=None ):
   if not vrfName:
      vrfName = defaultVrf
   serviceAclVrfConfig = cpConfig.cpConfig[ aclType ].serviceAcl.get( vrfName )
   if serviceAclVrfConfig:
      serviceConfig = serviceAclVrfConfig.service.get( service )
      if serviceConfig:
         if aclName is None or serviceConfig.aclName == aclName:
            if not serviceConfig.defaultAclName:
               # delete only when there is no default service Acl applied
               del serviceAclVrfConfig.service[ service ]
               if not serviceAclVrfConfig.service:
                  del cpConfig.cpConfig[ aclType ].serviceAcl[ vrfName ]
         elif mode:
            # mode can be None is called from CpAclTests
            mode.addError( notConfiguredError(
                  aclType, aclName, 'service %s(%s VRF) ' % ( service, vrfName ) ) )
      if mode:
         tryWaitForWarmup( mode )

def tryWaitForWarmup( mode ):
   # throttles Cli configuration so we do not change too fast
   # as for large ACLs it may take significant time for the agent to process
   try:
      WaitForWarmup.wait( mode.entityManager,
                          agentsToGrab=[ 'Acl' ],
                          timeout=300.0, sleep=True,
                          verbose=False )
   except Tac.Timeout:
      pass

def isRemarkConfig( remarkConfig ):
   # pylint: disable-next=unidiomatic-typecheck
   return type( remarkConfig ) == Tac.Type( 'Acl::RemarkConfig' )

def numAclRules( aclSubConfig ):
   return len( aclSubConfig.ruleBySequence ) + len( aclSubConfig.remarkBySequence ) 

def sortedSequenceNumbers( aclSubConfig ):
   rules = list( aclSubConfig.ruleBySequence ) + list(
      aclSubConfig.remarkBySequence )
   rules.sort()
   return rules

def getRuleById( subconfig, aclType ):
   return getattr( subconfig, aclType.lower().replace( "v", "" ) + 'RuleById' )

def getRuleValue( subConfig, seq, aclType, standard=False, convert=True ):
   return subConfig.ruleStr( seq, aclType, standard, convert )

def mergeAclTypeStatus( fromStatus, toStatus ):
   '''
   Merge Acl::AclTypeStatus object fromStatus into toStatus.
   First we merge AclStatus collection, then we merge AclIntf collection.
   '''
   for status in [ fromStatus, toStatus ]:
      assert status.__class__.__name__ == "Acl::AclTypeStatus"

   def mergeRuleStatus( ruleStatus1, ruleStatus2 ):
      return Tac.Value( "Acl::RuleStatus",
                        ruleStatus1.pkts + ruleStatus2.pkts,
                        max( ruleStatus1.lastChangedTime,
                             ruleStatus2.lastChangedTime ) )

   def copyRuleStatus( ruleStatus ):
      return Tac.Value( "Acl::RuleStatus", ruleStatus.pkts,
                        ruleStatus.lastChangedTime )
   
   for aclName, fromAclStatus in fromStatus.acl.items():
      if not aclName in toStatus.acl:
         toStatus.acl.newMember( aclName )
      toAclStatus = toStatus.acl[ aclName ]

      for ruleId, ruleStatus in fromAclStatus.ruleStatus.items():
         if ruleId in toAclStatus.ruleStatus:
            toAclStatus.ruleStatus[ ruleId ] = mergeRuleStatus(
               ruleStatus, toAclStatus.ruleStatus[ ruleId ] )
         else:
            toAclStatus.ruleStatus[ ruleId ] = copyRuleStatus( ruleStatus )

      for ruleId, ruleStatus in fromAclStatus.connRuleStatus.items():
         if ruleId in toAclStatus.connRuleStatus:
            toAclStatus.connRuleStatus[ ruleId ] = mergeRuleStatus(
               ruleStatus, toAclStatus[ ruleId ] )
         else:
            toAclStatus.connRuleStatus[ ruleId ] = copyRuleStatus( ruleStatus )

      toAclStatus.counterUpdateTime = max(
         toAclStatus.counterUpdateTime, fromAclStatus.counterUpdateTime )
      toAclStatus.version = fromAclStatus.version
      toAclStatus.countersIncomplete = fromAclStatus.countersIncomplete
      if toAclStatus.noRuleMatches.pkts == 0:
         toAclStatus.noRuleMatches = copyRuleStatus(
            fromAclStatus.noRuleMatches )
      else:
         toAclStatus.noRuleMatches = mergeRuleStatus(
            toAclStatus.noRuleMatches, fromAclStatus.noRuleMatches )
      if toAclStatus.noConnRuleMatches.pkts == 0:
         toAclStatus.noConnRuleMatches = copyRuleStatus(
            fromAclStatus.noConnRuleMatches )
      else:
         toAclStatus.noConnRuleMatches = mergeRuleStatus(
            toAclStatus.noConnRuleMatches, fromAclStatus.noConnRuleMatches )

   for direction, fromAclIntf in fromStatus.intf.items():
      if not direction in toStatus.intf:
         toAclStatus.intf.newMember( direction )
      toAclIntf = toStatus.intf[ direction ]
      for intfId, intfStr in fromAclIntf.intf.items():
         toAclIntf.intf[ intfId ] = intfStr

def aggregateServiceAclStatus( serviceAclStatus, aggAclType=None ):
   ''' Create an aggregated serviceAcl status entity to pass it onto ACL API
       ServiceAclStatus is maintained per VRF to avoid Multi-writer problem
       with gated but Checkpoint is not maintained per VRF. ACL Cli APIs
       won't work with these different ways of maintaining information in
       ServiceAclStatus and Checkpoint. To Circumvent this problem
       ServiceAclStatus across various VRFs has been merged.
   '''
   if aggAclType is not None:
      assert aggAclType in [ "ip", "ipv6" ]
      
   aggServiceAclStatus = Tac.newInstance( 'Acl::StatusService', 'merge' )
   for vrf in serviceAclStatus:
      vrfAclStatus = serviceAclStatus[ vrf ]
      assert vrfAclStatus.__class__.__name__ == "Acl::StatusService"
      for aclType in vrfAclStatus.status:
         if aggAclType is not None and aclType != aggAclType:
            continue
         if aclType not in aggServiceAclStatus.status:
            aggServiceAclStatus.status.newMember( aclType )
         mergeAclTypeStatus( vrfAclStatus.status[ aclType ],
                             aggServiceAclStatus.status[ aclType ] )

      for aclTypeVrfName in vrfAclStatus.aclVersion:
         aggServiceAclStatus.aclVersion[ aclTypeVrfName ] = \
            vrfAclStatus.aclVersion[ aclTypeVrfName ]

      for aclTypeVrfName in vrfAclStatus.defaultAction:
         aggServiceAclStatus.action[ aclTypeVrfName ] = \
            vrfAclStatus.action[ aclTypeVrfName ]

   return aggServiceAclStatus

def aggregateUnsupportedQualsAndActions( aclName, aclType, aclStatusDp,
                                         feature="featureAcl" ):
   '''
   Returns a List[(List[IntfApplication], List[qual : str], List[action : str ])]
   i.e. unsupported quals and actions and the intfs they apply to
   '''
   unsuppMap = {} # UniqueId to relevant info

   def emplace( uid ):
      return unsuppMap.setdefault( uid, ( set(), set(), set() ) )

   for inst in aclStatusDp.values():
      uComponent = inst.unsupportedComponent.get( feature )
      if uComponent is None:
         continue
      for uLComponent in uComponent.linecardUComponent.values():
         uAclTypeComponent = uLComponent.aclTypeUComponent.get( aclType )
         if uAclTypeComponent is None:
            continue
         uAclComponent = uAclTypeComponent.aclUnsupportedComponent.get( aclName )
         if uAclComponent is None:
            continue
         for uid, uComp in uAclComponent.unsupportedComponent.items():
            _, quals, actions = emplace( uid )
            quals.update( uComp.qual.values() )
            actions.update( uComp.action.values() )
         for intfApp, uidSet in uAclComponent.intfToUniqueId.items():
            for uid in uidSet.id:
               intfApps, _, _ = emplace( uid )
               intfApps.add( intfApp )

   return [ ( sorted( intfApps ), sorted( quals ), sorted( actions ) )
            for intfApps, quals, actions in unsuppMap.values()
            if intfApps and ( quals or actions ) ]

def dscpRange():
   return TacDscp.min, TacDscp.max

def getDscpAclNames( mode ):
   return { k : v[ 1 ] for k, v in dscpAclNames.items() }

class TosExpr( CliCommand.CliExpression ):
   expression = 'tos ( TOS_VAL | ( dscp ( DSCP | DSCP_ACL ) ) )'
   data = {
      'tos' : CliCommand.singleKeyword( 'tos',
         helpdesc='Specify ToS value' ),
      'TOS_VAL' : CliMatcher.IntegerMatcher( 0, U8_MAX_VALUE,
         helpdesc='ToS value' ),
      'dscp' : CliCommand.singleKeyword( 'dscp',
         helpdesc='Specify DSCP value' ),
      'DSCP' : CliMatcher.IntegerMatcher( *dscpRange(),
         helpdesc='DSCP value' ),
      'DSCP_ACL' : CliMatcher.DynamicKeywordMatcher( getDscpAclNames )
   }
