#!/usr/bin/env python3
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac
from CliPlugin.QosCliIntfTypes import ethIntfPrefixes
from CliPlugin.QosCli import ( invalidGuaranteedBw,
                              invalidShapeRate, qosInputConfig )
from CliPlugin.QosCliCommon import ( getIntfListFromMode, goToTxQueueRangeMode,
                                     setIntfConfig, cliBlockingFail,
                                     QosProfileMode, defaultWeight,
                                     setQosProfileConfig )
from CliPlugin.QosCliEcn import ( setTxQueueNonEct, ecnMaxDroprate,
                                  setEcnQueueCounterConfig )
from QosTypes import ( tacFeatureName, tacPercent, tacTxQueuePriority, tacQueueType,
                       tacBwWeight, tacEcnDelayThreshold, tacLatencyThreshold )

# -----------------------------------------------------------------------------------
# The "queue length weight <WEIGHT>" command, in 'tx-queue' mode.
# -----------------------------------------------------------------------------------
def setWeightCliBlockingFail( mode, timestamp, feature, description, intf,
                              prevWeightConfig ):
   if cliBlockingFail( mode, timestamp, feature, description, intf ):
      intfConfig = qosInputConfig.intfConfig.newMember( intf )
      txQueueConfig = intfConfig.txQueueConfig.newMember(
               mode.txQueue, tacTxQueuePriority.priorityInvalid,
               tacPercent.invalid, invalidShapeRate, invalidGuaranteedBw )
      txQueueConfig.weightConfig = prevWeightConfig
      setIntfConfig( intf )

def setTxQueueWeight( mode, txQueueConfig,
                      noOrDefaultKw=True, weight=defaultWeight ):
   prevWeight = txQueueConfig.weightConfig
   if noOrDefaultKw:
      txQueueConfig.weightConfig = defaultWeight
   else:
      if not ( prevWeight != defaultWeight and prevWeight == weight ):
         txQueueConfig.weightConfig = weight
   return txQueueConfig

# -----------------------------------------------------------------------------------
# Handler of AvrgWeightCmd
# -----------------------------------------------------------------------------------
def setWeight( mode, args ):
   weight = args.get( 'WEIGHT', defaultWeight )
   if isinstance( mode.parent_, QosProfileMode ):
      profile = mode.parent_.qosProfileModeContext.currentEntry_
      txQueueConfig = profile.txQueueConfig.newMember(
               mode.txQueue, tacTxQueuePriority.priorityInvalid, tacPercent.invalid,
               invalidShapeRate, invalidGuaranteedBw )
      setTxQueueWeight( mode, txQueueConfig, False, weight )
      setQosProfileConfig( profile )
   else:
      intfList = getIntfListFromMode( mode.parent_ )
      for intf in intfList:
         timestamp = Tac.now()
         intfConfig = qosInputConfig.intfConfig.newMember( intf )
         txQueueConfig = intfConfig.txQueueConfig.newMember(
               mode.txQueue, tacTxQueuePriority.priorityInvalid, tacPercent.invalid,
               invalidShapeRate, invalidGuaranteedBw )
         prevWeightConfig = txQueueConfig.weightConfig
         setTxQueueWeight( mode, txQueueConfig, False, weight )
         setIntfConfig( intf )
         setWeightCliBlockingFail( mode, timestamp, tacFeatureName.weight,
                                   "WEIGHT", intf, prevWeightConfig )

# -----------------------------------------------------------------------------------
# No and default handler of AvrgWeightCmd
# -----------------------------------------------------------------------------------
def setNoOrDefaultWeight( mode, args ):
   if isinstance( mode.parent_, QosProfileMode ):
      profile = mode.parent_.qosProfileModeContext.currentEntry_
      txQueueConfig = profile.txQueueConfig.get( mode.txQueue )
      if not txQueueConfig:
         return
      setTxQueueWeight( mode, txQueueConfig )
      setQosProfileConfig( profile )
   else:
      intfList = getIntfListFromMode( mode.parent_ )
      for intf in intfList:
         timestamp = Tac.now()
         intfConfig = qosInputConfig.intfConfig.get( intf )
         if intfConfig is None:
            continue
         txQueueConfig = intfConfig.txQueueConfig.get( mode.txQueue )
         if txQueueConfig is None:
            continue
         prevWeightConfig = txQueueConfig.weightConfig
         setTxQueueWeight( mode, txQueueConfig )
         setIntfConfig( intf )
         setWeightCliBlockingFail( mode, timestamp, tacFeatureName.weight,
                                   "WEIGHT", intf, prevWeightConfig )

# -----------------------------------------------------------------------------------
# Handler of RandomDetectNonEctCmd
# -----------------------------------------------------------------------------------
def setNonEctParams( mode, args ):
   weight = args.get( 'WEIGHT', defaultWeight )
   minThd, maxThd, unit = args[ 'THD_RANGE' ]

   msg = "Cannot apply Non ECT configuration. WRED configuration is already present."
   if isinstance( mode.parent_, QosProfileMode ):
      profile = mode.parent_.qosProfileModeContext.currentEntry_
      txQueueConfig = profile.txQueueConfig.get( mode.txQueue )
      if not txQueueConfig:
         txQueueConfig = profile.txQueueConfig.newMember( mode.txQueue,
                         tacTxQueuePriority.priorityInvalid, tacPercent.invalid,
                         invalidShapeRate, invalidGuaranteedBw )

      # NON-ECT and WRED ( DP WRED ) configs are mutually exclusive
      prevWredConfig = txQueueConfig.wredConfig
      prevDpWredConfig = ( txQueueConfig.dpConfig and
                           len( txQueueConfig.dpConfig.dpWredConfig ) != 0 )
      if prevWredConfig or prevDpWredConfig:
         mode.addError( msg )
         return
      setTxQueueNonEct( mode, txQueueConfig, False, minThd,
                        maxThd, unit, weight )
      setQosProfileConfig( profile )
   else:
      intfList = getIntfListFromMode( mode.parent_ )
      for intf in intfList:
         # we need a timestamp to pass as an argument to cliBlocking fail
         timestamp = Tac.now()
         intfConfig = qosInputConfig.intfConfig.get( intf )
         if intfConfig is None:
            intfConfig = qosInputConfig.intfConfig.newMember( intf )
            txQueueConfig = intfConfig.txQueueConfig.newMember(
               mode.txQueue, tacTxQueuePriority.priorityInvalid, tacPercent.invalid,
               invalidShapeRate, invalidGuaranteedBw )
         else:
            txQueueConfig = intfConfig.txQueueConfig.get( mode.txQueue )
            if txQueueConfig is None:
               txQueueConfig = intfConfig.txQueueConfig.newMember(
                  mode.txQueue, tacTxQueuePriority.priorityInvalid,
                  tacPercent.invalid, invalidShapeRate, invalidGuaranteedBw )

         prevNonEctConfig = txQueueConfig.nonEctConfig

         # NONECT and WRED ( DP WRED ) configs are mutual exclusive
         prevWredConfig = txQueueConfig.wredConfig
         prevDpWredConfig = ( txQueueConfig.dpConfig and
                           len( txQueueConfig.dpConfig.dpWredConfig ) != 0 )
         if prevWredConfig or prevDpWredConfig:
            mode.addError( msg )
            continue
         setTxQueueNonEct( mode, txQueueConfig, False, minThd,
                           maxThd, unit, weight )
         setIntfConfig( intf )

         if cliBlockingFail( mode, timestamp, tacFeatureName.ecn,
                             "NON-ECT parameters", intf ):
            intfConfig = qosInputConfig.intfConfig.newMember( intf )
            txQueueConfig = intfConfig.txQueueConfig.get( mode.txQueue )
            if not txQueueConfig:
               txQueueConfig = intfConfig.txQueueConfig.newMember(
                  mode.txQueue, tacTxQueuePriority.priorityInvalid,
                  tacPercent.invalid, invalidShapeRate, invalidGuaranteedBw )
            if prevNonEctConfig:
               txQueueConfig.nonEctConfig = ( prevNonEctConfig.minThd,
                                              prevNonEctConfig.maxThd,
                                              prevNonEctConfig.unit,
                                              ecnMaxDroprate,
                                              prevNonEctConfig.weight )
            else:
               txQueueConfig.nonEctConfig = None
            setIntfConfig( intf )

# -----------------------------------------------------------------------------------
# No or default handler of RandomDetectNonEctCmd
# -----------------------------------------------------------------------------------
def setNoOrDefaultNonEctParams( mode, args ):
   if isinstance( mode.parent_, QosProfileMode ):
      profile = mode.parent_.qosProfileModeContext.currentEntry_
      txQueueConfig = profile.txQueueConfig.get( mode.txQueue )
      if not txQueueConfig:
         return

      setTxQueueNonEct( mode, txQueueConfig )
      setQosProfileConfig( profile )
   else:
      intfList = getIntfListFromMode( mode.parent_ )
      for intf in intfList:
         # we need a timestamp to pass as an argument to cliBlocking fail
         timestamp = Tac.now()
         intfConfig = qosInputConfig.intfConfig.get( intf )
         if intfConfig is None:
            continue
         txQueueConfig = intfConfig.txQueueConfig.get( mode.txQueue )
         if txQueueConfig is None:
            continue

         prevNonEctConfig = txQueueConfig.nonEctConfig

         setTxQueueNonEct( mode, txQueueConfig )
         setIntfConfig( intf )

         if cliBlockingFail( mode, timestamp, tacFeatureName.ecn,
                             "NON-ECT parameters", intf ):
            intfConfig = qosInputConfig.intfConfig.newMember( intf )
            txQueueConfig = intfConfig.txQueueConfig.get( mode.txQueue )
            if not txQueueConfig:
               txQueueConfig = intfConfig.txQueueConfig.newMember(
                  mode.txQueue, tacTxQueuePriority.priorityInvalid,
                  tacPercent.invalid, invalidShapeRate, invalidGuaranteedBw )
            if prevNonEctConfig:
               txQueueConfig.nonEctConfig = ( prevNonEctConfig.minThd,
                                              prevNonEctConfig.maxThd,
                                              prevNonEctConfig.unit,
                                              ecnMaxDroprate,
                                              prevNonEctConfig.weight )
            else:
               txQueueConfig.nonEctConfig = None
            setIntfConfig( intf )

# -----------------------------------------------------------------------------------
# Handler of QueueSetIntfRangeModeletCmd
# -----------------------------------------------------------------------------------
def queueSetIntfRangeModeletCmdHandler( mode, args ):
   if 'tx-queue' in args:
      goToTxQueueRangeMode( mode, 'tx-queue',
                           list( args.get( 'TXQSET' ).values() ) )
   elif 'uc-tx-queue' in args:
      goToTxQueueRangeMode( mode, 'uc-tx-queue',
                           list( args.get( 'UCTXQSET' ).values() ) )
   elif 'mc-tx-queue' in args:
      goToTxQueueRangeMode( mode, 'mc-tx-queue',
                           list( args.get( 'MCTXQSET' ).values() ) )

# -----------------------------------------------------------------------------------
# No or default handler of QueueSetIntfRangeModeletCmd
# -----------------------------------------------------------------------------------
def queueSetIntfRangeModeletCmdNoOrDefaultHandler( mode, args ):
   if 'tx-queue' in args:
      setDefaultTxQRangeConfig( mode, True, 'tx-queue',
                               list( args.get( 'TXQSET' ).values() ) )
   elif 'uc-tx-queue' in args:
      setDefaultTxQRangeConfig( mode, True, 'uc-tx-queue',
                               list( args.get( 'UCTXQSET' ).values() ) )
   elif 'mc-tx-queue' in args:
      setDefaultTxQRangeConfig( mode, True, 'mc-tx-queue',
                               list( args.get( 'MCTXQSET' ).values() ) )

def setDefaultTxQConfig( mode, noOrDefaultKw, tokenQueueType, txQueueId ):
   txQueueConfig = None
   txQueue = Tac.Value( 'Qos::TxQueue' )
   txQueue.id = txQueueId
   if 'tx-queue' == tokenQueueType:
      txQueue.type = tacQueueType.unknown
   elif 'uc-tx-queue' == tokenQueueType:
      txQueue.type = tacQueueType.ucq
   elif 'mc-tx-queue' == tokenQueueType:
      txQueue.type = tacQueueType.mcq
   else:
      assert 0, 'Unknown queue type supplied'

   if isinstance( mode, QosProfileMode ):
      profile = mode.qosProfileModeContext.currentEntry()
      txQueueConfig = profile.txQueueConfig.get( txQueue )
   else:
      intf = mode.intf.name
      intfConfig = qosInputConfig.intfConfig.get( intf )
      if intfConfig:
         txQueueConfig = intfConfig.txQueueConfig.get( txQueue )
   if txQueueConfig:
      txQueueConfig.priority = tacTxQueuePriority.priorityInvalid
      txQueueConfig.shapeRate = invalidShapeRate
      txQueueConfig.guaranteedBw = invalidGuaranteedBw
      txQueueConfig.bandwidth = tacPercent.invalid
      txQueueConfig.bandwidthWeight = tacBwWeight.invalid
      txQueueConfig.ecnConfig = None
      txQueueConfig.bufferBasedDecnConfig = None
      txQueueConfig.ecnDelayConfig = None
      txQueueConfig.nonEctConfig = None
      txQueueConfig.delayEcnEnabled = False
      txQueueConfig.ecnDelayThreshold = tacEcnDelayThreshold
      txQueueConfig.wredConfig = None
      txQueueConfig.dpConfig = None
      txQueueConfig.dropThresholds.clear()
      txQueueConfig.latencyThreshold = tacLatencyThreshold()
      if isinstance( mode, QosProfileMode ):
         setQosProfileConfig( profile )
      else:
         setIntfConfig( intf )

   # Remove queue counter config
   if ( not isinstance( mode, QosProfileMode ) ) and \
      txQueue.type in ( tacQueueType.unknown, tacQueueType.ucq ) and \
      mode.intf.name.startswith( ethIntfPrefixes ):
      intf = mode.intf.name
      intfCounterCfg = qosInputConfig.ecnIntfCounterConfig.get( intf )
      if intfCounterCfg:
         txQCounterCfg = intfCounterCfg.ecnTxQueueCounterConfig.get( txQueueId )
         if txQCounterCfg:
            setEcnQueueCounterConfig( mode, intf, txQueueId, txQCounterCfg, False )

# Default function for tx-queue range feature
def setDefaultTxQRangeConfig( mode, noOrDefaultKw, tokenQueueType, txQueueIds ):
   for txQueueId in txQueueIds:
      # Iterate through all the tx-queues and set default for each.
      setDefaultTxQConfig( mode, noOrDefaultKw, tokenQueueType, txQueueId )

# -----------------------------------------------------------------------------------
# Handler of QueueSetSubIntfRangeModeletCmd
# -----------------------------------------------------------------------------------
def queueSetSubIntfRangeModeletCmdHandler( mode, args ):
   goToTxQueueRangeMode( mode, 'tx-queue', list( args.get( 'TXQSET' ).values() ) )

# -----------------------------------------------------------------------------------
# No or default handler of QueueSetSubIntfRangeModeletCmd
# -----------------------------------------------------------------------------------
def queueSetSubIntfRangeModeletCmdNoOrDefaultHandler( mode, args ):
   setDefaultTxQRangeConfig( mode, True, 'tx-queue',
                            list( args.get( 'TXQSET' ).values() ) )
