# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
from CliPlugin.PolicyMapCliLib import PolicyOpChkr
from TypeFuture import TacLazyType
from CliMode.AleCpuPolicy import CpuTrafficPolicyEnforcementConfigMode
import ConfigMount
import CliCommand
import LazyMount
import CliGlobal
import Cell
import CliSession

gv = CliGlobal.CliGlobal(
   PolicyEnforcementConfig=TacLazyType( 'PolicyMap::PolicyEnforcementConfig' ),
   policiesVrfConfig=None,
   policiesIntfConfig=None,
   cpuPolicyGlobalConfig=None,
   policiesStatusRequestDir=None,
   policiesStatus=None,
   entityManager=None,
   l3IntfStatus=None,
)

def printCmd( cmd, indentLevel=0 ):
   if cmd:
      spacing = 3 * indentLevel * " "
      print( f"{spacing}{cmd}" )

def getCpuTrafficPolicyVrfs( policyName ):
   vrfs = []
   for vrf, policyConfig in gv.policiesVrfConfig.trafficPolicies.items():
      if policyConfig.policyName == policyName:
         vrfs.append( vrf )
   return vrfs

def getCpuPolicyEnforcementCmds( context ):
   cmds = []
   if not context:
      return cmds
   if context.includeManagement:
      cmds.append( 'enforcement management' )
   return cmds

class CpuTrafficPolicyContext:
   def __init__( self, mode, trafficPolicyName, vrfName ):
      self.mode = mode
      self.trafficPolicyName = trafficPolicyName
      self.vrfName = vrfName if vrfName else 'all'

      # Gather all VRFs to which this policy is currently applied.
      # The 'enforcement management' config won't be reset for this vrf unless
      # this traffic-policy is newly applied to the vrf.
      self.includeManagementPerVrf = {
         vrf: policyConfig.includeManagement
         for vrf, policyConfig in gv.policiesVrfConfig.trafficPolicies.items()
         if policyConfig.policyName == self.trafficPolicyName
      }
      self.includeManagement = ( self.vrfName in self.includeManagementPerVrf and
                                 self.includeManagementPerVrf[ self.vrfName ] )

   def abort( self ):
      self.trafficPolicyName = None
      self.vrfName = None
      self.includeManagementPerVrf.clear()
      self.includeManagement = False

   def includeManagementIs( self, value ):
      self.includeManagement = value

   def commit( self ):
      prevPoliciesVrfConfig = dict( gv.policiesVrfConfig.trafficPolicies.items() )
      policyConfig = gv.PolicyEnforcementConfig( self.trafficPolicyName,
                                                 self.includeManagement )
      gv.policiesVrfConfig.trafficPolicies[ self.vrfName ] = policyConfig

      if self.mode.session_.inConfigSession():
         def handler( mode, onSessionCommit=True ):
            return CpuTrafficPolicyContext._checkStatus( mode,
                                                         prevPoliciesVrfConfig )
         CliSession.registerSessionOnCommitHandler(
               self.mode.session_.entityManager, "cpu-traffic-policy", handler )
         return

      if not ( self.mode.session_.startupConfig() or
               self.mode.session_.isStandalone() ):
         CpuTrafficPolicyContext._checkStatus( self.mode, prevPoliciesVrfConfig )

   @staticmethod
   def _checkStatus( mode, prevPoliciesVrfConfig ):
      def getVrfStatuses():
         chkr = PolicyOpChkr( gv.policiesStatusRequestDir, gv.policiesStatus )
         for vrf in gv.policiesVrfConfig.trafficPolicies:
            for intfId, status in gv.l3IntfStatus.intfStatus.items():
               intfVrf = status.vrf
               if vrf in ( intfVrf, 'all' ):
                  chkr.registerRequest( 'intf', str( intfId ) )
         return chkr.verifyAllRequests()

      def rollback( result ):
         reason = result.error if result else 'unknown'
         mode.addError( 'Failed to commit traffic-policy : %s' % reason )

         if mode.session_.inConfigSession():
            return reason

         mode.addError( 'Rolling back to previous configuration' )

         # Rollback the new VRFs
         for vrf in gv.policiesVrfConfig.trafficPolicies:
            if vrf not in prevPoliciesVrfConfig:
               del gv.policiesVrfConfig.trafficPolicies[ vrf ]

         for vrf, policy in prevPoliciesVrfConfig.items():
            gv.policiesVrfConfig.trafficPolicies[ vrf ] = policy

         # Rollback of multiple VRFs (read many interfaces) can itself take time, so
         # block the user until this is complete.
         for commitStatus, _ in getVrfStatuses():
            if not commitStatus:
               mode.addError( 'Failed to roll back to previous configuration' )
         return reason

      for commitStatus, result in getVrfStatuses():
         if not commitStatus:
            return rollback( result )

def handleCpuTrafficPolicyConfig( mode, args ):
   name = args[ 'POLICY' ]
   vrfName = args.get( 'VRF' )

   context = CpuTrafficPolicyContext( mode, name, vrfName )
   childMode = mode.childMode( CpuTrafficPolicyEnforcementConfigMode,
                               parentMode=mode,
                               context=context )
   mode.session_.gotoChildMode( childMode )

def handleNoOrDefaultTrafficPolicyConfig( mode, args ):
   name = args[ 'POLICY' ]
   vrfName = args.get( 'VRF' ) if args.get( 'VRF' ) else 'all'
   oldVrfList = getCpuTrafficPolicyVrfs( name )
   if vrfName in oldVrfList:
      del gv.policiesVrfConfig.trafficPolicies[ vrfName ]

def handleIntfCpuTrafficPolicyConfig( mode, args ):
   policyName = args[ 'POLICY' ]

   def _checkStatus( mode, intfName, prevPolicy ):
      chkr = PolicyOpChkr( gv.policiesStatusRequestDir,
                           gv.policiesStatus )
      commitStatus, result = chkr.verify( 'intf', intfName )
      if not commitStatus:
         reason = result.error if result else 'unknown'
         mode.addError( 'Failed to commit traffic-policy : %s' % reason )

         if mode.session_.inConfigSession():
            return reason

         mode.addError( 'Rolling back to previous configuration' )

         # Rollback to previous policy
         if prevPolicy:
            gv.policiesIntfConfig.intf[ intfName ] = prevPolicy
         else:
            del gv.policiesIntfConfig.intf[ intfName ]
         return reason
      return None

   intfName = mode.intf.name
   # save prev policy
   prevPolicy = gv.policiesIntfConfig.intf.get( intfName )
   # add new policy
   gv.policiesIntfConfig.intf[ intfName ] = policyName

   if mode.session_.inConfigSession():
      def handler( mode, onSessionCommit=True ):
         nonlocal intfName
         return _checkStatus( mode, intfName, prevPolicy )
      CliSession.registerSessionOnCommitHandler(
            mode.session_.entityManager, "cpu-traffic-policy-intf", handler )
      return

   if not ( mode.session_.startupConfig() or
            mode.session_.isStandalone() ):
      _checkStatus( mode, intfName, prevPolicy )

def handleNoOrDefaultIntfCpuTrafficPolicyConfig( mode, args ):
   del gv.policiesIntfConfig.intf[ mode.intf.name ]

def handllePerVrfCpuPolicyConfig( mode, args ):
   args[ 'VRF' ] = mode.vrfName
   handleCpuTrafficPolicyConfig( mode, args )

def handleNoOrDefaultVrfCpuPolicyConfig( mode, args ):
   del gv.policiesVrfConfig.trafficPolicies[ mode.vrfName ]

def handleShowPendingCpuPolicyEnforcement( mode, args ):
   context = mode.context
   cpuPolicyApplicationCmd = mode.enterCmd()
   printCmd( cpuPolicyApplicationCmd, indentLevel=0 )
   enforcementCmds = getCpuPolicyEnforcementCmds( context )
   if enforcementCmds:
      for cmd in enforcementCmds:
         printCmd( cmd, indentLevel=1 )

def handleCpuTrafficPolicyEnforceManagementCmd( mode, args ):
   context = mode.context
   includeManagement = not CliCommand.isNoOrDefaultCmd( args )
   context.includeManagementIs( includeManagement )

def handleCpuPolicyPermitFragmentConfig( mode, args ):
   gv.cpuPolicyGlobalConfig.installPermitFragment = False

def handleNoOrDefaultCpuPolicyPermitFragmentConfig( mode, args ):
   gv.cpuPolicyGlobalConfig.installPermitFragment = True

def handleCpuPolicyEnforcementIpTtlExpiredConfig( mode, args ):
   gv.cpuPolicyGlobalConfig.enforcementIpTtlExpired = True

def handleNoOrDefaultCpuPolicyEnforcementIpTtlExpiredConfig( mode, args ):
   gv.cpuPolicyGlobalConfig.enforcementIpTtlExpired = False

def Plugin( em ):
   policiesRootNode = 'trafficPolicies'
   policiesCellRootNode = 'cell/%d/trafficPolicies' % Cell.cellId()
   statusNode = 'status'

   policiesVrfConfigNode = 'cpu/vrf'
   policiesVrfConfigPath = policiesRootNode + '/' + policiesVrfConfigNode
   policiesVrfConfigType = 'PolicyMap::VrfConfig'
   policiesIntfConfigNode = 'cpu/intf'
   policiesIntfConfigPath = policiesRootNode + '/' + policiesIntfConfigNode
   policiesIntfConfigType = 'PolicyMap::IntfConfig'
   cpuPolicyGlobalConfigPath = policiesRootNode + '/param/config/cpu'
   cpuPolicyGlobalConfigType = 'TrafficPolicy::CpuPolicyGlobalConfig'

   policiesStatusPath = policiesCellRootNode + '/' + statusNode
   policiesStatusType = 'Tac::Dir'

   statusRequestDirNode = 'statusRequest/cli'
   policiesStatusRequestDirPath = policiesRootNode + '/' + statusRequestDirNode
   policiesStatusRequestDirType = 'PolicyMap::PolicyMapStatusRequestDir'
   entityManager = em
   gv.policiesStatus = LazyMount.mount( entityManager,
                                        policiesStatusPath,
                                        policiesStatusType, 'ri' )
   gv.policiesVrfConfig = ConfigMount.mount( entityManager, policiesVrfConfigPath,
                                             policiesVrfConfigType, 'wi' )
   gv.policiesIntfConfig = ConfigMount.mount( entityManager, policiesIntfConfigPath,
                                              policiesIntfConfigType, 'wi' )
   gv.cpuPolicyGlobalConfig = ConfigMount.mount( entityManager,
                                              cpuPolicyGlobalConfigPath,
                                              cpuPolicyGlobalConfigType, 'wi' )
   gv.policiesStatusRequestDir = LazyMount.mount( entityManager,
                                                  policiesStatusRequestDirPath,
                                                  policiesStatusRequestDirType,
                                                  'wc' )
   gv.l3IntfStatus = LazyMount.mount( entityManager, "l3/intf/status",
                                      "L3::Intf::StatusDir", "r" )
