# Copyright (c) 2022 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function
import Tac
from CliPlugin import PolicyMapCliLib
from CliPlugin.TrafficPolicyCliLib import ( ActionType,
                                            TrafficPolicyMatchRuleAction,
                                            ClassPriorityConstant,
                                            ReservedClassMapNames,
                                            UniqueId,
                                            matchIpAccessGroup,
                                            matchIpv6AccessGroup,
                                            matchMacAccessGroup, )
from CliMode.TrafficPolicy import TrafficPolicyConfigMode
import Toggles.TrafficPolicyToggleLib
from Toggles.TrafficPolicyToggleLib import (
   toggleTrafficPolicyDefaultMacMatchEnabled )
from ClassificationLib import getIdFromDesc

import six

CHANGED = PolicyMapCliLib.CHANGED
IDENTICAL = PolicyMapCliLib.IDENTICAL
RESEQUENCED = PolicyMapCliLib.RESEQUENCED

class TrafficPolicyBaseContext( PolicyMapCliLib.PolicyMapContext ):
   pmapType = ''

   def __init__( self, config, statusReqDir, status, trafficPolicyName,
                 rollbackEnabled=True ):
      if not self.pmapType:
         raise NotImplementedError

      PolicyMapCliLib.PolicyMapContext.__init__( self, config, statusReqDir,
                                                 status, trafficPolicyName,
                                                 self.pmapType,
                                                 rollbackEnabled=rollbackEnabled )
      self.matchRuleContext = None
      self.shouldResequence = False

   def childMode( self ):
      raise NotImplementedError

   def hasPolicy( self, name ):
      return name in self.config().pmapType.pmap

   def mapTypeStr( self ):
      return 'traffic-policy'

   def reservedClassMapNames( self ):
      return [ ReservedClassMapNames.classV4Default,
               ReservedClassMapNames.classV6Default ]

   def classPriorityConstants( self ):
      return [ ClassPriorityConstant.classV4DefaultPriority,
               ClassPriorityConstant.classV6DefaultPriority ]

   def defaultClassMapToClassPriority( self ):
      return { ReservedClassMapNames.classV4Default :
                  ClassPriorityConstant.classV4DefaultPriority,
               ReservedClassMapNames.classV6Default :
                  ClassPriorityConstant.classV6DefaultPriority }

   def defaultClassMapToAccessGroup( self ):
      return { ReservedClassMapNames.classV4Default : matchIpAccessGroup,
               ReservedClassMapNames.classV6Default : matchIpv6AccessGroup }

   def initializeRuleToSeq( self ):
      super( TrafficPolicyBaseContext, self ).initializeRuleToSeq()
      self.ruleToSeq[ TrafficPolicyMatchRuleAction.keyTag() ] = \
         PolicyMapCliLib.PolicyRuleToSeqDict()

   def setRuleToSeq( self, policyMapSubConfig ):
      maxSeq = 0
      matchRuleFilterToSeq = self.ruleToSeq[ TrafficPolicyMatchRuleAction.keyTag() ]
      for prio, cmapName in policyMapSubConfig.classPrio.items():
         cmapName = policyMapSubConfig.classPrio[ prio ]
         rawCmap = policyMapSubConfig.rawClassMap.get( cmapName, None )
         assert rawCmap
         assert matchIpAccessGroup in rawCmap.match or \
                matchIpv6AccessGroup in rawCmap.match or \
                matchMacAccessGroup in rawCmap.match, \
                'Unsupported ClassMap matchTypes %s' % rawCmap.match
         matchRuleFilterToSeq[ cmapName ] = prio
         if cmapName in self.reservedClassMapNames():
            # The default rules should always be at the highest priority value. Even
            # after resequencing.
            assert prio in self.classPriorityConstants()
            # When sequencing rules, the default rules should not be considered as
            # they should always be at the end.
            continue
         maxSeq = max( maxSeq, prio )
      return maxSeq

   def defaultRuleSeq( self, inc ):
      ipv6DefaultRuleSeq = None
      ipv4DefaultRuleSeq = None
      if ReservedClassMapNames.classV6Default in self.reservedClassMapNames():
         ipv6DefaultRuleSeq = self.lastSequence() - inc
         ipv4DefaultRuleSeq = self.lastSequence() - 2 * inc
      elif ReservedClassMapNames.classV4Default in self.reservedClassMapNames():
         ipv4DefaultRuleSeq = self.lastSequence() - inc

      return {
         ReservedClassMapNames.classV4Default : ipv4DefaultRuleSeq,
         ReservedClassMapNames.classV6Default : ipv6DefaultRuleSeq }

   def resequence( self, start, inc ):
      defaultClassMapToPrioDict = self.defaultClassMapToClassPriority()
      # The user may have removed a reserved class map; only resequence the reserved
      # class maps which are still configured.
      remainingReservedClassMaps = []

      for className in self.reservedClassMapNames():
         if defaultClassMapToPrioDict[ className ] in self.npmap.classPrio:
            remainingReservedClassMaps.append( className )

      result = super( TrafficPolicyBaseContext, self ).resequence( start, inc )
      if result == 'errSequenceOutOfRange':
         return result

      numDefaultMatchRules = len( self.reservedClassMapNames() )
      lastUserRuleSeq = self.lastSequence() - ( numDefaultMatchRules + 1 ) * inc

      defaultRuleSeqDict = self.defaultRuleSeq( inc )

      # Assert the last two rules are the default rules
      # Assert the proper default rules' seqnos are not occupied by other rules
      for className in remainingReservedClassMaps:
         defaultRuleSeq = defaultRuleSeqDict[ className ]
         classPrio = defaultClassMapToPrioDict[ className ]
         assert self.npmap.classPrio[ defaultRuleSeq ] == className
         # super().requence() clears npmap.classPrio so the default class prio will
         # never be present in it.
         assert classPrio not in self.npmap.classPrio

      # Delete the old defaults
      for className in remainingReservedClassMaps:
         del self.npmap.classPrio[ defaultRuleSeqDict[ className ] ]

      # Restore the proper defaults
      for className in remainingReservedClassMaps:
         self.npmap.classPrio[ defaultClassMapToPrioDict[ className ] ] = className

      # Clean up the rule to sequence mapping
      matchRuleFilterToSeq = self.ruleToSeq[ TrafficPolicyMatchRuleAction.keyTag() ]
      for className in remainingReservedClassMaps:
         classPrio = defaultClassMapToPrioDict[ className ]
         matchRuleFilterToSeq.pop( classPrio, None )
         matchRuleFilterToSeq[ className ] = classPrio

      # Reset the last sequence to be the last user-defined rule
      self.lastSequenceIs( lastUserRuleSeq )

      self.shouldResequence = False
      self.moving = False

      return 'success'

   def currentPolicy( self ):
      return self.currentPmap()

   def copyRawClassMap( self, src, dst, mapType ):
      ''' src and dst are of type ClassMapSubConfig
      '''
      dst.matchCondition = src.matchCondition
      dst.matchDesc = src.matchDesc
      for option in src.match:
         srcMatch = src.match[ option ]
         dst.match.newMember( option )
         dstMatch = dst.match[ option ]
         if option == matchIpAccessGroup or option == matchIpv6AccessGroup \
               or option == matchMacAccessGroup:
            dstMatch.structuredFilter = ( "", )
            dstMatch.structuredFilter.copy( srcMatch.structuredFilter )
         else:
            assert False, 'unknown match option ' + option

   def copyAction( self, src ):
      actionType = src.actionType
      actions = self.config().actions
      if actionType == ActionType.deny:
         dst = actions.dropAction.newMember( src.className, UniqueId() )
         dst.msgType = src.msgType
         return dst
      elif actionType == ActionType.police:
         return actions.policeAction.newMember( src.className, UniqueId(),
                                                src.rateLimit, src.burstSize )
      elif actionType == ActionType.count:
         return actions.countAction.newMember( src.className, UniqueId(),
                                               src.counterName )
      elif actionType == ActionType.log:
         return actions.logAction.newMember( src.className, UniqueId() )
      elif actionType == ActionType.actionGoto:
         return actions.gotoAction.newMember( src.className, UniqueId(),
                                              src.gotoClassName )
      elif actionType == ActionType.setDscp:
         return actions.setDscpAction.newMember( src.className, UniqueId(),
                                                 src.dscp )
      elif actionType == ActionType.setTtl:
         return actions.setTtlAction.newMember( src.className, UniqueId(),
                                                src.ttl )
      elif actionType == ActionType.setTc:
         return actions.setTcAction.newMember( src.className, UniqueId(),
                                               src.tc )
      elif actionType == ActionType.useVrfSecondary:
         return actions.useVrfSecondaryAction.newMember( src.className,
                                                         UniqueId() )
      elif actionType == ActionType.setVrf:
         return actions.setVrfAction.newMember( src.className, UniqueId(),
                                                src.vrfName )
      elif actionType == ActionType.setVrfSecondary:
         return actions.setVrfSecondaryAction.newMember( src.className,
                                                         UniqueId(),
                                                         src.vrfName )
      elif actionType == ActionType.setDecapVrf:
         return actions.setDecapVrfAction.newMember( src.className,
                                                     UniqueId(),
                                                     src.decapVrfName,
                                                     src.fallbackVrfName,
                                                     src.postDecapVrfName )
      elif actionType == ActionType.setIdentityTag:
         return actions.setIdTagAction.newMember( src.className, UniqueId(),
                                                  src.idTag )
      elif actionType == ActionType.stripHeaderBytes:
         dst = actions.stripHdrBytesAction.newMember( src.className, UniqueId() )
         dst.copyAction( src )
         return dst
      elif actionType == ActionType.setAggregationGroup:
         dst = actions.setAggGroupAction.newMember( src.className, UniqueId() )
         dst.copyAction( src )
         return dst
      elif actionType == ActionType.setNexthopGroup:
         dst = actions.setNexthopGroupAction.newMember( src.className, UniqueId() )
         dst.copyAction( src )
         return dst
      elif actionType == ActionType.setNexthop:
         dst = actions.setNexthopAction.newMember( src.className, UniqueId() )
         dst.copyAction( src )
         return dst
      elif actionType == ActionType.setTimestampHeader:
         return actions.setTimestampHeaderAction.newMember(
               src.className, UniqueId() )
      elif actionType == ActionType.mirror:
         dst = actions.mirrorAction.newMember( src.className, UniqueId() )
         dst.copyAction( src )
         return dst
      elif actionType == ActionType.setMacAddress:
         return actions.setMacAddressAction.newMember( src.className, UniqueId(),
                                                       src.macAddress )
      elif actionType == ActionType.sflow:
         return actions.sflowAction.newMember( src.className, UniqueId() )
      elif actionType == ActionType.setHeaderRemove:
         return actions.setHeaderRemoveAction.newMember( src.className, UniqueId(),
                                                         src.headerRemove )
      elif actionType == ActionType.actionSet:
         dst = actions.replicateAction.newMember( src.className, UniqueId() )
         dst.copyAction( src )
         return dst
      elif actionType == ActionType.redirectTunnel:
         dst = actions.redirectTunnelAction.newMember( src.className, UniqueId() )
         dst.copyAction( src )
         return dst
      else:
         assert False, 'unknown actionType ' + actionType
      return None

   def updateNamedCounters( self, counterNames, add=True ):
      pmap = self.npmap
      if add:
         for c in counterNames:
            pmap.namedCounter.add( c )
      else:
         if not counterNames:
            # Delete all counters
            pmap.namedCounter.clear()
         else:
            # Only delete exisiting counters
            for c in counterNames:
               del pmap.namedCounter[ c ]

   def updatePolicyDesc( self, policyDesc, add=True ):
      pmap = self.npmap
      if add:
         pId = getIdFromDesc( desc=policyDesc, descTypePolicy=True )
         maxPolicyId = Tac.Value( 'Classification::Constants' ).maxPolicyId
         if pId and pId not in range( 1, maxPolicyId ):
            errStr = f'Policy id range should be from 1 to {maxPolicyId}'
            return errStr
         pmap.policyDesc = policyDesc
         pmap.policyId = pId
      else:
         pmap.policyDesc = ''
         pmap.policyId = 0
      return ''

   def maxRules( self ):
      pass

   def lastSequenceIs( self, seqnum ):
      if seqnum not in self.classPriorityConstants():
         super( TrafficPolicyBaseContext, self ).lastSequenceIs( seqnum )

   def maxSeq( self ):
      return ClassPriorityConstant.classPriorityMax

   def getRawTagAndFilter( self, cmapName ):
      return 'matchRuleFilter', cmapName

   def getRuleAtSeqnum( self, seqnum ):
      '''
        Return only the matching rule. If necessary, we can return
        the actions later on.
      '''
      policyRule = None
      ruleName = self.npmap.classPrio.get( seqnum, None )
      if ruleName:
         matchRule = self.npmap.rawClassMap.get( ruleName, None )
         if matchRule:
            matchOption = list( matchRule.match )[ 0 ]
            sfilter = matchRule.match.get( matchOption ).structuredFilter
            policyRule = TrafficPolicyMatchRuleAction( self, ruleName,
                                                       matchOption, sfilter )
      return policyRule

   def delDefaultActions( self, matchOption ):
      del self.npmap.defaultAction[ matchOption ]

   def removeAction( self, action ):
      actType = action.actionType
      actions = self.config().actions
      if actType == ActionType.deny:
         del actions.dropAction[ action.id ]
      elif actType == ActionType.police:
         del actions.policeAction[ action.id ]
      elif actType == ActionType.count:
         del actions.countAction[ action.id ]
      elif actType == ActionType.log:
         del actions.logAction[ action.id ]
      elif actType == ActionType.actionGoto:
         del actions.gotoAction[ action.id ]
      elif actType == ActionType.setDscp:
         del actions.setDscpAction[ action.id ]
      elif actType == ActionType.setTtl:
         del actions.setTtlAction[ action.id ]
      elif actType == ActionType.setTc:
         del actions.setTcAction[ action.id ]
      elif actType == ActionType.useVrfSecondary:
         del actions.useVrfSecondaryAction[ action.id ]
      elif actType == ActionType.setVrf:
         del actions.setVrfAction[ action.id ]
      elif actType == ActionType.setVrfSecondary:
         del actions.setVrfSecondaryAction[ action.id ]
      elif actType == ActionType.setAggregationGroup:
         del actions.setAggGroupAction[ action.id ]
      elif actType == ActionType.setIdentityTag:
         del actions.setIdTagAction[ action.id ]
      elif actType == ActionType.setTimestampHeader:
         del actions.setTimestampHeaderAction[ action.id ]
      elif actType == ActionType.stripHeaderBytes:
         del actions.stripHdrBytesAction[ action.id ]
      elif actType == ActionType.setNexthopGroup:
         del actions.setNexthopGroupAction[ action.id ]
      elif actType == ActionType.setNexthop:
         del actions.setNexthopAction[ action.id ]
      elif actType == ActionType.mirror:
         del actions.mirrorAction[ action.id ]
      elif actType == ActionType.setMacAddress:
         del actions.setMacAddressAction[ action.id ]
      elif actType == ActionType.sflow:
         del actions.sflowAction[ action.id ]
      elif actType == ActionType.setHeaderRemove:
         del actions.setHeaderRemoveAction[ action.id ]
      elif actType == ActionType.setDecapVrf:
         del actions.setDecapVrfAction[ action.id ]
      elif actType == ActionType.actionSet:
         del actions.replicateAction[ action.id ]
      elif actType == ActionType.redirectTunnel:
         del actions.redirectTunnelAction[ action.id ]
      else:
         assert False, 'unknown action ' + action

   def identicalFilter( self, filter1, filter2 ):
      if filter1 is None and filter2 is None:
         return True
      elif filter1 is None or filter2 is None:
         return False
      return filter1.isEqual( filter2 )

   def identicalMatch( self, matchRule1, matchRule2 ):
      assert len( matchRule1.match ) == 1
      if list( matchRule1.match ) != list( matchRule2.match ):
         return False
      if matchRule1.className != matchRule2.className:
         return False
      if matchRule1.matchDesc != matchRule2.matchDesc:
         return False
      filter1 = list( matchRule1.match.values() )[ 0 ].structuredFilter
      filter2 = list( matchRule2.match.values() )[ 0 ].structuredFilter
      return self.identicalFilter( filter1, filter2 )

   def compareMatchRules( self, p1, p2 ):
      # compare matchRules
      p1Rules = list( p1.rawClassMap )
      p2Rules = list( p2.rawClassMap )
      if p1Rules != p2Rules:
         return CHANGED
      for k, matchRule in six.iteritems( p1.rawClassMap ):
         if not self.identicalMatch( matchRule, p2.rawClassMap.get( k ) ):
            return CHANGED
      return None

   def identicalPolicyMap( self, p1, p2 ):
      if p1 is None or p2 is None:
         return CHANGED

      if p1 == p2:
         return IDENTICAL

      # compare named counters
      p1NamedCounters = p1.namedCounter
      p2NamedCounters = p2.namedCounter
      if set( p1NamedCounters ) != set( p2NamedCounters ):
         return CHANGED

      if p1.policyDesc != p2.policyDesc:
         return CHANGED
      # compare class Actions
      p1ClassActions = list( p1.classAction )
      p2ClassActions = list( p2.classAction )
      if p1ClassActions != p2ClassActions:
         return CHANGED
      for ruleName in p1ClassActions:
         r1Action = p1.classAction[ ruleName ]
         r2Action = p2.classAction[ ruleName ]
         ret = self.identicalClassActions( r1Action, r2Action )
         if ret != IDENTICAL:
            return ret

      # Compare default actions
      if set( p1.defaultAction ) != set( p2.defaultAction ):
         return CHANGED
      for matchOption in p1.defaultAction:
         p1DefaultActions = p1.defaultAction[ matchOption ]
         p2DefaultActions = p2.defaultAction[ matchOption ]
         ret = self.identicalDefaultActions( p1DefaultActions,
                                             p2DefaultActions )
         if ret != IDENTICAL:
            return ret

      # compare class priorities
      p1Prios = list( p1.classPrio )
      p2Prios = list( p2.classPrio )
      if p1Prios != p2Prios:
         if list( p1.classPrio.values() ) == list( p2.classPrio.values() ):
            # The keys were changed, but the values are still in the same order so
            # we've either resequenced or had one of the match rules change but the
            # name is still the same.
            if self.compareMatchRules( p1, p2 ) == CHANGED:
               return CHANGED
            return RESEQUENCED
         else:
            return CHANGED
      if list( p1.classPrio.values() ) != list( p2.classPrio.values() ):
         return CHANGED
      if self.compareMatchRules( p1, p2 ) == CHANGED:
         return CHANGED

      return IDENTICAL

   def identicalClassActions( self, c1Action, c2Action ):
      return IDENTICAL if c1Action.equalTo( c2Action ) else CHANGED

   def delMatchRule( self, ruleName ):
      del self.currentPolicy().rawClassMap[ ruleName ]

   def abort( self ):
      self.delPolicyResources()

   def addDefaultRule( self, cmapName, matchOption, classPriority ):
      if self.npmap.rawClassMap.get( cmapName, None ):
         # Don't add the rule twice
         return

      cmap = self.npmap.rawClassMap.newMember( cmapName, UniqueId() )

      cmap.match.clear()
      cmapMatch = cmap.match.newMember( matchOption )
      cmapMatch.structuredFilter = ( "", )

      self.npmap.classPrio[ classPriority ] = cmapName

      matchRuleAction = TrafficPolicyMatchRuleAction( self,
                                                      cmapName,
                                                      matchOption,
                                                      cmapMatch.structuredFilter )
      self.addRuleCommon( classPriority, matchRuleAction )

   def newEditPmap( self ):
      super( TrafficPolicyBaseContext, self ).newEditPmap()
      accessGroupDict = self.defaultClassMapToAccessGroup()
      classPrioDict = self.defaultClassMapToClassPriority()
      for className in self.reservedClassMapNames():
         self.addDefaultRule( className,
                              accessGroupDict[ className ],
                              classPrioDict[ className ] )

   def supportsDefaultRuleRemoval( self ):
      """
      This flag dictates whether removing of reserved class names e.g.
      ipv4-all-default is supported.

      This is an _opt-in_ configuration for the purposes backwards-compatibility,
      hence the default value of "False".

      For vanilla traffic policies i.e. those under the "traffic-policies"
      configuration mode, this is supported. For other users e.g. VRF selection, this
      is not supported.
      """
      return False

class TrafficPolicyContext( TrafficPolicyBaseContext ):
   pmapType = 'mapTrafficPolicy'

   def childMode( self ):
      return TrafficPolicyConfigMode

   def reservedClassMapNames( self ):
      reservedNames = super( TrafficPolicyContext, self ).reservedClassMapNames()
      if toggleTrafficPolicyDefaultMacMatchEnabled():
         reservedNames.append( ReservedClassMapNames.classMacDefault )
      return reservedNames

   def classPriorityConstants( self ):
      classPrio = super( TrafficPolicyContext, self ).classPriorityConstants()
      if toggleTrafficPolicyDefaultMacMatchEnabled():
         classPrio.append( ClassPriorityConstant.classMacDefaultPriority )
      return classPrio

   def defaultClassMapToClassPriority( self ):
      classMapToClassPrio = super( TrafficPolicyContext,
         self ).defaultClassMapToClassPriority()
      if toggleTrafficPolicyDefaultMacMatchEnabled():
         classMapToClassPrio[ ReservedClassMapNames.classMacDefault ] = \
            ClassPriorityConstant.classMacDefaultPriority
      return classMapToClassPrio

   def defaultClassMapToAccessGroup( self ):
      classMapToAccessGroup = super( TrafficPolicyContext,
         self ).defaultClassMapToAccessGroup()
      if toggleTrafficPolicyDefaultMacMatchEnabled():
         classMapToAccessGroup[ ReservedClassMapNames.classMacDefault ] = \
            matchMacAccessGroup
      return classMapToAccessGroup

   def defaultRuleSeq( self, inc ):
      ipv6DefaultRuleSeq = None
      ipv4DefaultRuleSeq = None
      macDefaultRuleSeq = None
      if ReservedClassMapNames.classV6Default in self.reservedClassMapNames():
         ipv6DefaultRuleSeq = self.lastSequence() - inc
         ipv4DefaultRuleSeq = self.lastSequence() - 2 * inc
         macDefaultRuleSeq = self.lastSequence() - 3 * inc
      elif ReservedClassMapNames.classV4Default in self.reservedClassMapNames():
         ipv4DefaultRuleSeq = self.lastSequence() - inc
         macDefaultRuleSeq = self.lastSequence() - 2 * inc
      elif ReservedClassMapNames.classMacDefault in self.reservedClassMapNames():
         macDefaultRuleSeq = self.lastSequence() - inc

      return {
         ReservedClassMapNames.classV4Default : ipv4DefaultRuleSeq,
         ReservedClassMapNames.classV6Default : ipv6DefaultRuleSeq,
         ReservedClassMapNames.classMacDefault : macDefaultRuleSeq }

   def supportsDefaultRuleRemoval( self ):
      return Toggles.TrafficPolicyToggleLib.toggleCpuTrafficPolicyIntfEnabled()
