#!/usr/bin/env python3
# Copyright (c) 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import Tac
import CliSave
from CliSavePlugin import IntfCliSave
from CliSavePlugin.AclCliSave import IpAclConfigMode
from CliSavePlugin.IntfCliSave import IntfConfigMode
from CliMode.PolicyMap import ClassMapModeBase, PolicyMapModeBase, \
   PolicyMapClassModeBase
from CliMode.TapAgg import TapAggActionModeBase
from AclCliLib import ruleFromValue
import re
import Intf.IntfRange
import six

tacEthAddr = Tac.Type( "Arnet::EthAddr" )

# Save after interface mode so we can parse in startup-config
CliSave.GlobalConfigMode.addCommandSequence( 'TapAgg.pmapconfig',
                           after=[ IpAclConfigMode, IntfConfigMode ] )

IntfCliSave.IntfConfigMode.addCommandSequence( 'TapAgg.intfconfig' )

class ClassMapConfigMode( ClassMapModeBase, CliSave.Mode ):
   def __init__( self, param ):
      ClassMapModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, self.longModeKey )

CliSave.GlobalConfigMode.addChildMode( ClassMapConfigMode,
                                       after=[ IpAclConfigMode ] )

ClassMapConfigMode.addCommandSequence( 'TapAgg.cmap' )


#-------------------------------------------------------------------------------
# Object used for saving commands in "config-pmap" mode.
#-------------------------------------------------------------------------------
class PolicyMapConfigMode( PolicyMapModeBase, CliSave.Mode ):
   def __init__( self, param ):
      PolicyMapModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, self.longModeKey )

CliSave.GlobalConfigMode.addChildMode( PolicyMapConfigMode,
                                       after=[ ClassMapConfigMode ] )
PolicyMapConfigMode.addCommandSequence( 'TapAgg.pmap' )

#-------------------------------------------------------------------------------
# Object used for saving commands in "config-pmap-c" mode.
#-------------------------------------------------------------------------------
class PolicyMapClassConfigMode( PolicyMapClassModeBase, CliSave.Mode ):
   def __init__( self, param ):
      ( self.mapType, self.mapStr, self.pmapName, self.cmapName,
        self.prio, self.entCmd_ ) = param
      PolicyMapClassModeBase.__init__( self, param[ : -1 ] )
      CliSave.Mode.__init__( self, self.longModeKey )

   def enterCmd( self ):
      if self.entCmd_:
         return self.entCmd_
      else:
         return '%s class %s' % ( self.prio, self.cmapName )

   def instanceKey( self ):
      return self.prio

   @classmethod
   def useInsertionOrder( cls ):
      # because `instanceKey` is overridden with prio
      return True

   def modeSeparator( self ):
      return not self.entCmd_

PolicyMapConfigMode.addChildMode( PolicyMapClassConfigMode )
PolicyMapClassConfigMode.addCommandSequence( 'TapAgg.pmapc' )

#-------------------------------------------------------------------------------
# Object used for saving commands in "config-pmap-cmap-a" mode.
#-------------------------------------------------------------------------------
class TapAggActionSaveMode( TapAggActionModeBase, CliSave.Mode ):
   def __init__( self, param ):
      TapAggActionModeBase.__init__( self, param )
      CliSave.Mode.__init__( self, self.longModeKey )

PolicyMapClassConfigMode.addChildMode(
      TapAggActionSaveMode, after=[ 'TapAgg.pmapc' ] )
TapAggActionSaveMode.addCommandSequence( 'TapAgg.action' )

def CmdPMapClassAction( classAction, requireMounts ):
   cmd = ''
   aggGroup = classAction.policyAction.get( 'setAggregationGroup' )
   idTag = classAction.policyAction.get( 'setIdentityTag' )
   stripHdr = classAction.policyAction.get( 'stripHeaderBytes' )
   nexthopGroupAct = classAction.policyAction.get( 'setNexthopGroup' )
   dropAct = classAction.policyAction.get( 'deny' )
   macAddress = classAction.policyAction.get( 'setMacAddress' )
   timestampHeader = classAction.policyAction.get( 'setTimestampHeader' )

   tapAggConfig = requireMounts[ 'tapagg/cliconfig' ] 
   toolGroupNewFormat = tapAggConfig.toolGroupNewFormat

   if aggGroup or idTag or macAddress or timestampHeader:
      cmd = 'set'
      if aggGroup:
         aggGroupList = list( aggGroup.aggGroup )
         aggIntfList = list( aggGroup.aggIntf )
         if aggGroupList:
            if toolGroupNewFormat or aggIntfList or ( len(aggGroupList) > 1 ):
               cmd += ' aggregation-group group %s' \
                   % ( ' group '.join( sorted( aggGroupList ) ) )
            else:
               cmd += ' aggregation-group %s' \
                   % ( aggGroupList[ 0 ] )

         if aggIntfList :
            cmd += ' interface'
            printIntfList = Intf.IntfRange.intfListToCanonical( aggIntfList )
            for intf in printIntfList:
               cmd += ' ' + intf

      if idTag:
         cmd += ' id-tag %d' % idTag.idTag.outer
         if idTag.idTag.inner:
            cmd += ' inner %d' % idTag.idTag.inner

      if macAddress and macAddress.macAddress.destMac != tacEthAddr.ethAddrZero:
         cmd += ' mac-address dest %s' % macAddress.macAddress.destMac
         if macAddress.macAddress.srcMac != tacEthAddr.ethAddrZero:
            cmd += ' src %s' % macAddress.macAddress.srcMac

      if timestampHeader:
         cmd += ' mac timestamp header'

   if stripHdr:
      cmd += ' ' if cmd else ''
      if stripHdr.stripHdrBytes.hdrType == 'dot1q':
         cmd += 'remove dot1q outer %s' \
               % stripHdr.stripHdrBytes.dot1qRemoveVlans

   if nexthopGroupAct:
      sortedNhgList = sorted( nexthopGroupAct.nexthopGroup )
      if sortedNhgList:
         cmd = 'set nexthop-group %s' % ' '.join( sortedNhgList )
      if timestampHeader:
         cmd += ' mac timestamp header'

   if dropAct:
      cmd = 'drop'

   return cmd

class TapAggPmapCliSaver( object ): # pylint: disable=useless-object-inheritance
   def __init__( self, entity, root, requireMounts, options ):
      self.entity = entity
      self.root = root
      self.options = options
      self.mapType = 'mapTapAgg'
      self.mapStr = 'tapagg'
      self.requireMounts = requireMounts
      self.cmds = None
      self.pmap = None
      self.cmap = None
      self.intfConfig = self.requireMounts[ 'tapagg/intfconfig' ] 

   def savePMapClassAction( self, cmapName ):
      classAction = self.pmap.classAction[ cmapName ]
      aggGroup = None
      idTag = None
      aggGroupList = []
      aggIntfList = []
      nexthopGroupList = []
      aggGroup = classAction.policyAction.get( 'setAggregationGroup' )
      idTag = classAction.policyAction.get( 'setIdentityTag' )
      stripHdr = classAction.policyAction.get( 'stripHeaderBytes' )
      nexthopGroupAct = classAction.policyAction.get( 'setNexthopGroup' )
      dropAct = classAction.policyAction.get( 'deny' )
      macAddress = classAction.policyAction.get( 'setMacAddress' )
      timestampHeader = classAction.policyAction.get( 'setTimestampHeader' )
      headerRemove = classAction.policyAction.get( 'setHeaderRemove' )

      tapAggConfig = self.requireMounts[ 'tapagg/cliconfig' ] 
      toolGroupNewFormat = tapAggConfig.toolGroupNewFormat

      if aggGroup:
         aggGroupList = list( aggGroup.aggGroup )
      if aggGroupList:
         cmd = 'set aggregation-group'
         if toolGroupNewFormat or ( len(aggGroupList) > 1 ):
            cmd += ' group %s' % ' group '.join( sorted( aggGroupList ) )
         else:
            cmd += ' %s' % ' '.join( sorted( aggGroupList ) )
         if idTag:
            if idTag.idTag.inner:
               cmd += ' id-tag %d inner %d' % ( idTag.idTag.outer,
                                                idTag.idTag.inner )
            else:
               cmd += ' id-tag %d' % idTag.idTag.outer
         if macAddress and macAddress.macAddress.destMac != tacEthAddr.ethAddrZero:
            cmd += ' mac-address dest %s' % macAddress.macAddress.destMac
            if macAddress.macAddress.srcMac != tacEthAddr.ethAddrZero:
               cmd += ' src %s' % macAddress.macAddress.srcMac
         if timestampHeader:
            cmd += ' mac timestamp header'
         self.cmds.addCommand( cmd )
      elif idTag:
         if idTag.idTag.inner:
            cmd = 'set id-tag %d inner %d' % ( idTag.idTag.outer,
                                               idTag.idTag.inner )
         else:
            cmd = 'set id-tag %d' % idTag.idTag.outer
         if macAddress and macAddress.macAddress.destMac != tacEthAddr.ethAddrZero:
            cmd += ' mac-address dest %s' % macAddress.macAddress.destMac
            if macAddress.macAddress.srcMac != tacEthAddr.ethAddrZero:
               cmd += ' src %s' % macAddress.macAddress.srcMac
         if timestampHeader:
            cmd += ' mac timestamp header'
         self.cmds.addCommand( cmd )
      elif macAddress and macAddress.macAddress.destMac != tacEthAddr.ethAddrZero:
         cmd = 'set mac-address dest %s' % macAddress.macAddress.destMac
         if macAddress.macAddress.srcMac != tacEthAddr.ethAddrZero:
            cmd += ' src %s' % macAddress.macAddress.srcMac
         self.cmds.addCommand( cmd )

      if nexthopGroupAct:
         nexthopGroupList = list( nexthopGroupAct.nexthopGroup )

      if timestampHeader and not any( [ aggGroupList, idTag, nexthopGroupList ] ):
         cmd = 'set mac timestamp header'
         self.cmds.addCommand( cmd )

      if aggGroup:
         aggIntfList = list( aggGroup.aggIntf )
      if aggIntfList:
         printIntfs = Intf.IntfRange.intfListToCanonical( aggIntfList )
         for intfs in printIntfs:
            self.cmds.addCommand( 'set interface %s' % intfs )

      if stripHdr:
         if stripHdr.stripHdrBytes.hdrType == 'dot1q':
            self.cmds.addCommand( 'remove dot1q outer %s'
                                  % stripHdr.stripHdrBytes.dot1qRemoveVlans )

      if headerRemove and headerRemove.headerRemove.size != 0:
         cmd = 'remove header size %d' % headerRemove.headerRemove.size
         if headerRemove.headerRemove.preserveEth:
            cmd += ' preserve ethernet'
         self.cmds.addCommand( cmd )

      if nexthopGroupList:
         sortedNhgList = sorted( nexthopGroupList )
         cmd = 'set nexthop-group %s' % ' '.join( sortedNhgList )
         if timestampHeader:
            cmd += ' mac timestamp header'
         self.cmds.addCommand( cmd )

      if dropAct:
         self.cmds.addCommand( 'drop' )

   def insertAclTypeAndAddCmd( self, pmapName, cmapName, prio, aclType, cmd ):
      # Given the cmd string, insert the acl type in the raw match statement
      # appropriately and add the final cmd to the list of config cmds.
      cmd = re.sub( '^permit', 'match', cmd )
      tokens = cmd.split()
      insIndex = 1

      if 'inner' in tokens:
         insIndex = tokens.index( 'inner' ) + 3
      elif 'vlan' in tokens:
         insIndex = tokens.index( 'vlan' ) + 3

      if not aclType in tokens:
         tokens.insert( insIndex, aclType )
      cmd = '%d %s' % ( prio, ' '.join( tokens ) )
      cmd += ' ' + CmdPMapClassAction( self.pmap.classAction[ cmapName ],
                                       self.requireMounts )
      pmapMode = self.root[ PolicyMapConfigMode ].getOrCreateModeInstance( \
                            ( self.mapType, self.mapStr, pmapName ) )
      _rawMatchMode = pmapMode[ PolicyMapClassConfigMode ].\
                      getOrCreateModeInstance( ( self.mapType,
                                                 self.mapStr,
                                                 pmapName,
                                                 cmapName, prio, cmd ) )

   def saveRawClassMap( self, pmapName, cmapName, prio ):
      self.cmap = self.pmap.rawClassMap.get( cmapName, None )
      if not self.cmap:
         return

      for cmapMatch in six.itervalues( self.cmap.match ):
         for ipRuleCfg in six.itervalues( cmapMatch.ipRule ):
            ipRuleCmd = ruleFromValue( ipRuleCfg, 'ip' )
            self.insertAclTypeAndAddCmd( pmapName, cmapName, prio, 'ip',
                                         ipRuleCmd )

         for ip6RuleCfg in six.itervalues( cmapMatch.ip6Rule ):
            ip6RuleCmd = ruleFromValue( ip6RuleCfg, 'ipv6' )
            self.insertAclTypeAndAddCmd( pmapName, cmapName, prio, 'ipv6',
                                         ip6RuleCmd )

         for macRuleCfg in six.itervalues( cmapMatch.macRule ):
            macRuleCmd = ruleFromValue( macRuleCfg, 'mac' )
            self.insertAclTypeAndAddCmd( pmapName, cmapName, prio, 'mac',
                                         macRuleCmd )

   def savePMapClassActionSet( self, name, namedActionSet ):
      if namedActionSet.setAggGroup and namedActionSet.setAggGroup.aggGroup:
         cmd = 'set aggregation-group'
         groupList = list( namedActionSet.setAggGroup.aggGroup )
         if len( groupList ) == 1:
            cmd += ' %s' % groupList[ 0 ]
         else:
            cmd += ' group %s' % ' group '.join( sorted( groupList ) )
         self.cmds.addCommand( cmd )

      if namedActionSet.setAggGroup and namedActionSet.setAggGroup.aggIntf:
         cmd = 'set interface'
         intfList = list( namedActionSet.setAggGroup.aggIntf )
         printIntfs = Intf.IntfRange.intfListToCanonical( intfList )
         for intfs in printIntfs:
            self.cmds.addCommand( cmd + ' %s' % intfs )

      if namedActionSet.setIdentityTag:
         cmd = 'set id-tag'
         idTag = namedActionSet.setIdentityTag.idTag
         if idTag.inner:
            cmd += ' %d inner %d' % ( idTag.outer,
                                      idTag.inner )
         else:
            cmd += ' %d' % idTag.outer
         self.cmds.addCommand( cmd )

      if namedActionSet.stripHeaderBytes:
         cmd = 'remove dot1q outer'
         vlanIndices = namedActionSet.stripHeaderBytes.stripHdrBytes.dot1qRemoveVlans
         cmd += ' %s' % vlanIndices
         self.cmds.addCommand( cmd )

      if namedActionSet.setMacAddress:
         cmd = 'set mac-address dest'
         dstMac = namedActionSet.setMacAddress.macAddress.destMac
         cmd += ' %s' % dstMac
         srcMac = namedActionSet.setMacAddress.macAddress.srcMac
         if srcMac != tacEthAddr.ethAddrZero:
            cmd += ' src %s' % srcMac
         self.cmds.addCommand( cmd )

   def savePMapClassActionSetAll( self, configPmapClassMode, pmapName, cmapName ):
      classAction = self.pmap.classAction[ cmapName ]
      actionSet = classAction.policyAction.get( 'actionSet' )
      if not actionSet:
         return
      for aName in sorted( actionSet.namedActionSet.keys() ):
         nas = actionSet.namedActionSet[ aName ]
         param = ( configPmapClassMode, self.mapType, self.mapStr,
                   pmapName, cmapName, aName )
         aMode = configPmapClassMode[
               TapAggActionSaveMode ].getOrCreateModeInstance( param )
         self.cmds = aMode[ 'TapAgg.action' ]
         self.savePMapClassActionSet( aName, nas )

   def savePMapClass( self, pmapName, cmapName, prio ):
      pmapMode = self.root[ PolicyMapConfigMode ].getOrCreateModeInstance( \
                            ( self.mapType, self.mapStr, pmapName ) )
      if self._rawClassMap( cmapName ):
         self.saveRawClassMap( pmapName, cmapName, prio )
      else:
         configPmapClassMode = pmapMode[ PolicyMapClassConfigMode ].\
                               getOrCreateModeInstance( ( self.mapType,
                                                          self.mapStr, 
                                                          pmapName,
                                                          cmapName, prio, None ) )
         self.cmds = configPmapClassMode[ 'TapAgg.pmapc' ]
         self.savePMapClassAction( cmapName )
         self.savePMapClassActionSetAll( configPmapClassMode, pmapName, cmapName )

   def savePMapClassAll( self, pmapName ):
      self.pmap = self.entity.pmapType.pmap[ pmapName ].currCfg
      if self.pmap:
         for prio, cmap in six.iteritems( self.pmap.classPrio ):
            self.savePMapClass( pmapName, cmap, prio )

   def savePMap( self, pmapName ):
      pmapMode = self.root[ PolicyMapConfigMode ].getOrCreateModeInstance( \
                            ( self.mapType, self.mapStr, pmapName ) )
      self.cmds = pmapMode[ 'TapAgg.pmap' ]
      self.savePMapClassAll( pmapName )

   def savePMapAll( self ):
      # display all pmaps in sorted order of names
      pmapNames = sorted( self.entity.pmapType.pmap.keys() )
      for pmapName in pmapNames:
         self.savePMap( pmapName )

   def saveCMapMatch( self, cmapName, option ):
      cmapMatch = self.cmap.match[ option ]
      mapOptions = { 'matchIpAccessGroup' : 'ip', 'matchIpv6AccessGroup' : 'ipv6',
                     'matchMacAccessGroup' : 'mac' }
      for prio, aclName in six.iteritems( cmapMatch.acl ):
         self.cmds.addCommand( '%s match %s access-group %s' % \
                               ( prio, mapOptions[ option ], aclName ) )
   def saveCMapMatchAll( self, cmapName ):
      self.cmap = self.entity.cmapType.cmap[ cmapName ].currCfg
      if self.cmap:
         for option in self.cmap.match:
            self.saveCMapMatch( cmapName, option )

   def _rawClassMap( self, cmapName ):
      return cmapName in self.pmap.rawClassMap

   def saveCMap( self, cmapName ):
      cmapMode = self.root[ ClassMapConfigMode ].getOrCreateModeInstance( \
                          ( self.mapType, self.mapStr, cmapName ) )
      self.cmds = cmapMode[ 'TapAgg.cmap' ]
      self.saveCMapMatchAll( cmapName )

   def saveCMapAll( self ):
      # display all class maps in sorted order of names
      cmapNames = sorted( self.entity.cmapType.cmap.keys() )
      for cmapName in cmapNames:
         self.saveCMap( cmapName )

   def saveServicePolicy( self ):
      for intfName, pmap in six.iteritems( self.intfConfig.intf ):
         intfMode = self.root[ IntfCliSave.IntfConfigMode ].\
                         getOrCreateModeInstance( intfName )
         self.cmds = intfMode[ 'TapAgg.intfconfig' ]
         self.cmds.addCommand( 'service-policy type tapagg input %s' % pmap )

   def save( self ):
      self.savePMapAll()
      self.saveCMapAll()
      self.saveServicePolicy()

@CliSave.saver( 'TapAgg::PmapConfig', 'tapagg/pmapconfig',
                requireMounts = ( 'bridging/config','tapagg/cliconfig',
                   'tapagg/intfconfig' ) )
def saveConfig( entity, root, requireMounts, options ):
   cliDumper = TapAggPmapCliSaver( entity, root, requireMounts, options )
   cliDumper.save()


