#!/usr/bin/env python3
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import os
import Tac
import Arnet
import MainCli
import Tracing
import LazyMount
import ConfigMount
import ArPyUtils
from CliDynamicSymbol import CliDynamicPlugin
from CliPlugin import VrfCli
from CliPlugin import IraRouteCommon
from CliPlugin.IraIpIntfCli import canSetVrfHook
from CliPlugin.IraVrfCli import (
   allVrfConfig,
   arpInputConfigCli,
   canDeleteVrfHook,
   deletedVrfHook,
   rdAutoInputDir,
   rdConfigInputDir,
   routingHwStatusCommon,
   routingVrfConfigDir,
   routingVrfRouteConfigDir,
   routingVrfRouteConfigDynamicDir,
   routing6VrfConfigDir,
   routing6VrfRouteConfigDir,
   routing6VrfRouteConfigDynamicDir,
   VrfDefinitionMode,
)
from TypeFuture import TacLazyType
from CliMode.Ira import VrfDefaultsMode
import sys
import Cell
from IpLibConsts import (
    DEFAULT_VRF,
    DEFAULT_VRF_OLD,
    VRFNAMES_RESERVED,
)

vrfDefaultsConfig = None
allVrfStatusLocal = None
runnabilityConfig = None
ipConfig = None
ip6Config = None
l3ConfigDir = None

ipRedirect = IraRouteCommon.Ip4()
ip6Redirect = IraRouteCommon.Ip6()
routing = IraRouteCommon.routing( ipRedirect )
routing6 = IraRouteCommon.routing( ip6Redirect )
t0 = Tracing.trace0
InternalIntfId = TacLazyType( 'Arnet::InternalIntfId' )
TristateU32 = TacLazyType( 'Ark::TristateU32' )
reservedPortsSysctlPath = '/proc/sys/net/ipv4/ip_local_reserved_ports'

IraVrfModel = CliDynamicPlugin( "IraVrfModel" )

def getVrfDefinitionRouteDistinguisherInput():
   vrfDefinitionRouteDistinguisherInput =\
       rdConfigInputDir.get( 'vrfDefinition' )
   assert vrfDefinitionRouteDistinguisherInput
   return vrfDefinitionRouteDistinguisherInput

def gotoVrfDefMode( mode, vrfName, deprecatedCmd=False ):
   vrfCapability = routingHwStatusCommon.vrfCapability

   def vrfCreationAllowed():
      if vrfCapability.maxVrfs == -1:
         # this means that we haven't initialized the maxVrfs field
         # yet.  This happens when parsing the system startup config,
         # and since we don't know how many are supported we assume
         # that the config file was written correctly and accept
         # everything into Config
         return True
      return len( allVrfConfig.vrf ) < vrfCapability.maxVrfs

   if vrfName in VRFNAMES_RESERVED:
      mode.addError( f"The vrf name '{vrfName}' is reserved." )
      return

   try:
      if vrfName not in allVrfConfig.vrf and not vrfCreationAllowed():
         mode.addError( f"Can't create VRF {vrfName} (maximum number of VRFs" +
                        f" supported is {vrfCapability.maxVrfs})" )
         return
   except IndexError:
      mode.addError( f"'{vrfName}' too long: must be no more than" +
                     f' {Tac.Type( "L3::VrfName" ).maxLength} characters' )
      return

   childMode = mode.childMode( VrfDefinitionMode, vrfName=vrfName )
   mode.session_.gotoChildMode( childMode )

def deleteVrf( mode, vrfName ):
   if vrfName in VRFNAMES_RESERVED:
      mode.addError( f"The vrf name '{vrfName}' is reserved." )
      return
   intfModifiedCount = 0

   # first check to see if anyone is preventing us from deleting the VRF
   for hook in canDeleteVrfHook.extensions():
      ( accept, hookMsg ) = hook( vrfName )
      if hookMsg:
         # pylint: disable-msg=W0106
         mode.addWarning( hookMsg ) if accept else mode.addError( hookMsg )
      if not accept:
         return

   # first check to see if anyone is preventing us from deleting the VRF
   for ipIntfName in ipConfig.ipIntfConfig:
      iic = ipConfig.ipIntfConfig.get( ipIntfName )
      if iic is not None and iic.vrf == vrfName:
         for hook in canSetVrfHook.extensions():
            ( accept, hookMsg ) = hook( ipIntfName, oldVrf=vrfName,
                                        newVrf=DEFAULT_VRF, vrfDelete=True )
            if hookMsg:
               if accept:
                  mode.addWarning( hookMsg )
               else:
                  mode.addError( hookMsg )
            if not accept:
               return

   # first check to see if anyone is preventing us from deleting the VRF
   for ipIntfName in ip6Config.intf:
      iic = ip6Config.intf.get( ipIntfName )
      if iic is not None and iic.vrf == vrfName:
         for hook in canSetVrfHook.extensions():
            ( accept, hookMsg ) = hook( ipIntfName, oldVrf=vrfName,
                                        newVrf=DEFAULT_VRF, vrfDelete=True )
            if hookMsg:
               if accept:
                  mode.addWarning( hookMsg )
               else:
                  mode.addError( hookMsg )
            if not accept:
               return

   # clear out the address(es) on all of the interfaces in the VRF &
   # set the VRF to null
   for ipIntfName in ipConfig.ipIntfConfig:
      iic = ipConfig.ipIntfConfig.get( ipIntfName )
      if iic is None:
         continue
      if iic.vrf == vrfName:
         # FIXFIXFIX really ought to find a way to call the noIpAddr()
         # function in IraIpIntfCli.py rather than replicating the
         # logic here.  Ick.
         intfModifiedCount = intfModifiedCount + 1
         for a in iic.secondaryWithMask:
            t0( f'clearing address {a} from intf {ipIntfName} due to deletion' +
                f' of vrf {vrfName}' )
            del iic.secondaryWithMask[ a ]
         for a in iic.virtualSecondaryWithMask:
            t0( f'clearing virtual secondary address {a} from' +
                f' intf {ipIntfName} due to deletion of vrf {vrfName}' )
            del iic.virtualSecondaryWithMask[ a ]
         # Disable dhcp on the interface
         # This would ensure that the interface doesn't get an IP address
         # via dhcp in the default vrf when this vrf is deleted
         iic.addrSource = 'manual'
         zeroAddr = Arnet.AddrWithMask( "0.0.0.0/0" )
         if iic.addrWithMask != zeroAddr:
            t0( f'clearing address {iic.addrWithMask} from' +
                f' intf {ipIntfName} due to deletion of vrf {vrfName}' )
            iic.addrWithMask = zeroAddr
         if iic.virtualAddrWithMask != zeroAddr:
            t0( f'clearing virtual address {iic.virtualAddrWithMask} from' +
                f' intf {ipIntfName} due to deletion of vrf {vrfName}' )
            iic.virtualAddrWithMask = zeroAddr
         # now v6
         if ip6Config.intf.get( ipIntfName ):
            i6ic = ip6Config.intf[ ipIntfName ]
            i6ic.addrSource = 'manual'
            for a in i6ic.addr:
               t0( f'clearing address {a} from intf {ipIntfName} due to deletion' +
                   f' of vrf {vrfName}' )
               del i6ic.addr[ a ]
         # now clear out the VRF
         l3ConfigDir.intfConfig[ ipIntfName ].vrf = DEFAULT_VRF

   for ipIntfName in ip6Config.intf:
      iic = ip6Config.intf.get( ipIntfName )
      if iic is None:
         continue
      if iic.vrf == vrfName:
         intfModifiedCount = intfModifiedCount + 1
         iic.addrSource = 'manual'
         if ip6Config.intf.get( ipIntfName ):
            i6ic = ip6Config.intf[ ipIntfName ]
            for a in i6ic.addr:
               t0( f'clearing address {a} from intf {ipIntfName} due to deletion' +
                   f' of vrf {vrfName}' )
               del i6ic.addr[ a ]
         # now clear out the VRF
         l3ConfigDir.intfConfig[ ipIntfName ].vrf = DEFAULT_VRF

   if intfModifiedCount > 0:
      mode.addWarning(
         f'IP addresses from all interfaces in VRF {vrfName} have been removed' )

   if allVrfConfig.vrf.get( vrfName ):
      del allVrfConfig.vrf[ vrfName ]

   vrfDefinitionRouteDistinguisherInput = getVrfDefinitionRouteDistinguisherInput()

   if vrfName in vrfDefinitionRouteDistinguisherInput.routeDistinguisher:
      del vrfDefinitionRouteDistinguisherInput.routeDistinguisher[ vrfName ]

   if vrfName in routingVrfConfigDir.vrfConfig:
      del routingVrfConfigDir.vrfConfig[ vrfName ]
      del routingVrfRouteConfigDir.vrfConfig[ vrfName ]
      del routingVrfRouteConfigDynamicDir.vrfConfig[ vrfName ]

   if vrfName in routing6VrfConfigDir.vrfConfig:
      del routing6VrfConfigDir.vrfConfig[ vrfName ]
      del routing6VrfRouteConfigDir.vrfConfig[ vrfName ]
      del routing6VrfRouteConfigDynamicDir.vrfConfig[ vrfName ]

   if vrfName in runnabilityConfig.vrf:
      del runnabilityConfig.vrf[ vrfName ]

   if vrfName in arpInputConfigCli.vrf:
      del arpInputConfigCli.vrf[ vrfName ]

   # Notify interested parties that the vrf has been deleted
   deletedVrfHook.notifyExtensions( vrfName )
   # Reset the CLI session to VRF mapping, if required.
   cliSessVrf = VrfCli.vrfMap.getCliSessVrf( mode.session )
   if cliSessVrf == vrfName:
      VrfCli.vrfMap.setCliVrf( mode.session )

def handlerEnterVrfDefMode( mode, args ):
   gotoVrfDefMode( mode, args[ "VRF_NAME" ] )

def noOrDefaultHandlerEnterVrfDefMode( mode, args ):
   deleteVrf( mode, args[ "VRF_NAME" ] )

def handlerVrfRdCmd( mode, args ):
   checker = Tac.Value( 'Ira::RdAssignmentChecker',
                        LazyMount.force( rdAutoInputDir ),
                        LazyMount.force( rdConfigInputDir ),
                        mode.vrfName, args[ "<rd>" ] )
   if not checker.acceptable:
      mode.addError( checker.errorMessage )
      return

   mode.addMessage( "Since the RD is required for BGP operation, " +
                    f"please configure the RD for VRF {mode.vrfName} " +
                    "under the 'router bgp vrf' submode. " +
                    "The configuration of the RD " +
                    "under the VRF definition submode " +
                    "is deprecated and no longer required." )

   tacRd = Tac.Value( 'Arnet::RouteDistinguisher' )
   try:
      tacRd.stringValue = args[ "<rd>" ]
   except IndexError:
      mode.addError( f'Malformed Route Distinguisher: {args[ "<rd>" ]}' )
      return

   vrfDefinitionRouteDistinguisherInput = \
      getVrfDefinitionRouteDistinguisherInput()
   vrfDefinitionRouteDistinguisherInput.routeDistinguisher[ mode.vrfName ] = tacRd

def noOrDefaultHandlerVrfRdCmd( mode, args ):
   vrfDefinitionRouteDistinguisherInput = \
      getVrfDefinitionRouteDistinguisherInput()

   if mode.vrfName in vrfDefinitionRouteDistinguisherInput.routeDistinguisher:
      del vrfDefinitionRouteDistinguisherInput.routeDistinguisher[ mode.vrfName ]

def handlerVrfDescriptionCmd( mode, args ):
   allVrfConfig.vrf[ mode.vrfName ].description = args[ "<desc>" ]

def noOrDefaultHandlerVrfDescriptionCmd( mode, args ):
   allVrfConfig.vrf[ mode.vrfName ].description = ''

def handlerVrfRouteExpectedCountCmd( mode, args ):
   allVrfConfig.vrf[ mode.vrfName ].expectedRouteCount = \
      TristateU32.valueSet( int( args[ "NUMBER" ] ) )

def noOrDefaultHandlerVrfRouteExpectedCountCmd( mode, args ):
   allVrfConfig.vrf[ mode.vrfName ].expectedRouteCount = \
      TristateU32.valueInvalid()

def handlerVrfDefaultsRouteExpectedCountCmd( mode, args ):
   vrfDefaultsConfig.expectedRouteCount = \
      TristateU32.valueSet( int( args[ "NUMBER" ] ) )

def noOrDefaultHandlerVrfDefaultsRouteExpectedCountCmd( mode, args ):
   vrfDefaultsConfig.expectedRouteCount = TristateU32.valueInvalid()

def handlerVrfDefaultsCmd( mode, args ):
   childMode = mode.childMode( VrfDefaultsMode )
   mode.session_.gotoChildMode( childMode )

def noOrDefaultHandlerVrfDefaultsCmd( mode, args ):
   vrfDefaultsConfig.expectedRouteCount = TristateU32.valueInvalid()

def getVrfReservedPortsEntryModel( mode, vrfName ):
   vrfRpEntryModel = IraVrfModel.VrfReservedPortsEntry()

   if allVrfConfig.kernelNetConfig and \
         allVrfConfig.kernelNetConfig.get( vrfName ) and \
         allVrfConfig.kernelNetConfig.get( vrfName ).reservedPorts != "\n":
      vrfRpEntryModel.config = \
         allVrfConfig.kernelNetConfig.get( vrfName ).reservedPorts

   isDefaultVrf = vrfName in ( DEFAULT_VRF, DEFAULT_VRF_OLD )

   if isDefaultVrf:
      nsName = vrfName
   else:
      nsName = f'ns-{vrfName}'

   # VRF namespace only exists if VRF is active
   if isDefaultVrf or ( allVrfStatusLocal.vrf.get( vrfName ) and
         allVrfStatusLocal.vrf.get( vrfName ).state == 'active' ):
      with ArPyUtils.FileHandleInterceptor( [ sys.stdout.fileno() ] ) as out:
         mode.session_.runCmd( f"bash sudo ip netns exec {nsName}" +
                               f" cat {reservedPortsSysctlPath} " )

      if out.contents().strip() == "":
         vrfRpEntryModel.status = "None"
      else:
         vrfRpEntryModel.status = out.contents().strip()
   else:
      vrfRpEntryModel.status = vrfRpEntryModel.config

   return vrfRpEntryModel

def showReservedPorts( mode, vrfName=None ):
   vrfRpModel = IraVrfModel.VrfReservedPorts()

   assert allVrfConfig is not None

   if vrfName:
      vrfRpModel.vrfs[ vrfName ] = getVrfReservedPortsEntryModel( mode, vrfName )
   else:
      if allVrfConfig.kernelNetConfig:
         for vrf in allVrfConfig.kernelNetConfig:
            vrfRpModel.vrfs[ vrf ] = getVrfReservedPortsEntryModel( mode, vrf )
      for vrf in allVrfConfig.vrf:
         if not vrf in allVrfConfig.kernelNetConfig:
            vrfRpModel.vrfs[ vrf ] = getVrfReservedPortsEntryModel( mode, vrf )

   return vrfRpModel

def handlerShowVrfReservedPortsCmd( mode, args ):
   return showReservedPorts( mode, vrfName=args.get( "VRF_NAME" ) )

def nsNameFromVrfName( mode, vrfName, printErrors=True ):
   if vrfName in ( DEFAULT_VRF, DEFAULT_VRF_OLD ):
      nsName = Arnet.NsLib.DEFAULT_NS
   else:
      vrfS = allVrfStatusLocal.vrf.get( vrfName )
      if vrfS is None:
         if printErrors:
            mode.addError( f"VRF {vrfName} does not exist" )
         return None
      elif vrfS.state == 'active':
         nsName = vrfS.networkNamespace
         assert nsName != ''
      else:
         if printErrors:
            mode.addError( f"VRF {vrfName} is not active" )
         return None
   return nsName

def chVrf( mode, vrfName ):
   assert vrfName != ''
   t0( f'changing vrf context to {vrfName}' )
   nsName = nsNameFromVrfName( mode, vrfName, printErrors=True )
   if not nsName:
      return

   VrfCli.vrfMap.setCliVrf( mode.session, vrfName, nsName )

def resetCliVrfContextFromEnv( session ):
   # In case we are running under the "watch" command (a Cli in Cli), re-set the
   # parent's cli vrf context (parent saved it into the env)
   if vrfName := os.environ.get( "CLI_VRF_NAME" ):
      nsName = nsNameFromVrfName( session.mode, vrfName )
      if not nsName:
         return
      session.sessionDataIs( 'vrf', ( vrfName, nsName ) )

def handlerVrfForRtCtxCmd( mode, args ):
   if DEFAULT_VRF in args:
      vrfName = DEFAULT_VRF
   elif DEFAULT_VRF_OLD in args:
      vrfName = DEFAULT_VRF_OLD
   else:
      vrfName = args.get( "<vrfName>" )

   chVrf( mode, vrfName )

def handlerShowCliVrfCtxCmd( mode, args ):
   vrfName = VrfCli.vrfMap.getCliSessVrf( mode.session )
   vrfContext = IraVrfModel.ShowCliVrfCtxModel()

   if vrfName is not None:
      vrfContext.vrfRoutingContext = vrfName
      if vrfName not in ( DEFAULT_VRF, DEFAULT_VRF_OLD ) and\
            vrfName not in allVrfConfig.vrf.members():
         mode.addWarning( f"Warning: VRF {vrfName} does not exist" )
   else:
      vrfContext.vrfRoutingContext = "Unknown"

   # TBD: Do we care if vrfName is None. Why is this an error? What if
   # we assume that if no vrf is associated with a session, then it is
   # the default vrf and execute commands in the default vrf?
   return vrfContext

def handlerReservedPortsCmd( mode, args ):
   allVrfConfig.kernelNetConfig[ mode.vrfName ].reservedPorts =\
      str( args.get( 'PORT', '\n' ) )
   isDefaultVrf = mode.vrfName in ( DEFAULT_VRF, DEFAULT_VRF_OLD )

   if not isDefaultVrf:
      vrfS = allVrfStatusLocal.vrf.get( mode.vrfName )

   if isDefaultVrf or ( vrfS and vrfS.state == 'active' ):
      if args.get( 'PORT', '\n' ) != '\n':
         mode.addWarning( 'One or more ports provided may currently be ' +
                          'in use.' +
                          ' Please use the "show kernel ports in-use' +
                          f' vrf {mode.vrfName}" command to look at' +
                          f' the ports in use in VRF {mode.vrfName}.' )

######################################################################
# boilerplate
######################################################################

def Plugin( entityManager ):

   global vrfDefaultsConfig
   global allVrfStatusLocal
   global runnabilityConfig
   global ipConfig
   global ip6Config
   global l3ConfigDir

   vrfDefaultsConfig = ConfigMount.mount( entityManager,
         "ip/vrf/defaults/config",
         "Ip::VrfDefaultsConfig", "w" )

   allVrfStatusLocal = LazyMount.mount( entityManager,
                                        Cell.path( "ip/vrf/status/local" ),
                                        "Ip::AllVrfStatusLocal", "r" )
   runnabilityConfig = ConfigMount.mount( entityManager,
                                          "routing6/runnability/config",
                                          "Routing6::Runnability::Config", "w" )

   ipConfig = ConfigMount.mount( entityManager, "ip/config", "Ip::Config", "w" )
   ip6Config = ConfigMount.mount( entityManager, "ip6/config", "Ip6::Config", "w" )
   l3ConfigDir = ConfigMount.mount( entityManager, "l3/intf/config",
                                    "L3::Intf::ConfigDir", "w" )

   # For running cli vrf context commands under the watch command
   MainCli.newCliSessionHook.addExtension( resetCliVrfContextFromEnv )
