#!/usr/bin/env python3
# Copyright (c) 2009-2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from CliPlugin import IntfCli, EthIntfCli, VlanCli
from CliPlugin import SubIntfCli
from CliPlugin.BridgingCli import warnMacTableUnsupportedUFTModeHook
from CliPlugin.PortSecCliModels import ( GeneralPortSecurityStatistics,
                               GeneralPortSecurityVlanStatistics,
                               PortSecurityAddresses,
                               PortSecurityInterfaces,
                               ethStr )
import CliCommand
import CliParser
import LazyMount
import SmashLazyMount
import ConfigMount
import Tac
from BridgingHostEntryType import isEntryTypeConfigured
import Toggles.MacMonToggleLib as Toggle

bridgingConfig = None
bridgingStatus = None
portSecConfig = None
portSecLocalConfig = None
portSecStatus = None
portSecHwCap = None

validIntfTypes = ( EthIntfCli.EthIntf, SubIntfCli.SubIntf )
VlanId = Tac.Type( "Bridging::VlanId" )

class PortSecIntf( IntfCli.IntfDependentBase ):
   def setDefault( self ):
      del portSecConfig.intfConfig[ self.intf_.name ]
      for intfVlanKey in portSecConfig.vlanConfig:
         if intfVlanKey.intfId == self.intf_.name:
            del portSecConfig.vlanConfig[ intfVlanKey ]
      for intfVlanKey in portSecConfig.defaultVlanConfig:
         if intfVlanKey.intfId == self.intf_.name:
            del portSecConfig.defaultVlanConfig[ intfVlanKey ]

def intfConfig( intf ):
   cfg = portSecConfig.intfConfig.get( intf )
   if cfg is not None:
      return Tac.nonConst( cfg )
   else:
      return Tac.Value( 'PortSec::IntfConfig', intf )

def vlanConfig( intf, vlanId, default=False ):
   intfVlanKey = Tac.Value( 'PortSec::IntfVlanKey', intf, vlanId )
   if default:
      vcfg = portSecConfig.defaultVlanConfig.get( intfVlanKey )
   else:
      vcfg = portSecConfig.vlanConfig.get( intfVlanKey )
   if vcfg is not None:
      return Tac.nonConst( vcfg )
   else:
      return Tac.Value( 'PortSec::VlanConfig', intfVlanKey )

def intfStatus( mode, intf ):
   return portSecStatus.intfStatus.get( intf.name, None )

def guardGlobalMacAddressConfig( mode, token ):
   if not portSecHwCap.allowSecureAddressDeletion:
      return CliParser.guardNotThisPlatform
   else:
      return None

def guardChipBasedConfig( mode, token ):
   if not portSecHwCap.allowChipBasedProtect:
      return CliParser.guardNotThisPlatform
   else:
      return None

def guardVlanBasedPortSec( mode, token ):
   if not Toggle.toggleVlanBasedPortSecEnabled():
      return CliParser.guardNotThisEosVersion
   else:
      return None

def guardDefaultVlanBasedPortSec( mode, token ):
   if not Toggle.toggleDefaultVlanBasedPortSecEnabled():
      return CliParser.guardNotThisEosVersion
   else:
      return None

def maximumRange( mode, context ):
   if not portSecConfig:
      return 1, portSecHwCap.allowedMaxLimitDefault

   cfg = portSecConfig.intfConfig.get( mode.intf.name )
   return 1, maximumLimit( cfg )

def maximumLimit( cfg ):
   if portSecConfig.chipBased and\
      cfg and\
      cfg.mode == 'protect' and\
      not cfg.log:
      return portSecHwCap.allowedMaxLimitChipBased

   return portSecHwCap.allowedMaxLimitDefault

def doShowPortSecurity( mode, args ):
   intfs = { x.name for x in
             IntfCli.Intf.getAll( mode, intfType=validIntfTypes ) }
   portStatistics = {}
   numAddresses = 0
   secureAddressMoves = portSecConfig.allowSecureAddressMoves
   secureAddressAging = portSecConfig.allowSecureAddressAging
   persistence = portSecConfig.persistenceEnabled
   for iname in portSecStatus.intfStatus:
      if iname not in intfs:
         continue
      intf = portSecStatus.intfStatus[ iname ]
      config = portSecConfig.intfConfig.get( iname )
      addrs = intf.addrs
      portStatistics[ iname ] = GeneralPortSecurityStatistics.PortStatistic(
                                   maxSecureAddr=config.maxAddrs if config else 0,
                                   intfTotalMacLimitDisabled=config.vlanBased,
                                   intfStaticOnly=config.staticOnly,
                                   currentAddr=addrs,
                                   numberOfViolations=intf.violations,
                                   securityAction=config.mode )
      numAddresses += addrs
   genPortStatStatistics = GeneralPortSecurityStatistics(
                              portStatistics=portStatistics,
                              totalAddresses=numAddresses,
                              secureAddressMoves=secureAddressMoves,
                              secureAddressAging=secureAddressAging,
                              persistence=persistence )
   return genPortStatStatistics

def doShowPortSecurityAddress( mode, args ):
   pStat = portSecStatus.intfStatus
   validIntfs = { x.name for x in
                  IntfCli.Intf.getAll( mode, intfType=validIntfTypes ) }
   portSecIntfNames = { i.intfId for i in
                        pStat.values() }.intersection( validIntfs )
   numAddresses = 0
   # TBD: We should show the addresses from the allowedAddress rather
   # than smashFdbStatus
   smashFdbStatuses = bridgingStatus.smashFdbStatus
   addresses = []
   hosts = [ host for host in smashFdbStatuses.values()
             if host.intf in portSecIntfNames ]
   for host in hosts:
      configured = isEntryTypeConfigured( host.entryType )
      entryType = "secureConfigured" if configured else "secureDynamic"
      addresses.append( PortSecurityAddresses.Address(
                           macAddress=host.key.addr,
                           vlan=host.key.fid,
                           entryType=entryType,
                           interface=host.intf,
                           remainingAge=None ) )
      numAddresses += 1
   return PortSecurityAddresses( addresses=addresses, totalAddresses=numAddresses )

def doShowPortSecurityInterface( mode, args ):
   intfs = IntfCli.Intf.getAll( mode, args.get( 'INTF', None ),
                                intfType=validIntfTypes )
   if not intfs:
      return PortSecurityInterfaces( interfaces={} )
   intfModels = {}
   for i in intfs:
      # Default values
      portSecurityEnabled = False
      portStatus = 'secure-down'
      violationMode = 'none'
      staticOnly = False
      portMaxEnabled = True
      maxMacAddresses = 1
      agingTime = int( bridgingConfig.hostAgingTime / 60.0 + 0.5 )
      agingType = "inactivity"
      secureStaticAddressAging = "disabled"
      secureAddressMoves = portSecConfig.allowSecureAddressMoves
      secureAddressAging = portSecConfig.allowSecureAddressAging
      persistence = portSecConfig.persistenceEnabled
      totalMacAddresses = None
      configuredMacAddresses = None
      addressChanges = None
      lastChangeDetails = None
      lastViolation = None
      securityViolationCount = None
      logAddrsAfterLimit = "disabled"
      allowedAddresses = []

      conf = portSecConfig.intfConfig.get( i.name )
      if conf:
         stat = intfStatus( mode, i )
         portMaxEnabled = not conf.vlanBased
         maxMacAddresses = conf.maxAddrs
         if stat and conf.enabled:
            if stat.restrictionStatus == 'restrictionActive':
               if conf.mode == 'shutdown':
                  portStatus = 'secure-shutdown'
               else:
                  portStatus = 'secure-protected'
            elif i.lineProtocolState() == 'up':
               portStatus = 'secure-up'
         if conf.staticOnly:
            staticOnly = True
         if conf.enabled:
            portSecurityEnabled = True
            violationMode = conf.mode
         if stat:
            totalMacAddresses = stat.addrs
            configuredMacAddresses = stat.staticAddrs
            addressChanges = stat.addrChanges
            interface = PortSecurityInterfaces.Interface
            securityViolationCount = stat.violations
            if stat.addrs:
               utcTime = stat.lastNewAddrTime + Tac.utcNow() - Tac.now()
               lastChangeDetails = interface.InterestingAddress(
                               macAddress=stat.lastNewAddr,
                               vlan=stat.lastNewVlanId,
                               time=utcTime )
            if stat.violations:
               utcTime = stat.lastViolationTime + Tac.utcNow() - Tac.now()
               lastViolation = interface.InterestingAddress(
                                  macAddress=stat.lastViolatingAddr,
                                  vlan=stat.lastViolatingVlanId,
                                  time=utcTime )
            if conf.mode == 'protect':
               if stat.log:
                  logAddrsAfterLimit = "enabled"
               for key in stat.allowedAddr:
                  allowedAddresses.append( "%s:%s" %
                                           ( ethStr( str( key.addr ) ), key.fid ) )

      intfModel = PortSecurityInterfaces.Interface(
            portSecurityEnabled=portSecurityEnabled,
            portStatus=portStatus,
            violationMode=violationMode,
            staticOnly=staticOnly,
            portMaxEnabled=portMaxEnabled,
            maxMacAddresses=maxMacAddresses,
            agingTime=agingTime,
            agingType=agingType,
            secureStaticAddressAging=secureStaticAddressAging,
            secureAddressMoves=secureAddressMoves,
            secureAddressAging=secureAddressAging,
            persistence=persistence,
            totalMacAddresses=totalMacAddresses,
            configuredMacAddresses=configuredMacAddresses,
            addressChanges=addressChanges,
            lastChangeDetails=lastChangeDetails,
            lastViolation=lastViolation,
            securityViolationCount=securityViolationCount,
            logAddrsAfterLimit=logAddrsAfterLimit,
            allowedAddresses=allowedAddresses )
      intfModels[ i.name ] = intfModel
   return PortSecurityInterfaces( interfaces=intfModels )

def doShowPortSecurityVlan( mode, args ):
   intfs = IntfCli.Intf.getAll( mode, args.get( 'INTFS' ),
                                intfType=EthIntfCli.EthIntf )
   intfs = { x.name for x in intfs }
   if not intfs:
      return GeneralPortSecurityVlanStatistics( interfaces={},
                                                totalAddresses=0 )
   vlans = args.get( 'VLANS' ) or { vlan.id for vlan in VlanCli.Vlan.getAll( mode ) }

   intfModels = {}
   numAddresses = 0
   vlanModels = {}

   def addVlanModel( status, vlanStat, dVlanCounter=None ):
      addrs = vlanStat.addrs
      violations = vlanStat.violations
      vlanId = vlanStat.vlanId
      if dVlanCounter is not None:
         addrs = dVlanCounter.addrs
         vlanId = dVlanCounter.vlanId
         violations = dVlanCounter.violations
      vlanModel = GeneralPortSecurityVlanStatistics.VlanDict.VlanStatistic(
            maxAddrs=vlanStat.maxAddrs,
            numAddrs=addrs,
            numViolations=violations,
            action=status.mode )
      if intfId not in vlanModels:
         vlanModels[ intfId ] = {}
      vlanModels[ intfId ][ vlanId ] = vlanModel
      return addrs

   for intfVlanKey, vlanStat in portSecStatus.vlanStatus.items():
      intfId = intfVlanKey.intfId
      vlanId = intfVlanKey.vlanId
      if intfId not in intfs:
         continue
      status = portSecStatus.intfStatus.get( intfId )
      if status is None:
         continue
      if vlanId == VlanId.invalid:
         for defaultIntfVlanKey in portSecStatus.defaultVlanCounter:
            if defaultIntfVlanKey.vlanId not in vlans or\
               defaultIntfVlanKey.intfId != intfId:
               continue
            dVlanCounter = portSecStatus.defaultVlanCounter.get( defaultIntfVlanKey )
            numAddresses += addVlanModel( status, vlanStat, dVlanCounter )

      if vlanId not in vlans:
         continue
      numAddresses += addVlanModel( status, vlanStat )
   for intfId, vlans in vlanModels.items():
      intfModels[ intfId ] = GeneralPortSecurityVlanStatistics.VlanDict(
            vlans=vlans )
   genVlanStatStatistics = GeneralPortSecurityVlanStatistics(
                              interfaces=intfModels,
                              totalAddresses=numAddresses )
   return genVlanStatStatistics

def clearPortSecurity( mode, args ):
   intfs = IntfCli.Intf.getAll( mode, args.get( 'INTF', None ),
                                intfType=validIntfTypes )
   if not intfs:
      return
   for i in intfs:
      portSecLocalConfig.clearPortSecurityRequest[ i.name ] = Tac.now()

def doEnDisPortSecurityChipBased( mode, args ):
   if CliCommand.isNoOrDefaultCmd( args ):
      portSecConfig.chipBased = False
      for intfId in portSecConfig.intfConfig:
         psiConf = intfConfig( intfId )
         maxLimit = maximumLimit( psiConf )
         if psiConf.maxAddrs > maxLimit:
            mode.addWarning( "Because you have disabled chip-based port security, "
                             "the operating MAC address maximum on %s is limited "
                             "to the non-chip-based limit %d" %
                             ( intfId, maxLimit ) )
            psiConf.maxAddrs = maxLimit
            portSecConfig.intfConfig.addMember( psiConf )
   else:
      portSecConfig.chipBased = True
      # Hook is populated on Trident4 platforms to warn of conflicting configuration
      warnMacTableUnsupportedUFTModeHook.notifyExtensions( mode )

def doEnDisPortSecurityLogging( mode, no=None ):
   if no:
      psiConf = portSecConfig.intfConfig.get( mode.intf.name )
      if psiConf:
         psiConf = Tac.nonConst( psiConf )
         psiConf.log = False
         portSecConfig.intfConfig.addMember( psiConf )
   else:
      conf = intfConfig( mode.intf.name )
      if conf and conf.maxAddrs > portSecHwCap.allowedMaxLimitAclBased:
         mode.addWarning( "Configured maximum addresses %d is higher than "
                          "the limit %d. Logging will not be enabled until "
                          "the configured maximum addresses value is below "
                          "the limit." % ( conf.maxAddrs,
                                           portSecHwCap.allowedMaxLimitAclBased ) )
      doEnDisPortSecurity( mode, portSecMode='protect', logging=True )

def doEnDisPortSecurity( mode, no=None, portSecMode='shutdown', logging=False ):
   if no and not portSecConfig.intfConfig.get( mode.intf.name ):
      # if port security is not configured, and this is
      # 'no switchport port-security', don't bother doing anything.
      return
   psiConf = intfConfig( mode.intf.name )
   if no:
      psiConf.enabled = False
      psiConf.log = False
      psiConf.mode = 'shutdown' # reset to default
   else:
      currentMode = psiConf.mode
      psiConf.mode = portSecMode
      psiConf.log = logging
      psiConf.enabled = True
      maxLimit = maximumLimit( psiConf )
      # We need to catch the misconfiguration only when we change the mode
      if currentMode != portSecMode and psiConf.maxAddrs > maxLimit:
         mode.addWarning( "Because you have configured port security %s mode, "
                          "the operating MAC address maximum on %s is lowered "
                          "to the limit %d."
                          % ( portSecMode, mode.intf.name, maxLimit ) )
         psiConf.maxAddrs = maxLimit
      # Hook is populated on Trident4 platforms to warn of conflicting configuration
      warnMacTableUnsupportedUFTModeHook.notifyExtensions( mode )
   portSecConfig.intfConfig.addMember( psiConf )

def doEnDisVlanBasedPortSec( mode, args ):
   psiConf = intfConfig( mode.intf.name )
   maximum = args.get( 'MAXIMUM' )

   # This will be used when we start using the new collection
   # for default vlan configuration
   isDefaultVlan = 'default' in args
   vlans = args.get( 'VLAN_SET', () )

   defaultVlanId = Tac.Value( 'Bridging::VlanIdOrAnyOrNone', VlanId.invalid )
   if CliCommand.isNoOrDefaultCmd( args ):
      for vlan in vlans:
         vlanId = Tac.Value( 'Bridging::VlanIdOrAnyOrNone', vlan )
         intfVlanKey = Tac.Value( 'PortSec::IntfVlanKey', mode.intf.name, vlanId )
         del portSecConfig.vlanConfig[ intfVlanKey ]
      if isDefaultVlan:
         defaultIntfVlanKey = Tac.Value( 'PortSec::IntfVlanKey',
                                         mode.intf.name, defaultVlanId )
         del portSecConfig.defaultVlanConfig[ defaultIntfVlanKey ]
   elif psiConf.mode != 'shutdown':
      mode.addWarning( "This command is not supported in protect mode" )
   else:
      for vlan in vlans:
         vlanId = Tac.Value( 'Bridging::VlanIdOrAnyOrNone', vlan )
         psvConf = vlanConfig( mode.intf.name, vlanId )
         psvConf.maxAddrs = maximum
         portSecConfig.vlanConfig.addMember( psvConf )
      if isDefaultVlan:
         psvConf = vlanConfig( mode.intf.name, defaultVlanId, default=True )
         psvConf.maxAddrs = maximum
         portSecConfig.defaultVlanConfig.addMember( psvConf )

def doSetMaximum( mode, args ):
   maximum = args.get( "MAXIMUM", 1 )
   vlanBased = 'disabled' in args
   if Toggle.togglePortSecProtectStaticEnabled():
      staticOnly = 'static' in args
   else:
      staticOnly = False
   if ( maximum == 1 and
        not vlanBased and
        not ( Toggle.togglePortSecProtectStaticEnabled() and staticOnly ) and
        not portSecConfig.intfConfig.get( mode.intf.name ) ):
      # If port security is not configured, and this is trying to set the default
      # max value, then return. We don't want to create a default
      # PortSec::IntfConfig object for no reason.
      return
   # Disabled should override static, static should override maximum
   conf = intfConfig( mode.intf.name )
   conf.maxAddrs = maximum
   # If we set it to static mode, reenable interface based portsec protect
   conf.vlanBased = vlanBased
   conf.staticOnly = staticOnly
   portSecConfig.intfConfig.addMember( conf )

def doEnDisMacAddressLimitMaximum( mode, args ):
   maximum = args.get( 'MAXIMUM', 1 )
   no = CliCommand.isNoOrDefaultCmd( args )
   if no and not portSecConfig.intfConfig.get( mode.intf.name ):
      # if mac address limit is not configured, and this is
      # 'no mac address limit', don't bother doing anything.
      return
   psiConf = intfConfig( mode.intf.name )
   psiConf.enabled = not no
   psiConf.maxAddrs = 1 if no else maximum
   psiConf.log = False
   # logging not supported at the point when this changes got introduced
   portSecConfig.intfConfig.addMember( psiConf )

def doSetMacAddressLimitViolationMode( mode, args ):
   # Just set the violationMode as specified.
   violationMode = 'protect' if not args.get( 'shutdown' ) \
                   else 'shutdown'
   psiConf = intfConfig( mode.intf.name )
   psiConf.mode = violationMode
   portSecConfig.intfConfig.addMember( psiConf )

def doEnAging( mode, args ):
   portSecConfig.allowSecureAddressAging = True

def doDisAging( mode, args ):
   portSecConfig.allowSecureAddressAging = False

def doEnMoves( mode, args ):
   portSecConfig.allowSecureAddressMovesData = 'data' in args
   portSecConfig.allowSecureAddressMovesPhone = 'phone' in args
   portSecConfig.allowSecureAddressMoves = not(
         portSecConfig.allowSecureAddressMovesData or
         portSecConfig.allowSecureAddressMovesPhone )

def doDisMoves( mode, args ):
   portSecConfig.allowSecureAddressMoves = False
   portSecConfig.allowSecureAddressMovesData = False
   portSecConfig.allowSecureAddressMovesPhone = False

def doEnPersistence( mode, args ):
   portSecConfig.persistenceEnabled = True

def doDisPersistence( mode, args ):
   portSecConfig.persistenceEnabled = False

def doEnShutdownPersistence( mode, args ):
   portSecConfig.persistenceShutdownEnabled = True

def doDisShutdownPersistence( mode, args ):
   portSecConfig.persistenceShutdownEnabled = False

def Plugin( entityManager ):
   global bridgingConfig, bridgingStatus
   global portSecConfig, portSecLocalConfig, portSecStatus, portSecHwCap
   bridgingConfig = LazyMount.mount( entityManager, "bridging/config",
                                     "Bridging::Config", "r" )
   portSecConfig = ConfigMount.mount( entityManager, "portsec/config",
                                      "PortSec::Config", "w" )
   portSecLocalConfig = LazyMount.mount( entityManager, "portsec/localconfig",
                                         "PortSec::LocalConfig", "w" )
   portSecStatus = LazyMount.mount( entityManager, "portsec/status",
                                    "PortSec::Status", "r" )
   portSecHwCap = LazyMount.mount( entityManager, "portsec/hwcap",
                                   "PortSec::HwCapabilities", "r" )
   IntfCli.Intf.registerDependentClass( PortSecIntf )
   bridgingStatus = SmashLazyMount.mount( entityManager, "bridging/status",
                                          "Smash::Bridging::Status",
                                          SmashLazyMount.mountInfo( 'reader' ) )
