# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import AirStreamLib
import GnmiSetCliSession
import Plugins
import re
import Tac
import Tracing
from QosOpenConfigCommonLib import getInterfaceIdAfterValidation

from QosTypes import (
      tacTrafficClass, tacCMapNm,
      tacPolicyMapType, tacDirection, tacActionType
      )

tacClassifierType = Tac.Type( "QosOc::Classifier::Type" )

t0 = Tracing.Handle( "OpenConfigQos" ).trace0
ocNamePrefix = "__yang_"

@Plugins.plugin( requires=( 'session', ) )
def Plugin( entMan ):
   # precommit handlers
   #    - Interface Input Classifiers
   class ToNativeInterfaceClassifiersSyncher( GnmiSetCliSession.PreCommitHandler ):
      externalPathList = [ 'qos/openconfig/config/classifiers',
                           'qos/openconfig/config/interfaces' ]
      nativePathList = [ 'qos/input/config/cli',
                         'qos/acl/input/cli' ]

      @classmethod
      def run( cls, sessionName ):
         t0( "running ToNativeInterfaceClassifiersSyncher" )
         classifiers = AirStreamLib.getSessionEntity(
               entMan, sessionName, 'qos/openconfig/config/classifiers' )
         interfaces = AirStreamLib.getSessionEntity(
               entMan, sessionName, 'qos/openconfig/config/interfaces' )
         qosCliConfig = AirStreamLib.getSessionEntity(
               entMan, sessionName, 'qos/input/config/cli' )
         qosAclCliConfig = AirStreamLib.getSessionEntity(
               entMan, sessionName, 'qos/acl/input/cli' )
         forwardingGroupNameToTcMap = qosCliConfig.forwardingGroupNameToTcMap

         # Local data for classifiers and their classActions
         # This data will be used for creating classActions for each classifier
         classifierToClassActionsMap = {}

         # Local data for storing the intfIds to which servicePolicy is attached.
         # This will also store the classActions for the pmap
         # This local data will be used for cleaning up stale entries in EOS.
         pmapToClassActionsAndIntfsMap = {}

         pmapTypeQos = qosAclCliConfig.pmapType.get( tacPolicyMapType.mapQos,
                                                     None )

         def cmapNameFromClassifierTerm( classifierName, termId ):
            return ocNamePrefix + "[" + classifierName + "]_[" + termId + "]"

         def pmapNameFromClassifiers( intfClassifiers ):
            pmapName = ""
            if intfClassifiers:
               pmapName = ocNamePrefix
               ipv4Classifier = intfClassifiers.get( tacClassifierType.IPV4, None )
               ipv6Classifier = intfClassifiers.get( tacClassifierType.IPV6, None )
               if ipv4Classifier:
                  pmapName += "[IPV4__" + ipv4Classifier + "]"
               if ipv6Classifier:
                  pmapName += "[IPV6__" + ipv6Classifier + "]"
            return pmapName

         def populateClassifierLocalData():
            for classifierName, classifier in classifiers.classifier.items():
               classifierToClassActionsMap.setdefault( classifierName, {} )
               classifierToClassActionsMap[ classifierName ].setdefault( 'type', "" )
               if classifier.config.type:
                  classifierToClassActionsMap[
                        classifierName ][ 'type' ] = classifier.config.type
               classActions = classifierToClassActionsMap[
                     classifierName ].setdefault( 'classActions', {} )
               if classifier.terms:
                  for termId, term in classifier.terms.term.items():
                     cmapName = cmapNameFromClassifierTerm( classifierName, termId )
                     tc = tacTrafficClass.invalid
                     if term.actions:
                        targetGroup = term.actions.config.targetGroup
                        tc = forwardingGroupNameToTcMap.forwardingGroupNameToTc.get(
                              targetGroup, tacTrafficClass.invalid )
                     classActions[ cmapName ] = tc

         def populateIntfClassifiers( interface ):
            intfClassifiers = {}
            for classifierType, classifier in \
                  interface.input.classifiers.classifier.items():
               if not classifier.config.name:
                  raise AirStreamLib.ToNativeSyncherError(
                        sessionName, 'ToNativeInterfaceClassifiersSyncher',
                        f'Cannot attach the classifier to intf {interfaceId}: '
                        'name of the classifier has not been specified' )

               classifierName = classifier.config.name
               if classifierName not in classifierToClassActionsMap:
                  t0( f'Please create the classifier:{classifierName} before '
                      f'attaching it on the intf:{interfaceId}' )
                  continue
               if not classifierToClassActionsMap[ classifierName ][ 'type' ]:
                  t0( f'Cannot attach the classifier:{classifierName} to intf: '
                      f'{interfaceId}. Please create the classifier with a type' )
                  continue
               if classifierToClassActionsMap[
                  classifierName ][ 'type' ] != classifierType:
                  t0( f'Cannot attach the classifier:{classifierName} to the '
                      f'intf: {interfaceId} as classifier types do not match' )
                  continue
               intfClassifiers[ classifierType ] = classifierName
            return intfClassifiers

         def createPmapAndServicePolicy( intfClassifiers, interfaceId ):
            if not intfClassifiers:
               t0( 'No classifiers are attached to the interface' )
               return

            pmapName = pmapNameFromClassifiers( intfClassifiers )

            # Local data for pmap classActions and servicePolicy
            pmapToClassActionsAndIntfsMap.setdefault( pmapName, {} )
            pmapToClassActionsAndIntfsMap[ pmapName ].setdefault( 'intfIds', set() )
            pmapToClassActionsAndIntfsMap[ pmapName ][ 'classActions' ] = {}

            for classifierName in intfClassifiers.values():
               classActions = classifierToClassActionsMap[
                     classifierName ][ 'classActions' ]
               pmapToClassActionsAndIntfsMap[ pmapName ][
                     'classActions' ].update( classActions )

            # Create the pmap
            t0( f'Creating the pmap: {pmapName}' )
            # pmap.version and pmap.uniqueId are incremented by entityCopy handler
            pmapTypeQos = qosAclCliConfig.pmapType.newMember(
                  tacPolicyMapType.mapQos )
            pmap = pmapTypeQos.pmap.newMember( pmapName,
                                               tacPolicyMapType.mapQos )
            pmap.classDefault = ( tacCMapNm.classDefault,
                                  tacPolicyMapType.mapQos )
            match = pmap.classDefault.match.newMember( 'matchIpAccessGroup' )
            match.strValue = 'default'
            pmap.classActionDefault = ( tacCMapNm.classDefault, )

            # Create the classActions in the pmap
            cmapPrio = 1
            pmap.classPrio.clear()
            for cmapName, tc in pmapToClassActionsAndIntfsMap[
                  pmapName ][ 'classActions' ].items():
               classAction = pmap.classAction.newMember( cmapName )
               t0( f'Adding the classAction: {cmapName} to pmap: {pmapName}' )
               if tc != tacTrafficClass.invalid:
                  policyAction = classAction.policyAction.newMember(
                        tacActionType.actionSetTc )
                  policyAction.value = tc
                  t0( f'Adding the setTc action to classAction: {cmapName}'
                      f', pmap: {pmapName}' )
               classPrio = pmap.classPrio.newMember( cmapPrio )
               classPrio.cmapName = cmapName
               cmapPrio += 1

            # Create the service-policy and attach it to the interface
            servicePolicyKey = Tac.Type( "Qos::ServicePolicyKey" )(
                  tacPolicyMapType.mapQos, tacDirection.input, pmapName )
            servicePolicyConfig = qosCliConfig.\
                  servicePolicyConfig.newMember( servicePolicyKey )
            intfId = Tac.Value( "Arnet::IntfId", interfaceId )
            servicePolicyConfig.intfIds[ intfId ] = True
            t0( f'Attaching the servicePolicy {pmapName} on intf: {interfaceId}' )
            pmapToClassActionsAndIntfsMap[
                  pmapName ][ 'intfIds' ].add( interfaceId )

         def cleanUpStalePmapAndServicePolicy():
            # pylint: disable-msg=too-many-nested-blocks
            t0( 'Deleting the stale entries from servicePolicy and pmaps' )
            for servicePolicyKey, servicePolicyConfig in list(
                  qosCliConfig.servicePolicyConfig.items() ):
               if servicePolicyKey.type != tacPolicyMapType.mapQos or \
                  servicePolicyKey.direction != tacDirection.input:
                  continue

               pmapName = servicePolicyKey.pmapName
               matchRegex = \
                     "^__yang_(?:\\[IPV4__([^\\]]+)\\])?(?:\\[IPV6__([^\\]]+)\\])?$"
               # If the pmap is not created by openConfig, don't tamper with it
               if not re.search( matchRegex, pmapName ):
                  continue

               # Delete stale service-policy and pmap
               if pmapName not in pmapToClassActionsAndIntfsMap:
                  del qosCliConfig.servicePolicyConfig[ servicePolicyKey ]
                  t0( f'deleting the servicePolicy {pmapName}' )
                  if pmapTypeQos:
                     t0( f'deleting the pmap {pmapName}' )
                     del pmapTypeQos.pmap[ pmapName ]
               else:
                  # Delete stale intfIds to which the service-policy is attached
                  expIntfIds = pmapToClassActionsAndIntfsMap[ pmapName ][ 'intfIds' ]
                  for intfId in list( servicePolicyConfig.intfIds ):
                     if intfId not in expIntfIds:
                        del servicePolicyConfig.intfIds[ intfId ]
                        t0( f'Detaching the servicePolicy {pmapName} '
                            f'from the intf: {intfId}' )

                  # Delete the stale classActions in the pmap
                  pmapConfig = pmapTypeQos.pmap[ pmapName ]
                  expClassActions = pmapToClassActionsAndIntfsMap[
                        pmapName ][ 'classActions' ]
                  for cmapName, classAction in \
                        list( pmapConfig.classAction.items() ):
                     if cmapName not in expClassActions:
                        del pmapConfig.classAction[ cmapName ]
                        t0( f'deleting the classAction: {cmapName} '
                            f'from the pmap: {pmapName}' )
                     else:
                        if tacActionType.actionSetTc not in classAction.policyAction:
                           continue
                        if expClassActions[ cmapName ] == tacTrafficClass.invalid:
                           del classAction.policyAction[ tacActionType.actionSetTc ]
                           t0( f'Deleting the setTc action from classAction: '
                               f'{cmapName}, pmap: {pmapName}' )

         populateClassifierLocalData()
         for interfaceId, interface in interfaces.interface.items():
            interfaceId = getInterfaceIdAfterValidation( interfaces, interfaceId )
            if not interface.input:
               continue
            if not interface.input.classifiers:
               continue
            if not interface.input.classifiers.classifier:
               continue

            intfClassifiers = populateIntfClassifiers( interface )
            createPmapAndServicePolicy( intfClassifiers, interfaceId )
         cleanUpStalePmapAndServicePolicy()

   GnmiSetCliSession.registerPreCommitHandler( ToNativeInterfaceClassifiersSyncher )
