#!/usr/bin/env python3
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import AirStreamLib
import GnmiSetCliSession
import Plugins
import Tac
import Tracing
from QosTypes import tacPfcWatchdogAction

t0 = Tracing.Handle( "OpenConfigQos" ).trace0

tacWatchdogAction = Tac.Type( "QosOc::PfcWatchdog::Action" )

PFC_WATCHDOG_INTERMEDIATE_BOUNDARY = 0.2

@Plugins.plugin( requires=( 'session', ) )
def Plugin( entMan ):
   # pre-commit handlers
   # - Intf Pfc config
   class ToNativeInterfacePfcSyncher( GnmiSetCliSession.PreCommitHandler ):
      externalPathList = [ 'qos/openconfig/config/interfaces' ]
      nativePathList = [ 'qos/input/config/cli' ]

      @classmethod
      def run( cls, sessionName ):
         interfaces = AirStreamLib.getSessionEntity(
            entMan, sessionName, 'qos/openconfig/config/interfaces' )
         cliConfig = AirStreamLib.getSessionEntity(
            entMan, sessionName, 'qos/input/config/cli' )

         # Collection for removal of stale config
         interfaceToPriorities = {}
         for intfId, intfConfig in cliConfig.intfConfig.items():
            if intfConfig.pfcPortConfig:
               interfaceToPriorities[ intfId ] = intfConfig.pfcPortConfig.priorities

         # Collection to maintain forward mapping of config
         intfToPrioritiesNoDrop = {}
         for interfaceId, interface in interfaces.interface.items():
            if not interface.output:
               continue
            if not interface.output.pfc:
               continue
            intfId = Tac.Value( "Arnet::IntfId", interfaceId )
            # get or create intfConfig for an intfId
            intfConfig = cliConfig.intfConfig.newMember( intfId )
            if not intfConfig.pfcPortConfig:
               intfConfig.pfcPortConfig = ( interfaceId, False )
            pfcPortConfig = intfConfig.pfcPortConfig
            priorities = pfcPortConfig.priorities

            # Get the priorities value
            prioritiesNoDrop = int(
               interface.output.pfc.config.prioritiesNoDrop, 16 )
            priorities |= prioritiesNoDrop
            intfConfig.pfcPortConfig.priorities = priorities
            intfToPrioritiesNoDrop[ interfaceId ] = priorities

         # Removal of stale config
         for intfId, priorities in list( interfaceToPriorities.items() ):
            if intfId not in intfToPrioritiesNoDrop:
               # if intfId not present in forward config mapping, consider it as
               # stale and remove from cliConfig.
               cliConfig.intfConfig[ intfId ].pfcPortConfig = None
               del interfaceToPriorities[ intfId ]

   # pre-commit handlers
   # - Pfc watchdog
   class ToNativePfcWatchdogSyncher( GnmiSetCliSession.PreCommitHandler ):
      externalPathList = [ 'qos/openconfig/config/pfc' ]
      nativePathList = [ 'qos/input/config/cli' ]

      @classmethod
      def run( cls, sessionName ):
         def getPfcWatchdogCfgConflictDetails():
            timeout = cliConfig.watchdogTimeout
            recoveryCfg = cliConfig.watchdogRecoveryCfg
            recoveryTime = recoveryCfg.recoveryTime
            forcedRecovery = recoveryCfg.forcedRecovery
            pollingInterval = cliConfig.watchdogPollingInterval
            if timeout and pollingInterval:
               if ( recoveryTime and not forcedRecovery and (
                     recoveryTime < 2 * pollingInterval ) and (
                        recoveryTime < timeout ) ):
                  return ( "recovery-time", recoveryCfg, pollingInterval )
               elif timeout < 2 * pollingInterval:
                  return ( "timeout", timeout, pollingInterval )
            return None

         def addPfcWatchdogWarning( conflictCfgType, timeout, recoveryTime,
                                    pollingInterval ):
            operPollingInterval = 0.1
            if conflictCfgType == 'timeout':
               conflictCfgTimeInterval = timeout
            elif conflictCfgType == 'recovery-time':
               conflictCfgTimeInterval = recoveryTime
            if conflictCfgTimeInterval < PFC_WATCHDOG_INTERMEDIATE_BOUNDARY:
               operPollingInterval = ( 0.5 * conflictCfgTimeInterval )
               operPollingInterval = round( operPollingInterval, 3 )
               warningMessage = "User configured polling-interval" \
                  "f'pollingInterval:.5f' second(s) is greater than half of" \
                  "f'conflictCfgType' f'conflictCfgTimeInterval:.4f' second(s)." \
                  "Setting polling-interval to f'operPollingInterval:.5f' second(s)'"
               t0( warningMessage )

         def verifyPfcWatchdogConfig( prevConflictDetails ):
            timeout = cliConfig.watchdogTimeout
            recoveryTime = cliConfig.watchdogRecoveryCfg.recoveryTime
            pollingInterval = cliConfig.watchdogPollingInterval
            currConflictDetails = getPfcWatchdogCfgConflictDetails()
            if currConflictDetails and (
                  prevConflictDetails != currConflictDetails ):
               conflictCfgType = currConflictDetails[ 0 ]
               addPfcWatchdogWarning( conflictCfgType, timeout, recoveryTime,
                                           pollingInterval )

         def defaultWatchdogConfig():
            cliConfig.watchdogTimeout = 0
            recoveryConfig = Tac.Value( "Pfc::WatchdogRecoveryConfig" )
            recoveryConfig.recoveryTime = 0
            recoveryConfig.forcedRecovery = True
            cliConfig.watchdogRecoveryCfg = recoveryConfig
            cliConfig.watchdogPollingInterval = 0
            cliConfig.watchdogAction = tacPfcWatchdogAction.errdisable
            cliConfig.watchdogNonDisruptivePriorities = 0

         def configWatchdogTimeout( watchdogConfig ):
            prevConflictDetails = getPfcWatchdogCfgConflictDetails()
            defaultTimeout = watchdogConfig.defaultTimeout
            if defaultTimeout:
               cliConfig.watchdogTimeout = round( defaultTimeout, 2 )
            else:
               cliConfig.watchdogTimeout = 0
            verifyPfcWatchdogConfig( prevConflictDetails )

         def configWatchdogRecovery( watchdogConfig ):
            prevConflictDetails = getPfcWatchdogCfgConflictDetails()
            recoveryConfig = Tac.Value( "Pfc::WatchdogRecoveryConfig" )
            defaultRecoveryTime = watchdogConfig.defaultRecoveryTime
            if defaultRecoveryTime:
               recoveryConfig.recoveryTime = round( defaultRecoveryTime, 2 )
               recoveryConfig.forcedRecovery = False
            else:
               recoveryConfig.recoveryTime = 0
               recoveryConfig.forcedRecovery = True
            cliConfig.watchdogRecoveryCfg = recoveryConfig
            verifyPfcWatchdogConfig( prevConflictDetails )

         def configWatchdogPollingInterval( watchdogConfig ):
            prevConflictDetails = getPfcWatchdogCfgConflictDetails()
            defaultPollingInterval = watchdogConfig.defaultPollingInterval
            if defaultPollingInterval:
               cliConfig.watchdogPollingInterval = round(
                  defaultPollingInterval, 3 )
            else:
               cliConfig.watchdogPollingInterval = 0
            verifyPfcWatchdogConfig( prevConflictDetails )

         def configWatchdogAction( watchdogConfig ):
            action = watchdogConfig.action
            if not action or action == tacWatchdogAction.ERROR_DISABLE:
               cliConfig.watchdogAction = tacPfcWatchdogAction.errdisable
            elif action == tacWatchdogAction.DROP:
               cliConfig.watchdogAction = tacPfcWatchdogAction.drop
            elif action == tacWatchdogAction.NOTIFY_ONLY:
               cliConfig.watchdogAction = tacPfcWatchdogAction.notifyOnly

         def configWatchdogNonDisruptivePriority( watchdogConfig ):
            nonDisruptivePriority = watchdogConfig.nonDisruptivePriority
            if nonDisruptivePriority:
               cliConfig.watchdogNonDisruptivePriorities = int(
                  watchdogConfig.nonDisruptivePriority, 16 )
            else:
               cliConfig.watchdogNonDisruptivePriorities = 0

         pfcConfig = AirStreamLib.getSessionEntity( entMan, sessionName,
                                                    'qos/openconfig/config/pfc' )
         cliConfig = AirStreamLib.getSessionEntity( entMan, sessionName,
                                                    'qos/input/config/cli' )
         watchdogConfig = pfcConfig.watchdog.config
         if not watchdogConfig:
            defaultWatchdogConfig()
         configWatchdogTimeout( watchdogConfig )
         configWatchdogRecovery( watchdogConfig )
         configWatchdogPollingInterval( watchdogConfig )
         configWatchdogAction( watchdogConfig )
         configWatchdogNonDisruptivePriority( watchdogConfig )

   GnmiSetCliSession.registerPreCommitHandler( ToNativePfcWatchdogSyncher )
   GnmiSetCliSession.registerPreCommitHandler( ToNativeInterfacePfcSyncher )
