# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import Tac
from CliPlugin import StormControlCli
from CliPlugin import IntfCli
from CliDynamicSymbol import CliDynamicPlugin

StormControlCliModel = CliDynamicPlugin( "StormControlModel" )


BPS_RESOLUTION = { "bps": "bpsResolution", "kbps": "kbpsResolution",
                   "mbps": "mbpsResolution", "gbps": "gbpsResolution" }
BPS_FACTOR = { "bps": 1, "kbps": 1000, "mbps": 1000000, "gbps": 1000000000 }
BPS_EXCEED_MAX_ERROR = "Error: %s is greater than the maximum configurable value: %s"

def setAggregateLevelPps( mode, args ):
   tc = args[ 'TRAFFIC_CLASS' ]
   level = args[ 'PPS' ]
   threshold = Tac.Value( 'Bridging::StormControl::Threshold', 'packetsPerSecond',
                           level )
   assert threshold != StormControlCli.defaultThreshold
   if tc not in StormControlCli.config.aggregateTcConfig:
      StormControlCli.config.newAggregateTcConfig( tc )
   StormControlCli.config.aggregateTcConfig[ tc ].level = threshold

def noAggregateLevelPps( mode, args ):
   tc = args[ 'TRAFFIC_CLASS' ]
   if tc in StormControlCli.config.aggregateTcConfig:
      StormControlCli.config.aggregateTcConfig[ tc ].level = \
         StormControlCli.defaultThreshold
      del StormControlCli.config.aggregateTcConfig[ tc ]


maxLevelPercentage = \
    Tac.Value( 'Bridging::StormControl::Threshold' ).maxLevelPercentage

def setUUcastThreshold( mode, cpu, threshold ):
   name = mode.intf.name
   existing = name in StormControlCli.config.intfConfig
   if threshold == StormControlCli.defaultThreshold and not existing:
      return
   intfConfig = StormControlCli.config.newIntfConfig( name )
   if cpu:
      intfConfig.uucastCpuLevel = threshold
   else:
      intfConfig.uucastLevel = threshold
   if existing and StormControlCli.stormControlDefaultConfig( name ):
      del StormControlCli.config.intfConfig[ name ]

def setUUcastLevelPercentage( mode, level, cpu=False ):
   setUUcastThreshold(
      mode, cpu, Tac.Value( 'Bridging::StormControl::Threshold', 'percentage',
                            int( 0.5 + ( level * maxLevelPercentage ) / 100 ) ) )

# Handler for policing Known Multicast CLI
def setPoliceKnownMulticast( mode, args, isPolicingEnabled=True ):
   StormControlCli.config.policeKnownMulticast = isPolicingEnabled

def unsetPoliceKnownMulticast( mode, args ):
   setPoliceKnownMulticast( mode, args, isPolicingEnabled=False )

def setUUcastLevelPps( mode, level, cpu=False ):
   setUUcastThreshold(
      mode, cpu, Tac.Value( 'Bridging::StormControl::Threshold', 'packetsPerSecond',
                            level ) )

def setUUcastLevelBps( mode, rate, rateUnit, cpu=False ):
   rateToSet = rate * BPS_FACTOR[ rateUnit ]
   if rateToSet > StormControlCli.defaultThreshold.maxLevelBps:
      wantRate = '%s %s' % ( rate, rateUnit )
      maxRate = '%s %s' % StormControlCliModel.convertToHighestBps( 
         StormControlCli.defaultThreshold.maxLevelBps )
      mode.addError( BPS_EXCEED_MAX_ERROR % ( wantRate.lower(), maxRate.lower() ) )
      return
   threshold = Tac.Value( 'Bridging::StormControl::Threshold', 'bitsPerSecond',
                          rateToSet )
   threshold.bitResolution = BPS_RESOLUTION[ rateUnit ]
   setUUcastThreshold( mode, cpu, threshold )

def noUUcastLevel( mode, args ):
   setUUcastLevelPercentage( mode, 100, cpu='cpu' in args )

def unicastHandler( mode, args ):
   if 'PPS' in args:
      setUUcastLevelPps( mode, args[ 'PPS' ], cpu='cpu' in args )
   elif 'RATE' in args:
      setUUcastLevelBps( mode, args[ 'RATE' ], args[ 'RATE_UNIT' ],
                           cpu='cpu' in args )
   elif 'LEVEL' in args:
      setUUcastLevelPercentage( mode, args[ 'LEVEL' ], cpu='cpu' in args )
   else:
      intfConfig = StormControlCli.config.intfConfig.get( mode.intf.name )
      if intfConfig is None or \
         intfConfig.uucastLevel == StormControlCli.defaultThreshold:
         mode.addError( "cpu level unspecified" )
         return
      setUUcastThreshold( mode, True, intfConfig.uucastLevel )

def setBroadcastThreshold( mode, cpu, threshold ):
   name = mode.intf.name
   existing = name in StormControlCli.config.intfConfig
   if threshold == StormControlCli.defaultThreshold and not existing:
      return
   intfConfig = StormControlCli.config.newIntfConfig( name )
   if cpu:
      intfConfig.broadcastCpuLevel = threshold
   else:
      intfConfig.broadcastLevel = threshold
   if existing and StormControlCli.stormControlDefaultConfig( name ):
      del StormControlCli.config.intfConfig[ name ]

def setBroadcastLevelPercentage( mode, level, cpu=False ):
   setBroadcastThreshold(
      mode, cpu, Tac.Value( 'Bridging::StormControl::Threshold', 'percentage',
                            int( 0.5 + ( level * maxLevelPercentage ) / 100 ) ) )

def setBroadcastLevelPps( mode, level, cpu=False ):
   setBroadcastThreshold(
      mode, cpu, Tac.Value( 'Bridging::StormControl::Threshold', 'packetsPerSecond',
                            level ) )

def setBroadcastLevelBps( mode, rate, rateUnit, cpu=False ):
   rateToSet = rate * BPS_FACTOR[ rateUnit ]
   if rateToSet > StormControlCli.defaultThreshold.maxLevelBps:
      wantRate = '%s %s' % ( rate, rateUnit )
      maxRate = '%s %s' % StormControlCliModel.convertToHighestBps( 
         StormControlCli.defaultThreshold.maxLevelBps )
      
      mode.addError( BPS_EXCEED_MAX_ERROR % ( wantRate.lower(), maxRate.lower() ) )
      return
   threshold = Tac.Value( 'Bridging::StormControl::Threshold', 'bitsPerSecond',
                          rateToSet )
   threshold.bitResolution = BPS_RESOLUTION[ rateUnit ]
   setBroadcastThreshold( mode, cpu, threshold )

def noBroadcastLevel( mode, args ):
   setBroadcastLevelPercentage( mode, 100, cpu='cpu' in args )

def broadcastHandler( mode, args ):
   if 'PPS' in args:
      setBroadcastLevelPps( mode, args[ 'PPS' ], cpu='cpu' in args )
   elif 'RATE' in args:
      setBroadcastLevelBps( mode, args[ 'RATE' ], args[ 'RATE_UNIT' ],
                              cpu='cpu' in args )
   elif 'LEVEL' in args:
      setBroadcastLevelPercentage( mode, args[ 'LEVEL' ], cpu='cpu' in args )
   else:
      intfConfig = StormControlCli.config.intfConfig.get( mode.intf.name )
      if intfConfig is None or intfConfig.broadcastLevel == \
         StormControlCli.defaultThreshold:
         mode.addError( "cpu level unspecified" )
         return
      setBroadcastThreshold( mode, True, intfConfig.broadcastLevel )

def setMulticastThreshold( mode, cpu, threshold ):
   name = mode.intf.name
   existing = name in StormControlCli.config.intfConfig
   if threshold == StormControlCli.defaultThreshold and not existing:
      return
   intfConfig = StormControlCli.config.newIntfConfig( name )
   if cpu:
      intfConfig.multicastCpuLevel = threshold
   else:
      intfConfig.multicastLevel = threshold
   if existing and StormControlCli.stormControlDefaultConfig( name ):
      del StormControlCli.config.intfConfig[ name ]

def setMulticastLevelPercentage( mode, level, cpu=False ):
   setMulticastThreshold(
      mode, cpu, Tac.Value( 'Bridging::StormControl::Threshold', 'percentage',
                            int( 0.5 + ( level * maxLevelPercentage ) / 100 ) ) )

def setMulticastLevelPps( mode, level, cpu=False ):
   setMulticastThreshold(
      mode, cpu, Tac.Value( 'Bridging::StormControl::Threshold', 'packetsPerSecond',
                            level ) )

def setMulticastLevelBps( mode, rate, rateUnit, cpu=False ):
   rateToSet = rate * BPS_FACTOR[ rateUnit ]
   if rateToSet > StormControlCli.defaultThreshold.maxLevelBps:
      wantRate = '%s %s' % ( rate, rateUnit )
      maxRate = '%s %s' % StormControlCliModel.convertToHighestBps( 
         StormControlCli.defaultThreshold.maxLevelBps )
      
      mode.addError( BPS_EXCEED_MAX_ERROR % ( wantRate.lower(), maxRate.lower() ) )
      return
   threshold = Tac.Value( 'Bridging::StormControl::Threshold', 'bitsPerSecond',
                          rateToSet )
   threshold.bitResolution = BPS_RESOLUTION[ rateUnit ]
   setMulticastThreshold( mode, cpu, threshold )

def noMulticastLevel( mode, args ):
   setMulticastLevelPercentage( mode, 100, cpu='cpu' in args )

def multicastHandler( mode, args ):
   if 'PPS' in args:
      setMulticastLevelPps( mode, args[ 'PPS' ], cpu='cpu' in args )
   elif 'RATE' in args:
      setMulticastLevelBps( mode, args[ 'RATE' ], args[ 'RATE_UNIT' ],
                              cpu='cpu' in args )
   elif 'LEVEL' in args:
      setMulticastLevelPercentage( mode, args[ 'LEVEL' ], cpu='cpu' in args )
   else:
      intfConfig = StormControlCli.config.intfConfig.get( mode.intf.name )
      if intfConfig is None or intfConfig.multicastLevel == \
         StormControlCli.defaultThreshold:
         mode.addError( "cpu level unspecified" )
         return
      setMulticastThreshold( mode, True, intfConfig.multicastLevel )

def setAllThreshold( mode, cpu, threshold ):
   name = mode.intf.name
   existing = name in StormControlCli.config.intfConfig
   if threshold == StormControlCli.defaultThreshold and not existing:
      return
   intfConfig = StormControlCli.config.newIntfConfig( name )
   if cpu:
      intfConfig.allCpuLevel = threshold
   else:
      intfConfig.allLevel = threshold
   if existing:
      if StormControlCli.stormControlDefaultConfig( name ):
         del StormControlCli.config.intfConfig[ name ]
   else:
      # TODO: Check CPU levels
      if( StormControlCli.config.intfConfig[ name ].uucastLevel != \
         StormControlCli.defaultThreshold or
          StormControlCli.config.intfConfig[ name ].broadcastLevel != \
            StormControlCli.defaultThreshold or
          StormControlCli.config.intfConfig[ name ].multicastLevel != \
            StormControlCli.defaultThreshold ):
         mode.addWarning( "'storm-control all' will override broadcast,"
                          " multicast and unknown-unicast limits" )

def setAllLevelPercentage( mode, level, cpu=False ):
   setAllThreshold( mode, cpu,
                    Tac.Value( 'Bridging::StormControl::Threshold',
                               'percentage',
                               int( 0.5 + ( level * maxLevelPercentage ) / 100 ) ) )

def setAllLevelPps( mode, level, cpu=False ):
   setAllThreshold( mode, cpu,
                    Tac.Value( 'Bridging::StormControl::Threshold',
                               'packetsPerSecond', level ) )

def setAllLevelBps( mode, rate, rateUnit, cpu=False ):
   rateToSet = rate * BPS_FACTOR[ rateUnit ]
   if rateToSet > StormControlCli.defaultThreshold.maxLevelBps:
      wantRate = '%s %s' % ( rate, rateUnit )
      maxRate = '%s %s' % StormControlCliModel.convertToHighestBps( 
         StormControlCli.defaultThreshold.maxLevelBps )
      
      mode.addError( BPS_EXCEED_MAX_ERROR % ( wantRate.lower(), maxRate.lower() ) )
      return
   threshold = Tac.Value( 'Bridging::StormControl::Threshold', 'bitsPerSecond',
                          rateToSet )
   threshold.bitResolution = BPS_RESOLUTION[ rateUnit ]
   setAllThreshold( mode, cpu, threshold )

def noAllLevel( mode, args ):
   setAllLevelPercentage( mode, 100, cpu='cpu' in args )


def allLevelHandler( mode, args ):
   if 'PPS' in args:
      setAllLevelPps( mode, args[ 'PPS' ], cpu='cpu' in args )
   elif 'RATE' in args:
      setAllLevelBps( mode, args[ 'RATE' ], args[ 'RATE_UNIT' ],
                        cpu='cpu' in args)
   elif 'LEVEL' in args:
      setAllLevelPercentage( mode, args[ 'LEVEL' ], cpu='cpu' in args )
   else:
      intfConfig = StormControlCli.config.intfConfig.get( mode.intf.name )
      if intfConfig is None or intfConfig.allLevel == \
         StormControlCli.defaultThreshold:
         mode.addError( "cpu level unspecified" )
         return
      setAllThreshold( mode, True, intfConfig.allLevel )

def dropLoggingIntervalHandler( mode, args ):
   StormControlCli.config.stormControlDropLogInterval = args.get( 'INTERVAL', 0 )
 
def enableDropLogginghandler( mode, args ):
   StormControlCli.config.stormControlDropLoggingMode = 'on'

def noDropLoggingHandler( mode, args ):
   StormControlCli.config.stormControlDropLoggingMode = 'off'

def trafficTypeModel( threshold, rate, drops, dropOctets, dormant, cntSupported ):
   trafficModel = StormControlCliModel.StormControlType()
   trafficModel.level = threshold.level
   trafficModel.thresholdType = threshold.type
   trafficModel.rate = rate
   if cntSupported:
      trafficModel.drop = drops
      trafficModel.dropOctets = dropOctets
   else:
      trafficModel.drop = None
      trafficModel.dropOctets = None
   trafficModel.dormant = dormant
   return trafficModel

def showStormControl( mode, args ):
   intfs = args.get( 'INTFS' )
   stormControlStatus = StormControlCliModel.StormControlStatus()

   stormControlStatus.errdisableEnabledIs( StormControlCli.errdisableSupported() )
   
   if StormControlCli.status().aggregateStormControlSupported and intfs is None:
      for tc in StormControlCli.config.aggregateTcConfig:
         level = StormControlCli.config.aggregateTcConfig[ tc ].level
         trafficStatus = trafficTypeModel( level, 0, 0, 0, False, False )

         stormControlStatus.aggregateTrafficClasses[ tc ] = trafficStatus

   if intfs is None:
      if StormControlCli.status().dropCountSupported:
         intfs = StormControlCli.status().allIntfStatus
      elif StormControlCli.stormControlSupported():
         # Port-channel interfaces can have status entries in more
         # than one linecard, so this is a set.
         intfs = { intf
                   for sliceStatus in StormControlCli.sliceStatusDir().values()
                   for intf in sliceStatus.intfStatus }

   if not intfs and not StormControlCli.config.aggregateTcConfig:
      return stormControlStatus

   # storm control count is only available on some platforms
   cntSupported = StormControlCli.status().dropCountSupported

   mem2PortChannels = {}
   def getAllIntfsInPortChannels( intfList ):
      for poIntfName in StormControlCli.lagStatus.intfStatus:
         for member in StormControlCli.lagStatus.intfStatus.get( poIntfName ).member:
            memShortName = IntfCli.Intf.getShortname( member )
            poShortName = IntfCli.Intf.getShortname( poIntfName )
            intfList[ memShortName ] = poShortName
      return intfList

   getAllIntfsInPortChannels( mem2PortChannels )

   def getParentIntf( intf ):
      return intf.split('.')[0]

   levelTypes = [ 'all', 'multicast', 'broadcast', 'uucast' ]
   for i in intfs:
      statusInst = None

      dropCnt = { it : 0 for it in levelTypes }
      dropOctets = { it : 0 for it in levelTypes }
      totalRates = { it : 0 for it in levelTypes }
      dropCntCpu = { it : 0 for it in levelTypes }
      dropOctetsCpu = { it : 0 for it in levelTypes }
      totalRatesCpu = { it : 0 for it in levelTypes }

      # For Sand platform, we don't want to read counters from intfStatus because
      # intfStatus doesn't have space holder for DropOctets. When we add this field,
      # we should remove check for stormControlDropCountersSupported.
      if StormControlCli.status().dropCountSupported:
         # For platforms which supports both counting and metering
         if i not in StormControlCli.status().allIntfStatus:
            continue
         intfStatus = None
         for intfStatus in StormControlCli.status().allIntfStatus[ i ].intfStatus:
            totalRates[ 'all' ] += intfStatus.allRate
            totalRates[ 'broadcast' ] += intfStatus.broadcastRate
            totalRates[ 'multicast' ] += intfStatus.multicastRate
            totalRates[ 'uucast' ] += intfStatus.uucastRate
            totalRatesCpu[ 'all' ] += intfStatus.allCpuRate
            totalRatesCpu[ 'broadcast' ] += intfStatus.broadcastCpuRate
            totalRatesCpu[ 'multicast' ] += intfStatus.multicastCpuRate
            totalRatesCpu[ 'uucast' ] += intfStatus.uucastCpuRate
            # For non-Sand platform, we read counters from IntfStatus            
            dropCnt[ 'all' ] += intfStatus.allDrop
            dropCnt[ 'broadcast' ] += intfStatus.broadcastDrop
            dropCnt[ 'multicast' ] += intfStatus.multicastDrop
            dropCnt[ 'uucast' ] += intfStatus.uucastDrop
            dropOctets[ 'all' ] += intfStatus.allDropOcts
            dropOctets[ 'broadcast' ] += intfStatus.broadcastDropOcts
            dropOctets[ 'multicast' ] += intfStatus.multicastDropOcts
            dropOctets[ 'uucast' ] += intfStatus.uucastDropOcts            
            dropCntCpu[ 'all' ] += intfStatus.allCpuDrop
            dropCntCpu[ 'broadcast' ] += intfStatus.broadcastCpuDrop
            dropCntCpu[ 'multicast' ] += intfStatus.multicastCpuDrop
            dropCntCpu[ 'uucast' ] += intfStatus.uucastCpuDrop
      elif StormControlCli.stormControlSupported():
         for sliceStatus in StormControlCli.sliceStatusDir().values():
            if i in sliceStatus.intfStatus:
               statusInst = sliceStatus
               totalRates[ 'all' ] += statusInst.intfStatus[ i ].allRate
               totalRates[ 'broadcast' ] += statusInst.intfStatus[ i ].broadcastRate
               totalRates[ 'multicast' ] += statusInst.intfStatus[ i ].multicastRate
               totalRates[ 'uucast' ] += statusInst.intfStatus[ i ].uucastRate
               totalRatesCpu[ 'all' ] += statusInst.intfStatus[ i ].allCpuRate
               totalRatesCpu[ 'broadcast' ] += \
                  statusInst.intfStatus[ i ].broadcastCpuRate
               totalRatesCpu[ 'multicast' ] += \
                  statusInst.intfStatus[ i ].multicastCpuRate
               totalRatesCpu[ 'uucast' ] += statusInst.intfStatus[ i ].uucastCpuRate
         if statusInst is None:
            continue
         intfStatus = statusInst.intfStatus[ i ]
      else:
         intfStatus = None
      stormControlAll = False

      intfStormControlStatus = StormControlCliModel.IntfStormControlStatus()
      if intfStatus.allLevel != StormControlCli.defaultThreshold:

         intfStormControlStatus.trafficTypes[ 'all' ] = trafficTypeModel(
            intfStatus.allLevel, totalRates[ 'all' ], dropCnt[ 'all' ], 
            dropOctets[ 'all' ], False, cntSupported )
         stormControlAll = True

      if intfStatus.allCpuLevel != StormControlCli.defaultThreshold:
         intfStormControlStatus.trafficTypes[ 'all-cpu' ] = trafficTypeModel(
            intfStatus.allCpuLevel, totalRatesCpu[ 'all' ], dropCntCpu[ 'all' ],
            dropOctetsCpu[ 'all' ], False, cntSupported )
         stormControlAll = True

      if intfStatus.broadcastLevel != StormControlCli.defaultThreshold:
         intfStormControlStatus.trafficTypes[ 'broadcast' ] = trafficTypeModel(
            intfStatus.broadcastLevel, totalRates[ 'broadcast' ],
            dropCnt[ 'broadcast' ], dropOctets[ 'broadcast' ], stormControlAll, 
            cntSupported )

      if intfStatus.broadcastCpuLevel != StormControlCli.defaultThreshold:
         intfStormControlStatus.trafficTypes[ 'broadcast-cpu' ] = trafficTypeModel(
            intfStatus.broadcastCpuLevel, totalRatesCpu[ 'broadcast' ],
            dropCntCpu[ 'broadcast' ], dropOctetsCpu[ 'broadcast' ],
            stormControlAll, cntSupported )

      if intfStatus.uucastLevel != StormControlCli.defaultThreshold:
         intfStormControlStatus.trafficTypes[ 'unknown-unicast' ] = \
            trafficTypeModel(
               intfStatus.uucastLevel, totalRates[ 'uucast' ], dropCnt[ 'uucast' ],
               dropOctets[ 'uucast' ], stormControlAll, cntSupported )

      if intfStatus.uucastCpuLevel != StormControlCli.defaultThreshold:
         intfStormControlStatus.trafficTypes[ 'unknown-unicast-cpu' ] = \
            trafficTypeModel(
               intfStatus.uucastCpuLevel, totalRatesCpu[ 'uucast' ],
               dropCntCpu[ 'uucast' ], dropOctetsCpu[ 'uucast' ], stormControlAll,
               cntSupported )

      if intfStatus.multicastLevel != StormControlCli.defaultThreshold:
         intfStormControlStatus.trafficTypes[ 'multicast' ] = trafficTypeModel(
            intfStatus.multicastLevel, totalRates[ 'multicast' ],
            dropCnt[ 'multicast' ], dropOctets[ 'multicast' ], stormControlAll,
            cntSupported )
         
      if intfStatus.multicastCpuLevel != StormControlCli.defaultThreshold:
         intfStormControlStatus.trafficTypes[ 'multicast-cpu' ] = trafficTypeModel(
            intfStatus.multicastCpuLevel, totalRatesCpu[ 'multicast' ],
            dropCntCpu[ 'multicast' ], dropOctetsCpu[ 'multicast' ],
            stormControlAll, cntSupported )

      stormCtrStatus = True
      reason = ''

      if i.startswith( 'Et' ):
         parent = getParentIntf( i )
         parent = IntfCli.Intf.getShortname( parent )
         if parent in mem2PortChannels:
            stormCtrStatus = False
            reason = 'member of %s' % mem2PortChannels[ parent ]
      elif i.startswith( 'Po' ):
         parent = getParentIntf( i )
         if not StormControlCli.lagStatus.intfStatus.get( parent ).member:
            stormCtrStatus = False
            reason = 'no members'
      intfStormControlStatus.active = stormCtrStatus
      intfStormControlStatus.reason = reason
      intfStormControlStatus.errdisabled = False
      if StormControlCli.errdisableSupported():
         intfStormControlStatus.errdisabled = i in \
            StormControlCli.errdisableStatus.errdisabledIntf
      stormControlStatus.interfaces[ i ] = intfStormControlStatus
   return stormControlStatus

def setBurst( mode, args ):
   time = args[ 'TIME' ]
   units = args[ 'UNITS' ]
   value = Tac.Value( 'Bridging::StormControl::BurstInfo', units, time )
   name = "burstConfig"
   if name in StormControlCli.config.burstConfig:
      StormControlCli.config.burstConfig[ name ].burst = value
   else:
      StormControlCli.config.newBurstConfig( name )
      StormControlCli.config.burstConfig[ name ].burst = value

def noBurst( mode, args ):
   name = "burstConfig"
   if name in StormControlCli.config.burstConfig:
      del StormControlCli.config.burstConfig[ name ]