#!/usr/bin/env python3
# Copyright (c) 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import CliSave
from CliMode.PolicyMap import ClassMapModeBase, PolicyMapModeBase, \
   PolicyMapClassModeBase
from CliSavePlugin import IntfCliSave
from CliSavePlugin.AclCliSave import IpAclConfigMode
from AclCliLib import ruleFromValue
import re
import Tac
import PolicyMap
from PolicyMap import matchOptionToStr

UnresolvedNexthopAction = Tac.Type( "PolicyMap::PbrUnresolvedNexthopAction" )

CliSave.GlobalConfigMode.addCommandSequence( 'Pbr.hardwarePersistent')

@CliSave.saver( 'Pbr::PbrConfig', 'pbr/input/pmap/cli',
                requireMounts=( 'routing/hardware/status', ) )
def saveHardwarePersistentConfig( entity, root, requireMounts, options ):
   rtgHwStatus = requireMounts[ 'routing/hardware/status' ]
   cmds = root[ 'Pbr.hardwarePersistent' ]
   if entity.isConfigDriven:
      cmds.addCommand( 'ip policy hardware persistent' )
   elif options.saveAll and rtgHwStatus.pbrHardwarePersistentSupported:
      cmds.addCommand( 'no ip policy hardware persistent' )

CliSave.GlobalConfigMode.addCommandSequence( 'Pbr.countersPollInterval' )

@CliSave.saver( 'Pbr::PbrConfig', 'pbr/input/pmap/cli',
                requireMounts=( 'routing/hardware/status', ) )
def saveCountersPollIntervalConfig( entity, root, requireMounts, options ):
   rtgHwStatus = requireMounts[ 'routing/hardware/status' ]
   cmds = root[ 'Pbr.countersPollInterval' ]
   if interval := entity.counterPollInterval:
      cmds.addCommand(
            f'ip policy counters poll interval {int( interval )} seconds' )
   elif options.saveAll and rtgHwStatus.pbrSharkCounterSupported:
      cmds.addCommand( 'no ip policy counters poll interval' )

CliSave.GlobalConfigMode.addCommandSequence( 'Pbr.hardwareUpdateHitfulForced' )

@CliSave.saver( 'Pbr::PbrConfig', 'pbr/input/pmap/cli',
                requireMounts=( 'routing/hardware/status', ) )
def saveHardwareHitfulStrictConfig( entity, root, requireMounts, options ):
   rtgHwStatus = requireMounts[ 'routing/hardware/status' ]
   cmds = root[ 'Pbr.hardwareUpdateHitfulForced' ]
   if entity.inPlaceUpdateDisabled:
      cmds.addCommand( 'ip policy hardware update hitful forced' )
   elif options.saveAll and rtgHwStatus.pbrInPlaceTcamUpdateSupported:
      cmds.addCommand( 'no ip policy hardware update hitful forced' )

CliSave.GlobalConfigMode.addCommandSequence( 'Pbr.nexthopConfig')

@CliSave.saver( 'Pbr::PbrConfig', 'pbr/input/pmap/cli' )
def saveNexthopsUnresolvedActionConfig( entity, root, requireMounts, options ):
   cmds = root[ 'Pbr.nexthopConfig' ]
   if entity.pbrUnresolvedNexthopAction == UnresolvedNexthopAction.dropDelayed:
      cmds.addCommand( 'ip policy unresolved-nexthop action drop' )
   elif entity.pbrUnresolvedNexthopAction == UnresolvedNexthopAction.dropImmediate:
      cmds.addCommand( 'ip policy unresolved-nexthop action drop immediate' )
   elif options.saveAll:
      cmds.addCommand( 'no ip policy unresolved-nexthop action' )

@CliSave.saver( 'Pbr::PbrConfig', 'pbr/input/pmap/cli' )
def savePbrMacAddressAgingConfig( entity, root, requireMounts, options ):
   cmds = root[ 'Pbr.nexthopConfig' ]
   if not entity.pbrEnableNonAgingMacAddr:
      cmds.addCommand( 'ip policy mac address aging' )
   elif options.saveAll:
      cmds.addCommand( 'no ip policy mac address aging' )

@CliSave.saver( 'Pbr::PbrConfig', 'pbr/input/pmap/cli' )
def saveBgpPbrRedirect( entity, root, requireMounts, options ):
   cmds = root[ 'Pbr.nexthopConfig' ]
   if entity.bgpPbrRedirect:
      cmds.addCommand( 'ip policy match protocol bgp' )
   elif options.saveAll:
      cmds.addCommand( 'no ip policy match protocol bgp' )

CliSave.GlobalConfigMode.addCommandSequence( 'Pbr.config',
                                             after=[ IpAclConfigMode ] )
IntfCliSave.IntfConfigMode.addCommandSequence( 'Pbr.config' )

class ClassMapConfigMode( ClassMapModeBase, CliSave.Mode ):
   def __init__( self, param ):
      ClassMapModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, self.longModeKey )

CliSave.GlobalConfigMode.addChildMode( ClassMapConfigMode,
                                       after=[ IpAclConfigMode ] )

ClassMapConfigMode.addCommandSequence( 'Pbr.cmap' )

#-------------------------------------------------------------------------------
# Object used for saving commands in "config-pmap" mode.
#-------------------------------------------------------------------------------
class PolicyMapConfigMode( PolicyMapModeBase, CliSave.Mode ):
   def __init__( self, param ):
      PolicyMapModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, self.longModeKey )

CliSave.GlobalConfigMode.addChildMode( PolicyMapConfigMode,
                                       after=[ ClassMapConfigMode ] )
PolicyMapConfigMode.addCommandSequence( 'Pbr.pmap' )

#-------------------------------------------------------------------------------
# Object used for saving commands in "config-pmap-c" mode.
#-------------------------------------------------------------------------------
class PolicyMapClassConfigMode( PolicyMapClassModeBase, CliSave.Mode ):
   def __init__( self, param ):
      ( self.mapType, self.mapStr, self.pmapName, self.cmapName,
        self.prio, self.entCmd_ ) = param
      PolicyMapClassModeBase.__init__( self, param[ : -1 ] )
      CliSave.Mode.__init__( self, self.longModeKey )

   def enterCmd( self ):
      if self.entCmd_:
         return self.entCmd_
      else:
         return '%s class %s' % ( self.prio, self.cmapName )

   def instanceKey( self ):
      return self.prio

   @classmethod
   def useInsertionOrder( cls ):
      # because `instanceKey` is overridden with prio
      return True

   def modeSeparator( self ):
      return not self.entCmd_

PolicyMapConfigMode.addChildMode( PolicyMapClassConfigMode )
PolicyMapClassConfigMode.addCommandSequence( 'Pbr.pmapc' )

class PbrCliSaver:
   def __init__( self, entity, root, requireMounts, options ):
      self.entity = entity
      self.root = root
      self.options = options
      self.mapType = 'mapPbr'
      self.mapStr = 'pbr'
      self.cmds = None
      self.pmap = None
      self.cmap = None
      self.intfConfig = requireMounts[ 'pbr/input/intf/cli' ]

   # pylint: disable-next=inconsistent-return-statements
   def savePMapClassAction( self, cmapName, addCmd=True ):
      classAction = self.pmap.classAction[ cmapName ]
      cmd = list() # pylint: disable=use-list-literal
      if 'deny' in classAction.policyAction:
         cmd.append( 'drop' )
      if 'setNexthop' in classAction.policyAction:
         pbrAct = classAction.policyAction[ 'setNexthop' ]
         nhops = ' '.join( sorted( [ nh.stringValue for nh in pbrAct.nexthop ] ) )
         cmd.append( 'set nexthop %s%s%s' % (
            'recursive ' if pbrAct.recursive else '', nhops,
            ( ' vrf %s' % pbrAct.vrfName ) \
            if pbrAct.vrfName != '' else '' ) )
      if 'setNexthopGroup' in classAction.policyAction:
         pbrAct = classAction.policyAction[ 'setNexthopGroup' ]
         cmd.append( 'set nexthop-group %s' % pbrAct.nexthopGroup )
      if 'setDscp' in classAction.policyAction:
         cmd.append( 'set dscp %d' % \
                        classAction.policyAction[ 'setDscp' ].dscp )
      if 'setTtl' in classAction.policyAction:
         cmd.append( 'set ttl %d' % \
                     classAction.policyAction[ 'setTtl' ].ttl )
      if not addCmd:
         return ' '.join( cmd )
      else:
         for line in cmd:
            self.cmds.addCommand( line )

   def insertAclTypeAndAddCmd( self, cmapName, prio, aclType, cmd ):
      # Given the cmd string, insert the acl type in the raw match statement
      # appropriately and add the final cmd to the list of config cmds.
      cmd = re.sub( '^permit', 'match', cmd )
      tokens = cmd.split()
      insIndex = 1

      if 'vlan' in tokens:
         insIndex = tokens.index( 'vlan' ) + 3

      if not aclType in tokens:
         tokens.insert( insIndex, aclType )
      cmd = '%d %s' % ( prio, ' '.join( tokens ) )
      cmd += ' ' + self.savePMapClassAction( cmapName, addCmd=False )
      return cmd

   def saveRawClassMap( self, pmapName, cmapName, prio ):
      self.cmap = self.pmap.rawClassMap.get( cmapName, None )
      if not self.cmap:
         return
      pmapMode = self.root[ PolicyMapConfigMode ].getOrCreateModeInstance( \
                            ( self.mapType, self.mapStr, pmapName ) )
      for cmapMatch in self.cmap.match.values():
         for ipRuleCfg in cmapMatch.ipRule.values():
            ipRuleCmd = ruleFromValue( ipRuleCfg, 'ip' )
            cmd = self.insertAclTypeAndAddCmd( cmapName, prio, 'ip', ipRuleCmd )
            _rawMatchMode = pmapMode[ PolicyMapClassConfigMode ].\
                            getOrCreateModeInstance( ( self.mapType,
                                                       self.mapStr,
                                                       pmapName,
                                                       cmapName, prio, cmd ) )

         for ip6RuleCfg in cmapMatch.ip6Rule.values():
            ip6RuleCmd = ruleFromValue( ip6RuleCfg, 'ipv6' )
            cmd = self.insertAclTypeAndAddCmd( cmapName, prio, 'ipv6', ip6RuleCmd )
            _rawMatchMode = pmapMode[ PolicyMapClassConfigMode ].\
                            getOrCreateModeInstance( ( self.mapType,
                                                       self.mapStr,
                                                       pmapName,
                                                       cmapName, prio, cmd ) )

         if cmapMatch.option == PolicyMap.matchMplsAccessGroup:
            cmd = '%d match mpls any' % prio
            cmd += ' ' + self.savePMapClassAction( cmapName, addCmd=False )
            _rawMatchMode = pmapMode[ PolicyMapClassConfigMode ].\
                            getOrCreateModeInstance( ( self.mapType,
                                                       self.mapStr,
                                                       pmapName,
                                                       cmapName, prio, cmd ) )

   def savePMapClass( self, pmapName, cmapName, prio ):
      pmapMode = self.root[ PolicyMapConfigMode ].getOrCreateModeInstance( \
                            ( self.mapType, self.mapStr, pmapName ) )
      if self._rawClassMap( cmapName ):
         self.saveRawClassMap( pmapName, cmapName, prio )
      else:
         configPmapClassMode = pmapMode[ PolicyMapClassConfigMode ].\
                               getOrCreateModeInstance( ( self.mapType,
                                                          self.mapStr,
                                                          pmapName,
                                                          cmapName, prio, None ) )
         self.cmds = configPmapClassMode[ 'Pbr.pmapc' ]
         self.savePMapClassAction( cmapName )

   def savePMapClassAll( self, pmapName ):
      for prio, cmap in self.pmap.classPrio.items():
         self.savePMapClass( pmapName, cmap, prio )

   def savePMap( self, pmapName ):
      self.pmap = self.entity.pmapType.pmap[ pmapName ].currCfg
      if not self.pmap:
         return
      pmapMode = self.root[ PolicyMapConfigMode ].getOrCreateModeInstance( \
                            ( self.mapType, self.mapStr, pmapName ) )
      self.cmds = pmapMode[ 'Pbr.pmap' ]
      self.savePMapClassAll( pmapName )

   def savePMapAll( self ):
      # display all pmaps in sorted order of names
      pmapNames = sorted( self.entity.pmapType.pmap.keys() )
      for pmapName in pmapNames:
         self.savePMap( pmapName )

   def saveCMapMatch( self, cmapName, option ):
      cmapMatch = self.cmap.match[ option ]
      ip = matchOptionToStr( option )
      for prio, aclName in cmapMatch.acl.items():
         self.cmds.addCommand( '%s match %s access-group %s' % \
                               ( prio, ip, aclName ) )

   def saveCMapMatchAll( self, cmapName ):
      for option in self.cmap.match:
         self.saveCMapMatch( cmapName, option )

   def _rawClassMap( self, cmapName ):
      return cmapName in self.pmap.rawClassMap

   def saveCMap( self, cmapName ):
      self.cmap = self.entity.cmapType.cmap[ cmapName ].currCfg
      if not self.cmap:
         return
      cmapMode = self.root[ ClassMapConfigMode ].getOrCreateModeInstance( \
                          ( self.mapType, self.mapStr, cmapName ) )
      self.cmds = cmapMode[ 'Pbr.cmap' ]
      self.saveCMapMatchAll( cmapName )

   def saveCMapAll( self ):
      # display all class maps in sorted order of names
      cmapNames = sorted( self.entity.cmapType.cmap.keys() )
      for cmapName in cmapNames:
         self.saveCMap( cmapName )

   def saveServicePolicy( self, attachType ):
      if attachType == 'primary':
         intfCollection = self.intfConfig.intf
         fallback = ''
      else:
         intfCollection = self.intfConfig.intfFallback
         fallback = ' ' + 'fallback'

      for intfName, pmap in intfCollection.items():
         intfMode = self.root[
            IntfCliSave.IntfConfigMode ].getOrCreateModeInstance( intfName )
         self.cmds = intfMode[ 'Pbr.config' ]
         self.cmds.addCommand( 'service-policy type pbr input %s%s' %
                               ( pmap, fallback ) )

   def save( self ):
      self.savePMapAll()
      self.saveCMapAll()
      self.saveServicePolicy( 'primary' )
      self.saveServicePolicy( 'fallback' )

@CliSave.saver( 'Pbr::PbrConfig', 'pbr/input/pmap/cli',
                requireMounts=( 'pbr/input/intf/cli', ) )
def saveConfig( entity, root, requireMounts, options ):
   cliDumper = PbrCliSaver( entity, root, requireMounts, options )
   cliDumper.save()
