#!/usr/bin/env python3
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import CliCommand
import ConfigMount
import LazyMount
import Plugins
import Tac
from QosLib import isLagPort
from QosTypes import ( tacDropPrecedence, tacQueueThresholdUnit, tacPercent,
                       tacTxQueuePriority, tacFeatureName )
from CliPlugin import IntfCli
from CliPlugin import EthIntfCli
from CliPlugin.QosCliIntfTypes import ethOrLagIntfPrefixes
from CliPlugin.QosCliModel import ( WredIntfQueueCountersModel, TxQueueWredModel,
                                    IntfWredCollectionModel, IntfWredModel,
                                    WredParametersModel )
from CliPlugin.QosCli import ( invalidShapeRate, invalidGuaranteedBw )
from CliPlugin.QosCliCommon import ( getIntfListFromMode, setIntfConfig,
                                     cliBlockingFail, QosProfileMode,
                                     defaultWeight, setQosProfileConfig )
from CliPlugin.QosCliWred import showWredIntfQueueCountersHook

# -----------------------------------------------------------------------------------
# Variables for Qos Wred associated mount paths from Sysdb
# -----------------------------------------------------------------------------------
qosInputConfig = None
qosHwStatus = None
qosStatus = None

# -----------------------------------------------------------------------------------
# The "show qos interfaces [ INTF ] wred counters queue" command
# -----------------------------------------------------------------------------------
def showWredIntfQueueCounters( mode, args ):
   intf = args.get( 'INTF' )

   wredCounters = WredIntfQueueCountersModel()

   numUnguardedHooks = 0
   for func, guard in \
         showWredIntfQueueCountersHook.extensions():

      if guard( mode, None ) is None:
         wredCounters = func( mode, intf )
         numUnguardedHooks += 1

   assert numUnguardedHooks <= 1, "Found too many possible results"
   return wredCounters

# -----------------------------------------------------------------------------------
# The "show [ mls ] qos interfaces [ INTF ] random-detect drop" command, in
# "enable" mode.
# -----------------------------------------------------------------------------------
def showInterfacesWred( mode, args ):
   intf = args.get( 'INTF' )
   intfWredCollection = IntfWredCollectionModel()

   intfs = IntfCli.Intf.getAll( mode, intf, None, intfType=EthIntfCli.EthIntf )
   if not intfs:
      return intfWredCollection
   for intf in intfs:
      if not intf.name.startswith( ethOrLagIntfPrefixes ):
         continue
      intfWred = showWredInterface( intf )
      intfWredCollection.intfWredCollection[ intf.name ] = intfWred

   return intfWredCollection

def showWredInterface( intf ):

   intfStatus = qosStatus.intfStatus.get( intf.name )
   if intfStatus:
      intfStatus = qosStatus.intfStatus[ intf.name ]

   intfWredModel = IntfWredModel()

   if 0 == qosHwStatus.numTxQueueSupported:
      return intfWredModel

   for hwtxqid in range( qosHwStatus.numTxQueueSupported - 1, -1, -1 ):
      clitxq = qosHwStatus.hwTxQueue[ hwtxqid ].txQueue
      if qosHwStatus.numMulticastQueueSupported != 0:
         prefix = clitxq.type[ : 2 ].upper()
         txQName = prefix + str( clitxq.id )
         if prefix != 'UC':
            continue
      else:
         txQName = str( clitxq.id )

      txQueueWred = TxQueueWredModel()
      txQueueWred.txQueue = txQName
      txQueueWred.wredWeightSupported = qosHwStatus.wredWeightSupported or \
                                        qosHwStatus.queueWeightSupported
      txQueueWred.dpWredSupported = qosHwStatus.dpWredSupported

      wredSetting = None
      dpWredSetting = None
      queueWeight = defaultWeight
      if isLagPort( intf.name ):  # For Lag read from config
         if intf.name in qosInputConfig.intfConfig:
            intfConfig = qosInputConfig.intfConfig[ intf.name ]
            txQueueConfig = intfConfig.txQueueConfig.get( clitxq )
            if txQueueConfig and txQueueConfig.weightConfig:
               queueWeight = txQueueConfig.weightConfig
            if txQueueConfig and txQueueConfig.dpConfig:
               dpWredSetting = txQueueConfig.dpConfig.dpWredConfig
      elif intfStatus:
         txQueueStatus = intfStatus.txQueueStatus.get( clitxq )
         if txQueueStatus and txQueueStatus.weightStatus:
            queueWeight = txQueueStatus.weightStatus
         if txQueueStatus and txQueueStatus.dpStatus:
            dpWredSetting = txQueueStatus.dpStatus.dpWredStatus

      qosShowCommandHelper = Tac.Type( "Qos::QosShowCommandHelper" )
      wredSetting = qosShowCommandHelper.getWredStatus(
         qosInputConfig, qosStatus, intf.name, clitxq )

      if wredSetting:
         txQueueWred.txQueueWredParameters = WredParametersModel()
         txQueueWred.txQueueWredParameters.minThreshold = wredSetting.minThd
         txQueueWred.txQueueWredParameters.maxThreshold = wredSetting.maxThd
         txQueueWred.txQueueWredParameters.unit = wredSetting.unit
         txQueueWred.txQueueWredParameters.maxDroprate = wredSetting.maxDroprate
         txQueueWred.txQueueWredParameters.weight = wredSetting.weight
         if queueWeight != defaultWeight:
            txQueueWred.txQueueWredParameters.weight = queueWeight
         txQueueWred.txQueueWredParameters.dp = None
         if wredSetting.unit == tacQueueThresholdUnit.segments:
            txQueueWred.txQueueWredParameters.segmentSizeInBytes = \
                qosHwStatus.ecnSegmentSizeInBytes

      intfWredModel.txQueueWredList.append( txQueueWred )

      if dpWredSetting:
         for dp in range( tacDropPrecedence.max ):
            if dp not in dpWredSetting:
               continue
            dpWredParam = dpWredSetting[ dp ]
            txQueueDpWred = TxQueueWredModel()
            txQueueDpWred.txQueue = txQName
            txQueueDpWred.wredWeightSupported = qosHwStatus.wredWeightSupported
            txQueueDpWred.dpWredSupported = qosHwStatus.dpWredSupported
            txQueueDpWred.txQueueWredParameters = WredParametersModel()
            txQueueDpWred.txQueueWredParameters.minThreshold = dpWredParam.minThd
            txQueueDpWred.txQueueWredParameters.maxThreshold = dpWredParam.maxThd
            txQueueDpWred.txQueueWredParameters.unit = dpWredParam.unit
            txQueueDpWred.txQueueWredParameters.maxDroprate = dpWredParam.maxDroprate
            txQueueDpWred.txQueueWredParameters.weight = dpWredParam.weight
            if queueWeight != defaultWeight:
               txQueueDpWred.txQueueWredParameters.weight = queueWeight
            txQueueDpWred.txQueueWredParameters.dp = dpWredParam.dp
            if dpWredParam.unit == tacQueueThresholdUnit.segments:
               txQueueDpWred.txQueueWredParameters.segmentSizeInBytes = \
                  qosHwStatus.ecnSegmentSizeInBytes
            intfWredModel.txQueueWredList.append( txQueueDpWred )

   return intfWredModel

# -----------------------------------------------------------------------------------
# The "[ mls ] qos random-detect drop  minimum-threshold <min_threashold>
# maximum-threshold <max_threshold> drop-probability <maxDroprate> "
# command, in "uc-tx-queue" mode.
# -----------------------------------------------------------------------------------
def setTxQueueDpWred( txQueueConfig, dp, noOrDefaultKw=True, minThd=None,
                      maxThd=None, wredUnit=None, maxDroprate=100,
                      weight=defaultWeight ):
   prevDpWredConfig = None
   if txQueueConfig.dpConfig and dp in txQueueConfig.dpConfig.dpWredConfig:
      prevDpWredConfig = txQueueConfig.dpConfig.dpWredConfig[ dp ]
   if noOrDefaultKw:
      if prevDpWredConfig:
         del txQueueConfig.dpConfig.dpWredConfig[ dp ]
   else:
      if txQueueConfig.dpConfig is None:
         txQueueConfig.dpConfig = ( "", )
      dpConfig = txQueueConfig.dpConfig
      # If there is existing config and it is not equal to the new config
      if not ( prevDpWredConfig and
               prevDpWredConfig.minThd == minThd and
               prevDpWredConfig.maxThd == maxThd and
               prevDpWredConfig.unit == wredUnit and
               prevDpWredConfig.maxDroprate == maxDroprate and
               prevDpWredConfig.weight == weight ):
         # To modify the existing config, del existing entry and add new entry
         if prevDpWredConfig:
            del txQueueConfig.dpConfig.dpWredConfig[ dp ]
         dpConfig.dpWredConfig.newMember( minThd, maxThd,
                                          wredUnit, maxDroprate, weight, dp )

def setNoOrDefaultWred( mode, args ):
   dp = args.get( 'DP', None )

   if isinstance( mode.parent_, QosProfileMode ):
      profile = mode.parent_.qosProfileModeContext.currentEntry_
      txQueueConfig = profile.txQueueConfig.get( mode.txQueue )
      if not txQueueConfig:
         return
      if dp is not None:
         setTxQueueDpWred( txQueueConfig, dp )
      else:
         setTxQueueWred( mode, txQueueConfig )
      setQosProfileConfig( profile )
   else:
      intfList = getIntfListFromMode( mode.parent_ )
      for intf in intfList:
         timestamp = Tac.now()
         intfConfig = qosInputConfig.intfConfig.get( intf )
         if intfConfig is None:
            continue
         txQueueConfig = intfConfig.txQueueConfig.get( mode.txQueue )
         if txQueueConfig is None:
            continue

         prevWredConfig = txQueueConfig.wredConfig
         prevDpWredConfig = None
         if ( dp is not None and txQueueConfig.dpConfig and
               dp in txQueueConfig.dpConfig.dpWredConfig ):
            prevDpWredConfig = txQueueConfig.dpConfig.dpWredConfig[ dp ]

         if dp is not None:
            setTxQueueDpWred( txQueueConfig, dp )
         else:
            setTxQueueWred( mode, txQueueConfig )
         setIntfConfig( intf )

         if cliBlockingFail( mode, timestamp, tacFeatureName.wred, "WRED", intf ):
            intfConfig = qosInputConfig.intfConfig.newMember( intf )
            txQueueConfig = intfConfig.txQueueConfig.get( mode.txQueue )
            if not txQueueConfig:
               txQueueConfig = intfConfig.txQueueConfig.newMember(
                  mode.txQueue, tacTxQueuePriority.priorityInvalid,
                  tacPercent.invalid, invalidShapeRate, invalidGuaranteedBw )
            if dp is None:
               if prevWredConfig:
                  txQueueConfig.wredConfig = ( prevWredConfig.minThd,
                        prevWredConfig.maxThd, prevWredConfig.unit,
                        prevWredConfig.maxDroprate, prevWredConfig.weight )
               else:
                  txQueueConfig.wredConfig = None
            else:
               if prevDpWredConfig:
                  dpConfig = txQueueConfig.dpConfig
                  del txQueueConfig.dpConfig.dpWredConfig[ dp ]
                  dpConfig.dpWredConfig.newMember( prevDpWredConfig.minThd,
                                                   prevDpWredConfig.maxThd,
                                                   prevDpWredConfig.unit,
                                                   prevDpWredConfig.maxDroprate,
                                                   prevDpWredConfig.weight,
                                                   dp )
               else:
                  del txQueueConfig.dpConfig.dpWredConfig[ dp ]

            setIntfConfig( intf )

def setWred( mode, args ):
   msg = "Cannot apply WRED configuration. ECN configuration is already present."
   weight = args.get( 'WEIGHT', defaultWeight )
   dp = args.get( 'DP', None )
   maxDroprate = args[ 'MAX_DROPRATE' ]
   minThd, maxThd, wredUnit = args[ 'THD_RANGE' ]

   if isinstance( mode.parent_, QosProfileMode ):
      profile = mode.parent_.qosProfileModeContext.currentEntry_
      txQueueConfig = profile.txQueueConfig.get( mode.txQueue )
      if not txQueueConfig:
         txQueueConfig = profile.txQueueConfig.newMember(
               mode.txQueue, tacTxQueuePriority.priorityInvalid, tacPercent.invalid,
               invalidShapeRate, invalidGuaranteedBw )

      # ECN and WRED configs are mutually exclusive when hw capability
      # wredEcnMutuallyExclusive is true
      prevEcnConfig = txQueueConfig.ecnConfig
      prevGlobalEcnConfig = qosInputConfig.globalEcnConfig
      nonEctConfig = txQueueConfig.nonEctConfig
      # WRED and ECN may not be mutually exclusive ( depends on hw capability ), but
      # wred and ( global ecn config or non-ect config ) remain mutually exclusive
      if not qosHwStatus.hwInitialized or not qosHwStatus.wredEcnMutuallyExclusive:
         # Still write to wred input config if qosHwStatus.hwInitialized is False
         if prevGlobalEcnConfig or nonEctConfig:
            mode.addError( msg )
            return
      elif prevEcnConfig or prevGlobalEcnConfig or nonEctConfig:
         mode.addError( msg )
         return

      if dp is not None:
         setTxQueueDpWred( txQueueConfig, dp, False, minThd, maxThd, wredUnit,
                           maxDroprate, weight )
      else:
         setTxQueueWred( mode, txQueueConfig, False, minThd,
                         maxThd, wredUnit, maxDroprate, weight )
      setQosProfileConfig( profile )
   else:
      intfList = getIntfListFromMode( mode.parent_ )
      for intf in intfList:
         timestamp = Tac.now()
         intfConfig = qosInputConfig.intfConfig.get( intf )
         if intfConfig is None:
            intfConfig = qosInputConfig.intfConfig.newMember( intf )
            txQueueConfig = intfConfig.txQueueConfig.newMember(
               mode.txQueue, tacTxQueuePriority.priorityInvalid, tacPercent.invalid,
               invalidShapeRate, invalidGuaranteedBw )
         else:
            txQueueConfig = intfConfig.txQueueConfig.get( mode.txQueue )
            if txQueueConfig is None:
               txQueueConfig = intfConfig.txQueueConfig.newMember(
                  mode.txQueue, tacTxQueuePriority.priorityInvalid,
                  tacPercent.invalid, invalidShapeRate, invalidGuaranteedBw )

         prevWredConfig = txQueueConfig.wredConfig
         prevDpWredConfig = None
         if ( dp is not None and txQueueConfig.dpConfig and
               dp in txQueueConfig.dpConfig.dpWredConfig ):
            prevDpWredConfig = txQueueConfig.dpConfig.dpWredConfig[ dp ]

         # ECN and WRED ( or dpWred ) configs are mutually exclusive only when hw
         # wredEcnMutuallyExclusive is true
         prevEcnConfig = txQueueConfig.ecnConfig
         prevGlobalEcnConfig = qosInputConfig.globalEcnConfig
         nonEctConfig = txQueueConfig.nonEctConfig
         # WRED and ECN may not be mutually exclusive (depends on hw capability), but
         # wred and ( global ecn config or non-ect config ) remain mutually exclusive
         if not qosHwStatus.hwInitialized or \
            not qosHwStatus.wredEcnMutuallyExclusive:
            # Still write to wred input config if qosHwStatus.hwInitialized is False
            if prevGlobalEcnConfig or nonEctConfig:
               mode.addError( msg )
               continue
         elif prevEcnConfig or prevGlobalEcnConfig or nonEctConfig:
            mode.addError( msg )
            continue

         if dp is not None:
            setTxQueueDpWred( txQueueConfig, dp, False, minThd, maxThd, wredUnit,
                              maxDroprate, weight )
         else:
            setTxQueueWred( mode, txQueueConfig, False, minThd,
                            maxThd, wredUnit, maxDroprate, weight )
         setIntfConfig( intf )

         if cliBlockingFail( mode, timestamp, tacFeatureName.wred, "WRED", intf ):
            intfConfig = qosInputConfig.intfConfig.newMember( intf )
            txQueueConfig = intfConfig.txQueueConfig.get( mode.txQueue )
            if not txQueueConfig:
               txQueueConfig = intfConfig.txQueueConfig.newMember(
                  mode.txQueue, tacTxQueuePriority.priorityInvalid,
                  tacPercent.invalid, invalidShapeRate, invalidGuaranteedBw )
            if dp is None:
               if prevWredConfig:
                  txQueueConfig.wredConfig = ( prevWredConfig.minThd,
                        prevWredConfig.maxThd, prevWredConfig.unit,
                        prevWredConfig.maxDroprate, prevWredConfig.weight )
               else:
                  txQueueConfig.wredConfig = None
            else:
               if prevDpWredConfig:
                  dpConfig = txQueueConfig.dpConfig
                  del txQueueConfig.dpConfig.dpWredConfig[ dp ]
                  dpConfig.dpWredConfig.newMember( prevDpWredConfig.minThd,
                                                   prevDpWredConfig.maxThd,
                                                   prevDpWredConfig.unit,
                                                   prevDpWredConfig.maxDroprate,
                                                   prevDpWredConfig.weight,
                                                   dp )
               else:
                  del txQueueConfig.dpConfig.dpWredConfig[ dp ]

            setIntfConfig( intf )

def setTxQueueWred( mode, txQueueConfig, noOrDefaultKw=True, minThd=None,
                    maxThd=None, wredUnit=None,
                    maxDroprate=100, weight=defaultWeight ):
   prevWredConfig = txQueueConfig.wredConfig
   if noOrDefaultKw:
      if prevWredConfig:
         txQueueConfig.wredConfig = None
   else:
      if not ( prevWredConfig and
               prevWredConfig.minThd == minThd and
               prevWredConfig.maxThd == maxThd and
               prevWredConfig.unit == wredUnit and
               prevWredConfig.maxDroprate == maxDroprate and
               prevWredConfig.weight == weight ):
         txQueueConfig.wredConfig = ( minThd, maxThd,
                                      wredUnit, maxDroprate, weight )
   return txQueueConfig

# -----------------------------------------------------------------------------------
# The " [ no | default ] qos random-detect drop allow ect" command,
# in "global config" mode.
# -----------------------------------------------------------------------------------
def configureGlobalWredAllowEct( mode, args ):
   if CliCommand.isNoOrDefaultCmd( args ):
      qosInputConfig.wredAllowEct = False
   else:
      qosInputConfig.wredAllowEct = True

@Plugins.plugin( provides=( "QosCliWred", ) )
def Plugin( entityManager ):
   global qosHwStatus, qosStatus, qosInputConfig
   qosHwStatus = LazyMount.mount( entityManager, "qos/hardware/status/global",
                                  "Qos::HwStatus", "r" )
   qosStatus = LazyMount.mount( entityManager, "qos/status", "Qos::Status", "r" )
   qosInputConfig = ConfigMount.mount( entityManager, "qos/input/config/cli",
                                       "Qos::Input::Config", "w" )
