# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
from CliDynamicPlugin.IcmpErrorResponderModel import (
      ErrorResponderAfCountersModel,
      ErrorResponderCountersModel,
      ErrorResponderVrfCountersModel,
      ErrorResponderModel,
      ErrorResponderVrfModel )
from CliPlugin.IcmpErrorResponderConfigCli import (
      configDir,
      cleanupAllIcmpErrConfig,
      cleanupVrfIcmpErrConfig,
      IcmpErrExtMode,
      IcmpErrExtVrfMode )
from CliPlugin.IcmpErrorResponderShowCli import (
      allVrfStatusLocal,
      aclCliConfig,
      statusDir )
from CliPlugin.RouterGeneralCli import (
      RouterGeneralMode,
      RouterGeneralVrfMode )
from IpLibConsts import DEFAULT_VRF
from Toggles.IcmpResponderToggleLib import (
      toggleIcmpExtendedErrorResponderNdVrfEnabled )
from TypeFuture import TacLazyType

import AgentCommandRequest
import Tac

TristateBool = TacLazyType( 'Ark::TristateBoolean' )
NodeIdCodePoint = TacLazyType( 'IcmpResponder::ErrorResponder::NodeIdCodePoint' )
NodeIdVrfConfig = TacLazyType(
      'IcmpResponder::ErrorResponder::NodeIdVrfConfig' )
IpAddr = TacLazyType( 'Arnet::IpAddr' )
Ip6Addr = TacLazyType( 'Arnet::Ip6Addr' )
IntfId = TacLazyType( 'Arnet::IntfId' )
AclType = TacLazyType( 'Acl::AclType' )
AddressFamily = TacLazyType( 'Arnet::AddressFamily' )
AfCounterIdEnumTypeName = (
      'IcmpResponder::ErrorResponder::AfCounterId::CounterIdEnum' )
AfCounterIdEnum = TacLazyType( AfCounterIdEnumTypeName )
VrfCounterIdEnumTypeName = (
      'IcmpResponder::ErrorResponder::VrfCounterId::CounterIdEnum' )
VrfCounterIdEnum = TacLazyType( VrfCounterIdEnumTypeName )

def doIcmpErrExt( mode, args ):
   if isinstance( mode, RouterGeneralMode ):
      childMode = mode.childMode( IcmpErrExtMode )
   elif isinstance( mode, RouterGeneralVrfMode ):
      childMode = mode.childMode( IcmpErrExtVrfMode )
   else:
      assert False, f'unknown parent mode {type(mode)} in doIcmpErrExt'
   mode.session_.gotoChildMode( childMode )

def noIcmpErrExt( mode, args ):
   if isinstance( mode, RouterGeneralMode ):
      cleanupAllIcmpErrConfig( mode )
   elif isinstance( mode, RouterGeneralVrfMode ):
      cleanupVrfIcmpErrConfig( mode.vrfName )

def _getOrCreateVrfConfig( mode, noCreate=False ):
   vrfConfig = configDir.vrfConfig.get( mode.vrfName )
   if vrfConfig is None and not noCreate:
      vrfConfig = configDir.newVrfConfig( mode.vrfName )
   return vrfConfig

def _deleteVrfConfigIfEmpty( mode ):
   vrfConfig = configDir.vrfConfig.get( mode.vrfName )
   if vrfConfig:
      # all member attrs default
      empty = not ( bool( vrfConfig.enabled ) or
                    bool( vrfConfig.v4AclName ) or
                    bool( vrfConfig.v6AclName ) or
                    bool( vrfConfig.nodeIdConfig ) )
      if empty:
         del configDir.vrfConfig[ mode.vrfName ]

def tristateEnabledIs( args ):
   return TristateBool.valueSet( 'disabled' not in args )

def doExtensionsVrfAll( mode, args ):
   configDir.enabled = tristateEnabledIs( args )

def noExtensionsVrfAll( mode, args ):
   configDir.enabled = TristateBool()

def _addHostNameIs( add ):
   assert isinstance( add, bool )
   configDir.addHostName = add

def doParamsIncludeHostname( mode, args ):
   _addHostNameIs( True )

def noParamsIncludeHostname( mode, args ):
   _addHostNameIs( False )

def doNodeidCodepoint( mode, args ):
   codepoint = args[ 'CODEPOINT' ]
   configDir.nodeIdCodePoint = codepoint

def noNodeidCodepoint( mode, args ):
   configDir.nodeIdCodePoint = NodeIdCodePoint()

# ----------------------------
# VRF mode commands begin here
# ----------------------------
def doExtensions( mode, args ):
   vrfConfig = _getOrCreateVrfConfig( mode )
   vrfConfig.enabled = tristateEnabledIs( args )

def noExtensions( mode, args ):
   vrfConfig = _getOrCreateVrfConfig( mode, noCreate=True )
   if vrfConfig:
      vrfConfig.enabled = TristateBool()
      _deleteVrfConfigIfEmpty( mode )

def doIpAccessGroup( mode, args ):
   vrfConfig = _getOrCreateVrfConfig( mode )
   vrfConfig.v4AclName = args[ 'ACLNAME' ]

def noIpAccessGroup( mode, args ):
   vrfConfig = _getOrCreateVrfConfig( mode, noCreate=True )
   if vrfConfig:
      vrfConfig.v4AclName = ''
      _deleteVrfConfigIfEmpty( mode )

def doIp6AccessGroup( mode, args ):
   vrfConfig = _getOrCreateVrfConfig( mode )
   vrfConfig.v6AclName = args[ 'ACLNAME' ]

def noIp6AccessGroup( mode, args ):
   vrfConfig = _getOrCreateVrfConfig( mode, noCreate=True )
   if vrfConfig:
      vrfConfig.v6AclName = ''
      _deleteVrfConfigIfEmpty( mode )

def doNodeidIpv4Address( mode, args ):
   addr = args[ 'IPADDR' ]
   vrfConfig = _getOrCreateVrfConfig( mode )
   nodeIdConfig = Tac.nonConst( vrfConfig.nodeIdConfig )
   nodeIdConfig.v4Address = addr
   vrfConfig.nodeIdConfig = nodeIdConfig

def noNodeidIpv4Address( mode, args ):
   vrfConfig = _getOrCreateVrfConfig( mode, noCreate=True )
   if vrfConfig:
      nodeIdConfig = Tac.nonConst( vrfConfig.nodeIdConfig )
      nodeIdConfig.v4Address = IpAddr()
      vrfConfig.nodeIdConfig = nodeIdConfig
      _deleteVrfConfigIfEmpty( mode )

def doNodeidIpv6Address( mode, args ):
   addr = args[ 'IP6ADDR' ]
   vrfConfig = _getOrCreateVrfConfig( mode )
   nodeIdConfig = Tac.nonConst( vrfConfig.nodeIdConfig )
   nodeIdConfig.v6Address = addr
   vrfConfig.nodeIdConfig = nodeIdConfig

def noNodeidIpv6Address( mode, args ):
   vrfConfig = _getOrCreateVrfConfig( mode, noCreate=True )
   if vrfConfig:
      nodeIdConfig = Tac.nonConst( vrfConfig.nodeIdConfig )
      nodeIdConfig.v6Address = Ip6Addr()
      vrfConfig.nodeIdConfig = nodeIdConfig
      _deleteVrfConfigIfEmpty( mode )

def doNodeidInterface( mode, args ):
   intf = args[ 'INTF' ]
   vrfConfig = _getOrCreateVrfConfig( mode )
   nodeIdConfig = Tac.nonConst( vrfConfig.nodeIdConfig )
   nodeIdConfig.intfId = IntfId( intf.name )
   vrfConfig.nodeIdConfig = nodeIdConfig

def noNodeidInterface( mode, args ):
   vrfConfig = _getOrCreateVrfConfig( mode, noCreate=True )
   if vrfConfig:
      nodeIdConfig = Tac.nonConst( vrfConfig.nodeIdConfig )
      nodeIdConfig.intfId = IntfId()
      vrfConfig.nodeIdConfig = nodeIdConfig
      _deleteVrfConfigIfEmpty( mode )

def _isDisabled( tristateBool ):
   return tristateBool.isSet and not tristateBool.value

def _isEnabled( tristateBool ):
   return tristateBool.isSet and tristateBool.value

def _isVrfDisabled( vrfName ):
   vrfConfig = configDir.vrfConfig.get( vrfName )
   if vrfConfig:
      return _isDisabled( vrfConfig.enabled )
   return False

def _isVrfEnabled( vrfName ):
   vrfConfig = configDir.vrfConfig.get( vrfName )
   if vrfConfig:
      return _isEnabled( vrfConfig.enabled )
   return False

def _aclNameAndActive( af, vrfName ):
   '''
   Returns None of ACL is not configured.
   Returns tuple ( aclName, activeState ) if ACL is configured.
   '''
   def _aclTypeConfig():
      aclType = ( AclType.ip if af == AddressFamily.ipv4 else
                  AclType.ipv6 )
      aclTypeConfig_ = aclCliConfig.config.get( aclType )
      return aclTypeConfig_

   vrfConfig = configDir.vrfConfig.get( vrfName )
   if vrfConfig:
      aclName = ( vrfConfig.v4AclName if af == AddressFamily.ipv4 else
                  vrfConfig.v6AclName )
      if aclName:
         aclTypeConfig = _aclTypeConfig()
         aclConfig = aclTypeConfig.acl.get( aclName )
         return ( aclName, ( aclConfig is not None ) )
   return None

def _erVrfModel( vrfName, v4Status, v6Status ):
   ervm = ErrorResponderVrfModel()
   ervm.v4Responding = bool( v4Status ) and ( vrfName in v4Status.vrfStatus )
   ervm.v6Responding = bool( v6Status ) and ( vrfName in v6Status.vrfStatus )

   # only populate active state for AFs that are responding
   afs = []
   if ervm.v4Responding:
      afs.append( AddressFamily.ipv4 )
   if ervm.v6Responding:
      afs.append( AddressFamily.ipv6 )

   for af in afs:
      aclActive = None
      aclNameAndActive = _aclNameAndActive( af, vrfName )

      # only populate acl attrs if acl is configured
      if aclNameAndActive:
         aclName, aclActive = aclNameAndActive
         assert isinstance( aclActive, bool )
         if af == AddressFamily.ipv4:
            ervm.v4AclName = aclName
            ervm.v4AclActive = aclActive
         else:
            ervm.v6AclName = aclName
            ervm.v6AclActive = aclActive

      # Right now, extensions are marked inactive only when acl is inactive
      # TODO with ndVrf support: Add other conditions like PAM creation failure
      if af == AddressFamily.ipv4:
         ervm.v4ExtensionsActive = ( aclActive is not False )
      else:
         ervm.v6ExtensionsActive = ( aclActive is not False )
   return ervm

def showIcmpErrExt( mode, args ):
   erModel = ErrorResponderModel()
   v4Status = statusDir.afStatus.get( AddressFamily.ipv4 )
   v6Status = statusDir.afStatus.get( AddressFamily.ipv6 )
   erModel.v4Active = ( v4Status is not None )
   erModel.v6Active = ( v6Status is not None )

   globalDisable = _isDisabled( configDir.enabled )
   globalEnable = _isEnabled( configDir.enabled )

   # Derive vrfsConfigured
   vrfsConsidered = { DEFAULT_VRF }
   vrfsConsidered.update( configDir.vrfConfig.keys() )
   if toggleIcmpExtendedErrorResponderNdVrfEnabled():
      vrfsConsidered.update( allVrfStatusLocal.vrf.keys() )
   if not globalDisable:
      if globalEnable:
         vrfsConfigured = { v for v in vrfsConsidered
                            if not _isVrfDisabled( v ) }
      else:
         vrfsConfigured = { v for v in vrfsConsidered
                            if _isVrfEnabled( v ) }
   else:
      vrfsConfigured = set()

   # Derive vrfsInStatus
   vrfsInStatus = set()
   if v4Status:
      vrfsInStatus = vrfsInStatus.union( v4Status.vrfStatus.keys() )
   if v6Status:
      vrfsInStatus = vrfsInStatus.union( v6Status.vrfStatus.keys() )

   vrfsToShow = vrfsConfigured.union( vrfsInStatus )

   # Apply vrfName filter
   vrfName = args.get( 'VRF' )
   if vrfName is not None:
      vrfsToShow = vrfsToShow.intersection( { vrfName } )

   for v in vrfsToShow:
      erModel.vrfs[ v ] = _erVrfModel( v, v4Status, v6Status )

   return erModel

def _vrfCountersModel( vrfStatus ):
   def vrfCounter( counterIdEnum ):
      counterId = Tac.enumValue( VrfCounterIdEnumTypeName, counterIdEnum )
      return vrfStatus.vrfCounter[ counterId ]

   vrfModel = ErrorResponderVrfCountersModel()
   vrfModel.packetsProcessed = vrfCounter( VrfCounterIdEnum.inboundPkts )
   vrfModel.socketNotFound = vrfCounter( VrfCounterIdEnum.pamNotFound )
   vrfModel.extendedIcmpErrorsSent = \
         vrfCounter( VrfCounterIdEnum.icmpExtErrPktsSent )
   vrfModel.icmpErrorsSent = vrfCounter( VrfCounterIdEnum.icmpErrPktsSent )
   vrfModel.socketErrors = vrfCounter( VrfCounterIdEnum.pamPktTxErrors )

   return vrfModel

def _afCountersModel( afStatus, vrfFilter ):
   def afCounter( counterIdEnum ):
      counterId = Tac.enumValue( AfCounterIdEnumTypeName, counterIdEnum )
      return afStatus.afCounter[ counterId ]

   afModel = ErrorResponderAfCountersModel()
   afModel.enqueued = afCounter( AfCounterIdEnum.enqueuedPkts )
   afModel.queueDropped = afCounter( AfCounterIdEnum.queueDroppedPkts )
   afModel.dequeued = afCounter( AfCounterIdEnum.dequeuedPkts )
   afModel.invalidPktObj = afCounter( AfCounterIdEnum.invalidSzPktInfo )
   afModel.originalDatagramInPktObjTooShort = \
         afCounter( AfCounterIdEnum.origDgmInPktInfoTooShort )
   afModel.invalidVrfInPktObj = afCounter( AfCounterIdEnum.nsValidationErrors )
   afModel.packetsProcessed = afCounter( AfCounterIdEnum.pktCountAfterAfErrorChecks )

   vrfsToConsider = ( set( afStatus.vrfStatus.keys() ) if not vrfFilter else
                      { vrfFilter } )
   for vrfName in sorted( vrfsToConsider ):
      vrfStatus = afStatus.vrfStatus.get( vrfName )
      if vrfStatus:
         afModel.vrfs[ vrfName ] = _vrfCountersModel( vrfStatus )

   return afModel

def showIcmpErrExtCounters( mode, args ):
   erModel = ErrorResponderCountersModel()
   v4Status = statusDir.afStatus.get( AddressFamily.ipv4 )
   v6Status = statusDir.afStatus.get( AddressFamily.ipv6 )
   vrfFilter = args.get( 'VRF' )
   if v4Status:
      erModel.ipv4Counters = _afCountersModel( v4Status, vrfFilter )
   if v6Status:
      erModel.ipv6Counters = _afCountersModel( v6Status, vrfFilter )
   return erModel

def clearIcmpErrExtCounters( mode, args ):
   AgentCommandRequest.runSocketCommand( mode.entityManager,
                                         dirName="IcmpResponder",
                                         commandType="errorResponderClearCounters",
                                         command="clear counters" )
