# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import AirStreamLib
import GnmiSetCliSession
import Tac
import Tracing
import re
import LazyMount
from QosTypes import (
      tacQueueType, tacTxQueuePriority,
      tacPercent, tacShapeRate, tacGuaranteedBw,
      tacQueueThresholdUnit, tacTxQueueId
      )
from QosOpenConfigCommonLib import getInterfaceIdAfterValidation

t0 = Tracing.Handle( "OpenConfigQosQMP" ).trace0

def registerQueueManagementProfilesHandler( entMan ):
   # precommit handlers
   #  -QueueManagementProfiles
   class ToNativeQueueManagementProfilesSyncher(
         GnmiSetCliSession.PreCommitHandler ):
      externalPathList = [ 'qos/openconfig/config/queues',
                           'qos/openconfig/config/interfaces',
                           'qos/openconfig/config/queueManagementProfiles' ]
      nativePathList = [ 'qos/input/config/cli',
                         'qos/profile',
                         'cli/config' ]

      @classmethod
      def run( cls, sessionName ):
         t0( "Running precommit handler for QueueManagementProfiles" )
         queues = AirStreamLib.getSessionEntity(
            entMan, sessionName, 'qos/openconfig/config/queues' )
         queueManagementProfiles = AirStreamLib.getSessionEntity(
            entMan, sessionName, 'qos/openconfig/config/queueManagementProfiles' )
         interfaces = AirStreamLib.getSessionEntity(
            entMan, sessionName, 'qos/openconfig/config/interfaces' )
         qosCliConfig = AirStreamLib.getSessionEntity(
               entMan, sessionName, 'qos/input/config/cli' )
         qosProfiles = AirStreamLib.getSessionEntity(
               entMan, sessionName, 'qos/profile' )
         qosProfileConfig = qosProfiles.config
         cliConfig = AirStreamLib.getSessionEntity(
            entMan, sessionName, 'cli/config' )

         commentHelper = Tac.Type( "QosOc::CommentHelper" )
         qosHwStatus = LazyMount.mount( entMan, "qos/hardware/status/global",
                                       "Qos::HwStatus", "r" )
         maxTxQ = qosHwStatus.numTxQueueSupported - 1
         if qosHwStatus.tc7ToAnyTxQueueRestricted:
            maxTxQ = qosHwStatus.numTxQueueSupported - 2

         queueNameToTxQueueIdMap = \
            qosCliConfig.queueNameToTxQueueIdMap.queueNameToTxQueueId

         # local data populated as
         # wredConfigs = { QMPName : wredConfig, ... }
         wredConfigs = {}
         # intfQueueWredConfig = { intfId : { qId : True}... }
         intfQueueWredConfigs = {}

         def encodeProfileName( qmpName ):
            return f"__YANG_QMP_[{qmpName}]"

         def decodeProfileName( profileName ):
            qmpName = None
            result = re.search( r"__YANG_QMP_\[(.*)\]", profileName )
            if result:
               qmpName = result.groups()[ 0 ]
            return qmpName

         def updateQosProfileData( qmpName ):
            t0( f"updateQosProfileData for QMP: {qmpName}" )
            profileName = encodeProfileName( qmpName )
            # create an empty qos profile
            qosProfile = qosProfileConfig.newMember( profileName )
            qosProfile.qosProfile = ( profileName, )
            # create tx-queue 0 config
            localWredConfig = wredConfigs[ qmpName ]
            txq = Tac.newInstance( "Qos::TxQueue" )
            txq.id = 0
            txq.type = getTxQueueType()
            txqConfig = qosProfile.qosProfile.txQueueConfig.get( txq, None )
            if not txqConfig:
               txqConfig = qosProfile.qosProfile.txQueueConfig.newMember(
                     txq, tacTxQueuePriority.priorityInvalid,
                     tacPercent(), tacShapeRate(), tacGuaranteedBw() )
            if qosHwStatus.queueWeightSupported:
               txqConfig.weightConfig = localWredConfig[ "weight" ]
            if localWredConfig[ "max-drop-probability-percent" ] == 0:
               # 0 is not a valid Cli Percent config, so convert it to invalid
               localWredConfig[ "max-drop-probability-percent" ] = tacPercent()
            if localWredConfig[ "drop" ]:
               txqConfig.wredConfig = ( localWredConfig[ "min-threshold" ],
                  localWredConfig[ "max-threshold" ], tacQueueThresholdUnit.bytes,
                  localWredConfig[ "max-drop-probability-percent" ],
                  localWredConfig[ "weight" ] )
            else:
               txqConfig.wredConfig = None
            if localWredConfig[ "enable-ecn" ]:
               txqConfig.ecnConfig = ( localWredConfig[ "min-threshold" ],
                  localWredConfig[ "max-threshold" ], tacQueueThresholdUnit.bytes,
                  localWredConfig[ "max-drop-probability-percent" ],
                  localWredConfig[ "weight" ] )
            else:
               txqConfig.ecnConfig = None

         def cleanupStaleConfigForQueueManagementProfiles():
            t0( "cleanupStaleConfigForQueueManagementProfiles" )
            # qos profile portion
            for profileName in qosProfileConfig:
               qmpName = decodeProfileName( profileName )
               if qmpName and qmpName not in wredConfigs:
                  t0( f"deleting profile {profileName}" )
                  del qosProfileConfig[ profileName ]
            # interface attachment portion
            intfConfigs = qosCliConfig.intfConfig
            for intfId, intfConfig in intfConfigs.items():
               txqConfigs = intfConfig.txQueueConfig
               for txq, txqConfig in txqConfigs.items():
                  commentKey = commentHelper.commentKey( intfId, txq )
                  comment = cliConfig.comment.get( commentKey, "" )
                  qmpInfo = commentHelper.queueManagementProfileInfoFromComment(
                     comment )
                  if not qmpInfo:
                     # delete the config only if qmp related comment is present
                     continue
                  if intfId not in intfQueueWredConfigs or \
                        txq.id not in intfQueueWredConfigs[ intfId ]:
                     t0( "deleting wred/ecn config from "
                         f"intf: {intfId}, q: {txq.id}" )
                     txqConfig.ecnConfig = None
                     txqConfig.wredConfig = None
                     newComment = \
                        commentHelper.deleteQueueManagementProfileComment( comment )
                     if newComment:
                        cliConfig.comment[ commentKey ] = newComment
                     else:
                        del cliConfig.comment[ commentKey ]
                     # add cleaned up queue to local data
                     intfQueueWredConfigs.setdefault( intfId, {} ).update(
                           { txq.id: True } )

         def getTxQueueType():
            if qosHwStatus.numMulticastQueueSupported:
               return tacQueueType.ucq
            return tacQueueType.unknown

         def updateCliComment( intfId, qId, qmpInfo ):
            t0( f"updateCliComment intf = {intfId} queue = {qId}"
                f" QMPInfo = {qmpInfo}" )
            interface = Tac.newInstance( "Arnet::IntfId", intfId )
            txQueue = Tac.newInstance( "Qos::TxQueue" )
            txQueue.id = qId
            txQueue.type = getTxQueueType()
            commentKey = commentHelper.commentKey( interface, txQueue )
            oldComment = cliConfig.comment.get( commentKey, "" )
            newComment = commentHelper.createQueueManagementProfileComment(
                  qmpInfo, oldComment )
            cliConfig.comment[ commentKey ] = newComment

         # qos profile portion
         for qmpName, qmp in queueManagementProfiles.queueManagementProfile.items():
            ecnWred = qmp.wred.uniform.config
            # BUG886023: Allow max(uint32) value in ECN/WRED QueueThresholds
            if ecnWred.minThreshold > qosHwStatus.ecnMaxQueueThresholdInBytes or \
               ecnWred.maxThreshold > qosHwStatus.ecnMaxQueueThresholdInBytes:
               raise AirStreamLib.ToNativeSyncherError( sessionName,
                     "ToNativeQueueManagementProfilesSyncher",
                     "RangeException: ECN/WRED threshold greater than " +
                     f"{qosHwStatus.ecnMaxQueueThresholdInBytes} not supported",
                     f"Configured Values - minTh: {ecnWred.minThreshold} " +
                     f"maxTh: {ecnWred.maxThreshold}" )
            if ecnWred.weight > qosHwStatus.maxWeightValue or \
               ecnWred.weight < qosHwStatus.minWeightValue:
               raise AirStreamLib.ToNativeSyncherError( sessionName,
                     "ToNativeQueueManagementProfilesSyncher",
                     "RangeException: ECN/WRED average queue length weight is " +
                     f"outside the allowed range: {qosHwStatus.minWeightValue} - " +
                     f"{qosHwStatus.maxWeightValue}",
                     f"Configured value - {ecnWred.weight}" )
            localWredConfig = { "min-threshold": ecnWred.minThreshold,
                                "max-threshold": ecnWred.maxThreshold,
                                "enable-ecn": ecnWred.enableEcn,
                                "drop": ecnWred.drop,
                                "weight": ecnWred.weight,
                                "max-drop-probability-percent":
                                    ecnWred.maxDropProbabilityPercent }
            wredConfigs[ qmpName ] = localWredConfig
            updateQosProfileData( qmpName )

         # interface attachment portion
         for interfaceId, interface in interfaces.interface.items():
            interfaceId = getInterfaceIdAfterValidation( interfaces, interfaceId )
            if not interface.output:
               continue
            if not interface.output.queues:
               continue
            interfaceOutputQueues = interface.output.queues
            for queueName, interfaceOutputQueue in interfaceOutputQueues.q.items():
               queueManagementProfileName = \
                     interfaceOutputQueue.config.queueManagementProfile
               if not queueManagementProfileName:
                  continue
               queueManagementProfile = \
                  queueManagementProfiles.queueManagementProfile.get(
                     queueManagementProfileName )
               if not queueManagementProfile:
                  t0( f'for interface: {interfaceId}, output-queue: {queueName}'
                      ' the attached queue-management-profile is not created: '
                      f'{queueManagementProfileName}' )
                  continue
               wred = queueManagementProfile.wred.uniform.config
               if wred.drop == wred.enableEcn:
                  t0( "Warning: Both drop and enable-ecn are "
                      f"{wred.drop}, Ignoring config" )
                  continue
               if wred.maxDropProbabilityPercent == 0:
                  t0( "Warning: Drop Percent is given as 0, ignoring config" )
                  continue
               qId = queueNameToTxQueueIdMap.get( queueName )
               t0( f"intf: {interfaceId}, qId: {qId}"
                   f"queue-management-profile: {queueManagementProfileName}" )
               if ( qId is None ) or ( int( qId ) > maxTxQ ):
                  continue

               intfId = Tac.Value( "Arnet::IntfId", interfaceId )
               txq = Tac.newInstance( "Qos::TxQueue" )
               txq.id = int( qId )
               txq.type = getTxQueueType()

               # update comment before txQueueConfig as augmented sm expects this
               if wred.drop or wred.enableEcn:
                  qmpInfo = Tac.Value( "QosOc::QueueManagementProfileCommentInfo" )
                  qmpInfo.queueManagementProfile = queueManagementProfileName
                  updateCliComment( interfaceId, qId, qmpInfo )

               intfConfig = qosCliConfig.intfConfig.get( intfId, None )
               if not intfConfig:
                  intfConfig = qosCliConfig.intfConfig.newMember( intfId )

               txqConfig = intfConfig.txQueueConfig.get( txq, None )
               if not txqConfig:
                  txqConfig = intfConfig.txQueueConfig.newMember(
                     txq, tacTxQueuePriority.priorityInvalid,
                     tacPercent(), tacShapeRate(), tacGuaranteedBw() )

               ecnWredWeight = Tac.Value( 'Qos::Weight' )
               if qosHwStatus.queueWeightSupported:
                  txqConfig.weightConfig = wred.weight
               else:
                  ecnWredWeight = wred.weight

               maxThreshold = wred.maxThreshold
               minThreshold = wred.minThreshold
               if maxThreshold < minThreshold:
                  t0( " setting max-threshold = min-threshold" )
                  maxThreshold = minThreshold

               if wred.drop:
                  t0( f"setting drop : {wred}" )
                  txqConfig.wredConfig = ( minThreshold, maxThreshold,
                  tacQueueThresholdUnit.bytes, wred.maxDropProbabilityPercent,
                  ecnWredWeight )
                  txqConfig.ecnConfig = None
               if wred.enableEcn:
                  t0( f"setting ecn : {wred}" )
                  txqConfig.ecnConfig = ( minThreshold, maxThreshold,
                  tacQueueThresholdUnit.bytes, wred.maxDropProbabilityPercent,
                  ecnWredWeight )
                  txqConfig.wredConfig = None
               if txqConfig.wredConfig or txqConfig.ecnConfig:
                  intfQueueWredConfigs.setdefault( interfaceId, {} ).update(
                     { qId: True } )

         # clean stale entries( also handles deletion or updation )
         cleanupStaleConfigForQueueManagementProfiles()

         # queues pre commit handler
         def getThresholdUnit( unit ):
            if unit == "BYTES":
               return tacQueueThresholdUnit.bytes
            if unit == "KILOBYTES":
               return tacQueueThresholdUnit.kbytes
            if unit == "MEGABYTES":
               return tacQueueThresholdUnit.mbytes
            if unit == "SEGMENTS":
               return tacQueueThresholdUnit.segments
            if unit == "MILLISECONDS":
               return tacQueueThresholdUnit.milliseconds
            if unit == "MICROSECONDS":
               return tacQueueThresholdUnit.microseconds
            return None

         def cleanupStaleConfigForQueues():
            intfConfigs = qosCliConfig.intfConfig
            for intfId, intfConfig in intfConfigs.items():
               txqConfigs = intfConfig.txQueueConfig
               for txq, txqConfig in txqConfigs.items():
                  if intfId not in intfQueueWredConfigs or \
                     txq.id not in intfQueueWredConfigs[ intfId ]:
                     t0( "setting default ecn value "
                         f"for intf: {intfId}, TxQ: {txq.id}" )
                     txqConfig.ecnConfig = None

         def txQueueType( qType ):
            if qType == 'ucq':
               return tacQueueType.ucq
            if qType == 'mcq':
               return tacQueueType.mcq
            assert qType == 'unknown'
            return tacQueueType.unknown

         t0( "running queues preCommit handler" )
         for name, queue in queues.q.items():
            if not queue.ecn:
               continue
            matchRegEx = r"((?:Ethernet\d+(?:/\d+)*)|(?:Port-Channel\d+))" + \
                  r"(?:-(ucq|mcq|unknown))?-(\d+)"
            match = re.search( matchRegEx, name )
            if not match:
               # process only those queues which follow augmentation format
               # e.g. qName = 'Ethernet1/2/3-unknown-6'
               # qMatch[ 0 ] = 'Ethernet1/2/3-unknown-6'
               # qMatch[ 1 ] = 'Ethernet1/2/3' ( name of the interface )
               # qMatch[ 2 ] = 'unknown' ( type of the queue )
               # qMatch[ 3 ] = '6' ( txQueueId )
               continue
            ( intfName, qType, qId ) = match.groups()
            if intfName in intfQueueWredConfigs and \
               int( qId ) in intfQueueWredConfigs[ intfName ]:
               t0( f"skipping for intf: {intfName}, qId: {qId}" )
               continue
            txQueue = Tac.newInstance( "Qos::TxQueue" )
            txQueue.id = tacTxQueueId( int( qId ) )
            txQueue.type = txQueueType( qType )
            if txQueue.id > maxTxQ:
               continue
            t0( f"processing queue: {name}, value: {queue}" )
            intfId = Tac.newInstance( "Arnet::IntfId", intfName )
            intfConfig = qosCliConfig.intfConfig.get( intfId )
            if not intfConfig:
               intfConfig = qosCliConfig.intfConfig.newMember( intfId )
            txQueueConfig = intfConfig.txQueueConfig.get( txQueue )
            if not txQueueConfig:
               txQueueConfig = intfConfig.txQueueConfig.newMember(
                  txQueue, tacTxQueuePriority.priorityInvalid,
                  tacPercent(), tacShapeRate(), tacGuaranteedBw() )
            ecnData = queue.ecn.config
            t0( f"updating ecn data: {ecnData}" )
            percent = ecnData.maxDropRate if ecnData.maxDropRate else \
                  tacPercent()
            unit = getThresholdUnit( ecnData.thresholdUnit )
            txQueueConfig.ecnConfig = ( ecnData.minThreshold,
                  ecnData.maxThreshold, unit,
                  percent, ecnData.weight )
            intfQueueWredConfigs.setdefault( intfName, {} ).update(
                  { txQueue.id: True } )

         # remove stale config
         cleanupStaleConfigForQueues()

   GnmiSetCliSession.registerPreCommitHandler(
         ToNativeQueueManagementProfilesSyncher )
