# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import AirStreamLib
import GnmiSetCliSession
import LazyMount
import QosTypes
import re
import Tac
import Tracing
from QosOpenConfigCommonLib import getInterfaceIdAfterValidation

t0 = Tracing.Handle( "OpenConfigQosSP" ).trace0
t2 = Tracing.Handle( "OpenConfigQosSP" ).trace2

def registerSchedulerPoliciesHandler( entMan ):
   # precommit handlers
   #    - Scheduler Policy
   #    - Interface Output Scheduler Policy
   class ToNativeSchedulerPoliciesSyncher( GnmiSetCliSession.PreCommitHandler ):
      externalPathList = [ 'qos/openconfig/config/queues',
                           'qos/openconfig/config/schedulerPolicies',
                           'qos/openconfig/config/interfaces' ]
      nativePathList = [ 'qos/input/config/cli',
                         'qos/profile',
                         'cli/config' ]

      @classmethod
      def run( cls, sessionName ):
         t0( "running SchedulerPolicies PreCommitHandler" )
         queues = AirStreamLib.getSessionEntity(
            entMan, sessionName, 'qos/openconfig/config/queues' )
         schedulerPolicies = AirStreamLib.getSessionEntity(
               entMan, sessionName, 'qos/openconfig/config/schedulerPolicies' )
         interfaces = AirStreamLib.getSessionEntity(
               entMan, sessionName, 'qos/openconfig/config/interfaces' )
         qosCliConfig = AirStreamLib.getSessionEntity(
               entMan, sessionName, 'qos/input/config/cli' )
         qosProfiles = AirStreamLib.getSessionEntity(
               entMan, sessionName, 'qos/profile' )
         qosProfileConfig = qosProfiles.config
         cliConfig = AirStreamLib.getSessionEntity(
               entMan, sessionName, 'cli/config' )
         qosHwStatus = LazyMount.mount( entMan, "qos/hardware/status/global",
               "Qos::HwStatus", "r" )
         queueNameToTxQueueIdMap = qosCliConfig.queueNameToTxQueueIdMap
         commentHelper = Tac.Type( "QosOc::CommentHelper" )
         maxQId = qosHwStatus.numTxQueueSupported - 1
         if qosHwStatus.tc7ToAnyTxQueueRestricted:
            maxQId = qosHwStatus.numTxQueueSupported - 2

         # local data
         rrQueueToWeightMap = {}
         # - schedulerPolicyConfigs[ schedulerPolicyName ] = scheduler
         # -- scheduler[ sequence ] = { "priority": "STRICT"
         #                              "one-rate-two-color": ShapingData
         #                              "inputs": schedulerInput }
         # -- ShapingData= { "cir" : value
         #                   "cir-pct" : percent value }
         # -- inputs[ inputId ] = { "queue": qName
         #                          "weight": weight }
         schedulerPolicyConfigs = {}
         # stored as intfSchedulerPolicyConfigs= { intfId : {qId1:True}, ... ] }
         intfSchedulerPolicyConfigs = {}

         def getTxQueueType():
            if qosHwStatus.numMulticastQueueSupported:
               return QosTypes.tacQueueType.ucq
            return QosTypes.tacQueueType.unknown

         def getShapeRateFromConfig( oneRateTwoColor ):
            t0( f"getShapeRateFromConfig oneRateTwoColor = {oneRateTwoColor} " )
            shapeRate = QosTypes.tacShapeRate()
            if not oneRateTwoColor:
               return shapeRate
            if oneRateTwoColor.config.cir:
               rate = int( oneRateTwoColor.config.cir / 1000 )
               if rate < qosHwStatus.minTxQueueShapeRate:
                  rate = qosHwStatus.minTxQueueShapeRate
               if rate > qosHwStatus.maxShapeRateKbps:
                  rate = qosHwStatus.maxShapeRateKbps
               shapeRate.rate = rate
               shapeRate.unit = QosTypes.tacShapeRateUnit.shapeRateKbps
            elif oneRateTwoColor.config.cirPct:
               # need to populate the rate as well as percent
               shapeRate.rate = oneRateTwoColor.config.cirPct
               shapeRate.percent = oneRateTwoColor.config.cirPct
            else:
               t0( f"Invalid oneRateTwoColor cir = {oneRateTwoColor.config.cir}"
                   f" cir-pct = {oneRateTwoColor.config.cirPct}" )
            t2( f"Return ShapeRate = {shapeRate}" )
            return shapeRate

         # updates the local data as -
         # rrQueueToWeightMap = { schedPolName :
         #                             { q1 : weight,
         #                               q2 : weight,
         #                               . . .
         #                               totalWeight: totalWeight } }
         def calculateBandwidthForSchedulerPolicies():
            t0( "calculateBandwidthForSchedulerPolicies" )
            if qosHwStatus.bandwidthWeightSupported:
               t2( "bandwidthWeightSupported is true, returning" )
               return
            for schedPolName, schedulerPolicy in \
                  schedulerPolicies.schedulerPolicy.items():
               rrQueueToWeightMap[ schedPolName ] = {}
               totalWeight = 0
               t2( f"adding RR queues for Policy = {schedPolName}" )
               if not schedulerPolicy.schedulers:
                  continue
               for scheduler in schedulerPolicy.schedulers.scheduler.values():
                  if scheduler.config.priority:
                     continue
                  if not scheduler.inputs:
                     continue
                  for schedulerInput in scheduler.inputs.input.values():
                     if not schedulerInput.config.weight:
                        continue
                     qName = schedulerInput.config.q
                     weight = schedulerInput.config.weight
                     t2( f" - qName = {qName}, weight = {weight}" )
                     rrQueueToWeightMap[ schedPolName ][ qName ] = weight
                     totalWeight = totalWeight + weight
               t2( f" - totalWeight = {totalWeight} " )
               rrQueueToWeightMap[ schedPolName ][ "totalWeight" ] = totalWeight

         def getBandwidthForQueue( schedulerPolicyName, qName ):
            t0( f"getBandwidthForQueue qName = {qName}"
                f" SchedulerPolicy = {schedulerPolicyName} " )
            bandwidth = QosTypes.tacPercent.invalid
            totalWeight = rrQueueToWeightMap[
                  schedulerPolicyName ].get( "totalWeight" )
            if totalWeight:
               if totalWeight <= 100:
                  weight = rrQueueToWeightMap[ schedulerPolicyName ].get( qName )
                  if weight:
                     bandwidth = weight
               else:
                  weight = rrQueueToWeightMap[ schedulerPolicyName ].get( qName )
                  if weight:
                     thousandth = weight * 1000 // totalWeight
                     bandwidth = thousandth // 10 + (
                           1 if thousandth % 10 >= 5 else 0 )
                  if not bandwidth:
                     bandwidth = QosTypes.tacPercent.invalid
            t2( f"Return bandwidth = {bandwidth}" )
            return bandwidth

         def encodeProfileName( schedulerPolicyName ):
            return f"__YANG_SP_[{schedulerPolicyName}]"

         def decodeProfileName( profileName ):
            schedulerPolicyName = None
            result = re.search( r"__YANG_SP_\[(.*)\]", profileName )
            if result:
               schedulerPolicyName = result.groups()[ 0 ]
            return schedulerPolicyName

         def updateQosProfileData( schedulerPolicyName ):
            t0( f"updateQosProfileData for {schedulerPolicyName}" )
            profileName = encodeProfileName( schedulerPolicyName )
            # create an empty qos profile
            qosProfile = qosProfileConfig.newMember( profileName )
            qosProfile.qosProfile = ( profileName, )
            # create commentInfo
            schedPol = schedulerPolicyConfigs[ schedPolName ]
            sPProfileInfo = Tac.newInstance(
                  "QosOc::SchedulerPolicyProfileCommentInfo" )
            sPProfileInfo.name = schedulerPolicyName
            for seq, sched in schedPol.items():
               schedInfo = Tac.newInstance( "QosOc::SchedulerCommentInfo", seq )
               if "priority" in sched:
                  schedInfo.priority = Tac.Type(
                        "QosOc::SchedulerPriorityCommentInfo" ).STRICT
               if "one-rate-two-color" in sched:
                  oneRateTwoColorInfo = Tac.newInstance(
                        "QosOc::OneRateTwoColorCommentInfo" )
                  if sched[ "one-rate-two-color" ].get( "cir" ):
                     oneRateTwoColorInfo.cir = sched[ "one-rate-two-color" ][ "cir" ]
                  if sched[ "one-rate-two-color" ].get( "cir-pct" ):
                     oneRateTwoColorInfo.cirPct = sched[
                           "one-rate-two-color" ][ "cir-pct" ]
                  schedInfo.oneRateTwoColor = oneRateTwoColorInfo
               if "inputs" in sched:
                  for inId, schedInput in sched[ "inputs" ].items():
                     schedInputInfo = Tac.newInstance(
                        "QosOc::SchedulerInputCommentInfo", inId )
                     if "queue" in schedInput:
                        schedInputInfo.q = schedInput[ "queue" ]
                     if "weight" in schedInput:
                        schedInputInfo.weight = schedInput[ "weight" ]
                     schedInfo.inputs.addMember( schedInputInfo )
               sPProfileInfo.schedulers.addMember( schedInfo )
            comment = commentHelper.createSchedulerPolicyProfileComment(
                  sPProfileInfo )
            profileCommentKey = commentHelper.profileCommentKey( profileName )
            cliConfig.comment[ profileCommentKey ] = comment

         def updateTxQueueConfig(
               interfaceId, qId, priority, bandwidth, bwWeight, shapeRate ):
            t0( f"updateTxQueueConfig intf = {interfaceId} queue = {qId} priority ="
                f" {priority} bandwidth = {bandwidth} bwWeight = {bwWeight}"
                f" shapeRate = {shapeRate}" )
            # only one of bwPercent and bwWeight should be programmed
            assert ( bandwidth == QosTypes.tacPercent.invalid ) or \
                   ( bwWeight == QosTypes.tacBwWeight.invalid )
            interface = Tac.newInstance( "Arnet::IntfId", interfaceId )
            if interface not in qosCliConfig.intfConfig.keys():
               qosCliConfig.intfConfig.newMember( interface )
            intfConfig = qosCliConfig.intfConfig[ interface ]
            txQueue = Tac.newInstance( "Qos::TxQueue" )
            txQueue.id = qId
            txQueue.type = getTxQueueType()
            if txQueue not in intfConfig.txQueueConfig.keys():
               intfConfig.txQueueConfig.newMember( txQueue, priority,
                     bandwidth, shapeRate, QosTypes.tacGuaranteedBw() )
               intfConfig.txQueueConfig[ txQueue ].bandwidthWeight = bwWeight
            else:
               txQueueConfig = intfConfig.txQueueConfig[ txQueue ]
               txQueueConfig.priority = priority
               txQueueConfig.bandwidth = bandwidth
               txQueueConfig.shapeRate = shapeRate
               txQueueConfig.bandwidthWeight = bwWeight

         def updateCliComment( interfaceId, qId, sPInfo ):
            t0( f"updateCliComment intf = {interfaceId} queue = {qId}"
                f"SPInfo = {sPInfo}" )
            interface = Tac.newInstance( "Arnet::IntfId", interfaceId )
            txQueue = Tac.newInstance( "Qos::TxQueue" )
            txQueue.id = qId
            txQueue.type = getTxQueueType()
            commentKey = commentHelper.commentKey( interface, txQueue )
            oldComment = cliConfig.comment.get( commentKey, "" )
            newComment = commentHelper.createSchedulerPolicyComment(
                  sPInfo, oldComment )
            cliConfig.comment[ commentKey ] = newComment

         def updateIntfSchedulerPolicyCliComment( interfaceId, schedPolInfo ):
            t0( f"updateIntfSchedulerPolicyCliComment intf = {interfaceId} "
                f"schedPolInfo = {schedPolInfo}" )
            interface = Tac.newInstance( "Arnet::IntfId", interfaceId )
            intfCommentKey = commentHelper.intfCommentKey( interface )
            comment = commentHelper.createIntfSchedulerPolicyComment( schedPolInfo )
            cliConfig.comment[ intfCommentKey ] = comment

         def cleanupStaleConfigforScedulerPolicies():
            t0( "cleanupStaleConfigforScedulerPolicies" )
            # qos profile portion
            for profileName in qosProfileConfig:
               schedPolName = decodeProfileName( profileName )
               if schedPolName and schedPolName not in schedulerPolicyConfigs:
                  t0( f"deleting profile {profileName}" )
                  del qosProfileConfig[ profileName ]
                  profileCommentKey = commentHelper.profileCommentKey( profileName )
                  del cliConfig.comment[ profileCommentKey ]
            # interface attachment portion
            intfConfigs = qosCliConfig.intfConfig
            for intfId, intfConfig in intfConfigs.items():
               intfCommentKey = commentHelper.intfCommentKey( intfId )
               if intfCommentKey not in cliConfig.comment:
                  continue
               intfSchedPolComment = cliConfig.comment[ intfCommentKey ]
               intfSchedPolInfo = commentHelper.intfSchedulerPolicyInfoFromComment(
                     intfSchedPolComment )
               if not intfSchedPolInfo:
                  continue
               if intfId not in intfSchedulerPolicyConfigs:
                  t2( f"delete SchedulerPolicyConfig for {intfId}" )
                  del cliConfig.comment[ intfCommentKey ]

               for txQueue, txQueueConfig in intfConfig.txQueueConfig.items():
                  commentKey = commentHelper.commentKey( intfId, txQueue )
                  if commentKey not in cliConfig.comment:
                     continue
                  comment = cliConfig.comment[ commentKey ]
                  sPCommentInfo = \
                        commentHelper.schedulerPolicyInfoFromComment( comment )
                  if not sPCommentInfo:
                     continue
                  if intfId not in intfSchedulerPolicyConfigs or \
                        txQueue.id not in intfSchedulerPolicyConfigs[ intfId ]:
                     t2( f"delete SchedulerPolicy config for {intfId} "
                         f"txQueue {txQueue.id}" )
                     newComment = commentHelper.deleteSchedulerPolicyComment(
                           comment )
                     txQueueConfig.priority = \
                           QosTypes.tacTxQueuePriority.priorityInvalid
                     txQueueConfig.bandwidth = QosTypes.tacPercent.invalid
                     txQueueConfig.shapeRate = QosTypes.tacShapeRate()
                     if newComment:
                        cliConfig.comment[ commentKey ] = newComment
                     else:
                        del cliConfig.comment[ commentKey ]
                     # add cleaned-up queues also to local data
                     intfSchedulerPolicyConfigs.setdefault( intfId, {} ).update(
                           { txQueue.id: True } )

         # starting pre-Commit handler
         calculateBandwidthForSchedulerPolicies()

         # scheduler policy portion
         # pylint: disable-msg=too-many-nested-blocks
         for schedPolName, schedPol in schedulerPolicies.schedulerPolicy.items():
            schedulerPolicyConfigs[ schedPolName ] = {}
            if schedPol.schedulers:
               for seq, sched in schedPol.schedulers.scheduler.items():
                  schedulerPolicyConfigs[ schedPolName ][ seq ] = {}
                  if sched.config.priority:
                     schedulerPolicyConfigs[ schedPolName ][ seq ][ "priority" ] = \
                           "STRICT"
                  if sched.oneRateTwoColor:
                     shapingData = {}
                     if sched.oneRateTwoColor.config.cir:
                        shapingData[ "cir" ] = sched.oneRateTwoColor.config.cir
                     if sched.oneRateTwoColor.config.cirPct:
                        shapingData[ "cir-pct" ] = \
                              sched.oneRateTwoColor.config.cirPct
                     schedulerPolicyConfigs[ schedPolName ][ seq ][
                           "one-rate-two-color" ] = shapingData
                  if sched.inputs:
                     schedulerPolicyConfigs[ schedPolName ][ seq ][ "inputs" ] = {}
                     for inId, schedInput in sched.inputs.input.items():
                        schedInputData = {}
                        if schedInput.config.q:
                           schedInputData[ "queue" ] = schedInput.config.q
                        if schedInput.config.weight:
                           schedInputData[ "weight" ] = schedInput.config.weight
                        schedulerPolicyConfigs[ schedPolName ][ seq ][
                              "inputs" ][ inId ] = schedInputData
            updateQosProfileData( schedPolName )

         # interface attachment portion
         for interfaceId, interface in interfaces.interface.items():
            interfaceId = getInterfaceIdAfterValidation( interfaces, interfaceId )
            t2( f"interface {interfaceId}" )
            if not interface.output:
               continue
            if not interface.output.schedulerPolicy:
               continue
            schedulerPolicy = schedulerPolicies.schedulerPolicy.get(
                  interface.output.schedulerPolicy.config.name )
            if not schedulerPolicy:
               t0( f"for interface: {interfaceId} attached scheduler-policy "
                   f"not created: {interface.output.schedulerPolicy.config.name}" )
               continue
            if not schedulerPolicy.schedulers:
               continue

            sPInfo = Tac.Value( "QosOc::SchedulerPolicyCommentInfo" )
            sPInfo.schedulerPolicy = schedulerPolicy.config.name
            updateIntfSchedulerPolicyCliComment( interfaceId, sPInfo )
            intfSchedulerPolicyConfigs.setdefault( interfaceId, {} )

            t2( f"attached scheduler Policy = {schedulerPolicy.config.name}" )
            for scheduler in schedulerPolicy.schedulers.scheduler.values():
               if not scheduler.inputs:
                  continue
               for schedulerInput in scheduler.inputs.input.values():
                  qId = queueNameToTxQueueIdMap.queueNameToTxQueueId.get(
                        schedulerInput.config.q )
                  if qId is None or int( qId ) > maxQId:
                     continue
                  priority = QosTypes.tacTxQueuePriority.priorityRoundRobin
                  if scheduler.config.priority:
                     # use priority invalid in place of priority strict
                     priority = QosTypes.tacTxQueuePriority.priorityInvalid
                  shapeRate = getShapeRateFromConfig( scheduler.oneRateTwoColor )
                  bandwidth = QosTypes.tacPercent.invalid
                  bwWeight = QosTypes.tacBwWeight.invalid
                  if qosHwStatus.bandwidthWeightSupported:
                     if not scheduler.config.priority and \
                           schedulerInput.config.weight:
                        bwWeight = schedulerInput.config.weight
                  else:
                     bandwidth = getBandwidthForQueue(
                        schedulerPolicy.name, schedulerInput.config.q )
                  updateCliComment( interfaceId, qId, sPInfo )
                  updateTxQueueConfig(
                     interfaceId, qId, priority, bandwidth, bwWeight, shapeRate )
                  intfSchedulerPolicyConfigs[ interfaceId ].update(
                        { qId: True } )

         # clean stale entries
         cleanupStaleConfigforScedulerPolicies()

         # queue preCommit handler
         def cleanupStaleConfigForQueues():
            intfConfigs = qosCliConfig.intfConfig
            for intfId, intfConfig in intfConfigs.items():
               txqConfigs = intfConfig.txQueueConfig
               for txq, txqConfig in txqConfigs.items():
                  if intfId not in intfSchedulerPolicyConfigs or \
                     txq.id not in intfSchedulerPolicyConfigs[ intfId ]:
                     t0( "setting default scheduling/bandwidth values"
                         f"for intf: {intfId}, TxQ: {txq.id}" )
                     txqConfig.priority = \
                           QosTypes.tacTxQueuePriority.priorityInvalid
                     txqConfig.bandwidth = QosTypes.tacPercent()

         def txQueueType( qType ):
            if qType == 'ucq':
               return QosTypes.tacQueueType.ucq
            if qType == 'mcq':
               return QosTypes.tacQueueType.mcq
            assert qType == 'unknown'
            return QosTypes.tacQueueType.unknown

         t0( "running augmented queue bandwidth/scheduling handler" )
         for name, queue in queues.q.items():
            matchRegEx = r"((?:Ethernet\d+(?:/\d+)*)|(?:Port-Channel\d+))" + \
                  r"(?:-(ucq|mcq|unknown))?-(\d+)"
            match = re.search( matchRegEx, name )
            if not match:
               # process only those queues which follow augmentation format
               # e.g. qName = 'Ethernet1/2/3-unknown-6'
               # qMatch[ 0 ] = 'Ethernet1/2/3-unknown-6'
               # qMatch[ 1 ] = 'Ethernet1/2/3' ( name of the interface )
               # qMatch[ 2 ] = 'unknown' ( type of the queue )
               # qMatch[ 3 ] = '6' ( txQueueId )
               continue
            ( intfName, qType, qId ) = match.groups()
            if intfName in intfSchedulerPolicyConfigs and \
               int( qId ) in intfSchedulerPolicyConfigs[ intfName ]:
               t2( f"skipping the queue handler for intf: {intfName}, qId: {qId}" )
               continue
            txQueue = Tac.newInstance( "Qos::TxQueue" )
            txQueue.id = QosTypes.tacTxQueueId( int( qId ) )
            txQueue.type = txQueueType( qType )
            if txQueue.id > maxQId:
               continue
            t2( f"processing queue: {name}, value: {queue}" )
            intfId = Tac.newInstance( "Arnet::IntfId", intfName )
            intfConfig = qosCliConfig.intfConfig.get( intfId )
            if not intfConfig:
               intfConfig = qosCliConfig.intfConfig.newMember( intfId )
            txQueueConfig = intfConfig.txQueueConfig.get( txQueue )
            if not txQueueConfig:
               txQueueConfig = intfConfig.txQueueConfig.newMember(
                  txQueue, QosTypes.tacTxQueuePriority.priorityInvalid,
                  QosTypes.tacPercent(), QosTypes.tacShapeRate(),
                  QosTypes.tacGuaranteedBw() )
            if queue.config.scheduling == "STRICT":
               priority = QosTypes.tacTxQueuePriority.priorityStrict
            elif queue.config.scheduling == "ROUND_ROBIN":
               priority = QosTypes.tacTxQueuePriority.priorityRoundRobin
            bandwidth = queue.config.bandwidth if queue.config.bandwidth else \
                  QosTypes.tacPercent()
            t0( f"updating scheduling: {priority} bandwidth: {bandwidth}" )
            txQueueConfig.priority = priority
            txQueueConfig.bandwidth = bandwidth
            intfSchedulerPolicyConfigs.setdefault( intfName, {} ).update(
                  { txQueue.id: True } )

         # remove stale config
         cleanupStaleConfigForQueues()

   GnmiSetCliSession.registerPreCommitHandler( ToNativeSchedulerPoliciesSyncher )
