#!/usr/bin/env python3
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import CliCommand
import Tac
import Tracing
from CliPlugin.QosCli import ( qosInputConfig, invalidGuaranteedBw )
from CliPlugin.QosCliCommon import ( getIntfListFromMode, setIntfConfig,
                                     QosProfileMode, setQosProfileConfig )
from CliPlugin.QosCliScheduling import QosSchedulingPolicyMode
from QosTypes import ( tacShapeRateVal,
                       tacShapeRateUnit,
                       tacTxQueuePriority,
                       tacPercent )

__defaultTraceHandle__ = Tracing.Handle( 'QosCliShapeHandler' )
t0 = Tracing.trace0
t1 = Tracing.trace1
t8 = Tracing.trace8

def setPortShapeRate( mode, noOrDefaultKw=None, shapeRate=None,
                      shapeRateUnit=None, shapeRateShared=False,
                      burstSizeVal=None, burstSizeUnit=None ):
   shapeRatePercent = Tac.Value( 'Qos::Percent' )
   if isinstance( mode, QosProfileMode ):
      profile = mode.qosProfileModeContext.currentEntry_
      if noOrDefaultKw:
         shapeRate = tacShapeRateVal.invalid
         shapeRateUnit = tacShapeRateUnit.shapeRateKbps
         shapeRateShared = False
         burstSizeVal = None
         burstSizeUnit = None
      if shapeRateUnit == 'pps':
         shapeRateUnit = tacShapeRateUnit.shapeRatePps
      elif shapeRateUnit == 'percent':
         # Configured value is percent of bandwidth in kbps
         shapeRatePercent = Tac.Value( 'Qos::Percent', shapeRate )
         shapeRateUnit = tacShapeRateUnit.shapeRateKbps
      else:
         shapeRateUnit = tacShapeRateUnit.shapeRateKbps

      tempShapeRate = Tac.Value( 'Qos::ShapeRate' )
      tempShapeRate.rate = shapeRate
      tempShapeRate.unit = shapeRateUnit
      tempShapeRate.shared = shapeRateShared
      tempShapeRate.percent = shapeRatePercent
      burstSize = Tac.Value( 'Qos::BurstSize' )
      if burstSizeVal:
         burstSize.value = burstSizeVal
      if burstSizeUnit:
         burstSize.unit = burstSizeUnit
      tempShapeRate.burstSize = burstSize
      profile.shapeRate = tempShapeRate
   else:
      if noOrDefaultKw:
         shapeRate = tacShapeRateVal.invalid
         cfgShapeRateUnit = tacShapeRateUnit.shapeRateKbps
         shapeRateShared = False
         burstSizeVal = None
         burstSizeUnit = None
      if shapeRateUnit == 'pps':
         cfgShapeRateUnit = tacShapeRateUnit.shapeRatePps
      elif shapeRateUnit == 'percent':
         # Configured value is percent of bandwidth in kbps
         shapeRatePercent = Tac.Value( 'Qos::Percent', shapeRate )
         cfgShapeRateUnit = tacShapeRateUnit.shapeRateKbps
      else:
         cfgShapeRateUnit = tacShapeRateUnit.shapeRateKbps
      setIntfConfig( mode.intf.name, cfgShapeRate=shapeRate,
                            cfgShapeRateUnit=cfgShapeRateUnit,
                            cfgShapeRateShared=shapeRateShared,
                            cfgShapeRatePercent=shapeRatePercent,
                            burstSizeVal=burstSizeVal,
                            burstSizeUnit=burstSizeUnit )

# --------------------------------------------------------------------------------
# Handler of SetPortRateCmd
# --------------------------------------------------------------------------------
def setPortRateCmdHandler( mode, args ):
   if CliCommand.isNoOrDefaultCmd( args ):
      setPortShapeRate( mode, True )
   else:
      ( burstSizeUnit, burstSizeVal ) = args.get( 'BURST', ( None, None ) )
      setPortShapeRate( mode, None, args[ 'RATE' ], args[ 'UNIT' ],
                       burstSizeVal=burstSizeVal, burstSizeUnit=burstSizeUnit )

def setShapeRate( mode, noOrDefaultKw=None, shapeRate=None,
                  shapeRateUnit=False, burstSizeVal=None, burstSizeUnit=None ):
   cfgShapeRate = Tac.Value( 'Qos::ShapeRate' )
   burstSize = Tac.Value( 'Qos::BurstSize' )
   if noOrDefaultKw:
      shapeRate = tacShapeRateVal.invalid
      burstSizeVal = None
      burstSizeUnit = None
   cfgShapeRate.rate = shapeRate
   if shapeRateUnit == 'pps':
      cfgShapeRate.unit = tacShapeRateUnit.shapeRatePps
   elif shapeRateUnit == 'percent':
      # Configured value is percent of bandwidth in kbps
      cfgShapeRate.percent = Tac.Value( 'Qos::Percent', shapeRate )
      cfgShapeRate.unit = tacShapeRateUnit.shapeRateKbps
   else:
      cfgShapeRate.unit = tacShapeRateUnit.shapeRateKbps

   if burstSizeVal:
      burstSize.value = burstSizeVal
   if burstSizeUnit:
      burstSize.unit = burstSizeUnit
   cfgShapeRate.burstSize = burstSize

   if isinstance( mode.parent_, QosProfileMode ):
      profile = mode.parent_.qosProfileModeContext.currentEntry_
      txQueueConfig = profile.txQueueConfig.get( mode.txQueue )
      if not txQueueConfig:
         if noOrDefaultKw:
            return
         txQueueConfig = profile.txQueueConfig.newMember(
               mode.txQueue, tacTxQueuePriority.priorityInvalid, tacPercent.invalid,
               cfgShapeRate, invalidGuaranteedBw )
      else:
         txQueueConfig.shapeRate = cfgShapeRate
      setQosProfileConfig( profile )
   elif isinstance( mode, QosSchedulingPolicyMode ):
      policy = mode.qosSchedulingPolicyModeContext.currentEntry_
      policy.shapeRate = cfgShapeRate
   else:
      intfList = getIntfListFromMode( mode.parent_ )
      for intf in intfList:
         intfConfig = qosInputConfig.intfConfig.get( intf )
         if intfConfig is None:
            if shapeRate == tacShapeRateVal.invalid:
               continue
            intfConfig = qosInputConfig.intfConfig.newMember( intf )
            txQueueConfig = intfConfig.txQueueConfig.newMember(
               mode.txQueue, tacTxQueuePriority.priorityInvalid, tacPercent.invalid,
               cfgShapeRate, invalidGuaranteedBw )
         else:
            txQueueConfig = intfConfig.txQueueConfig.get( mode.txQueue )
            if txQueueConfig is None:
               if shapeRate == tacShapeRateVal.invalid:
                  continue
               txQueueConfig = intfConfig.txQueueConfig.newMember(
                  mode.txQueue, tacTxQueuePriority.priorityInvalid,
                  tacPercent.invalid, cfgShapeRate, invalidGuaranteedBw )
            else:
               txQueueConfig.shapeRate = cfgShapeRate
         setIntfConfig( intf )

# --------------------------------------------------------------------------------
# Handler of SetShapeRateCmd
# --------------------------------------------------------------------------------
def setShapeRateCmdHandler( mode, args ):
   if CliCommand.isNoOrDefaultCmd( args ):
      setShapeRate( mode, True )
   else:
      ( burstSizeUnit, burstSizeVal ) = args.get( 'BURST', ( None, None ) )
      setShapeRate( mode, None, args[ 'RATE' ], args[ 'UNIT' ],
                   burstSizeVal=burstSizeVal, burstSizeUnit=burstSizeUnit )

# --------------------------------------------------------------------------------
# Handler of SetPortRateSharedCmd
# --------------------------------------------------------------------------------
def setPortRateSharedCmdHandler( mode, args ):
   ( burstSizeUnit, burstSizeVal ) = args.get( 'BURST', ( None, None ) )
   setPortShapeRate( mode, None, args[ 'RATE' ], args[ 'UNIT' ],
                    'shared' in args, burstSizeVal=burstSizeVal,
                    burstSizeUnit=burstSizeUnit )
