# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import sys

import Arnet
import CliCommand
from CliDynamicSymbol import CliDynamicPlugin
from CliModel import cliPrinted
from CliPlugin.MplsCli import (
      labelOperation,
      mplsSupported,
      showMplsConfigWarnings,
      validateLabelStackSize,
)
from CliPlugin import MplsDebugCli as Globals
from CliPlugin import MplsModel
from CliPlugin.TunnelCli import getTunnelIdFromIndex
import LazyMount
from MplsTypeLib import tunnelTypeXlate
import Tac
from TypeFuture import TacLazyType

BridgingCliHandler = CliDynamicPlugin( "BridgingCliHandler" )
MplsCliHandler = CliDynamicPlugin( "MplsCliHandler" )

AddressFamily = TacLazyType( 'Arnet::AddressFamily' )
FecId = Tac.Type( 'Smash::Fib::FecId' )
IntfId = Tac.Type( 'Arnet::IntfId' )
LfibFlag = Tac.Type( 'Mpls::LfibFlag' )
LfibSysdbIpLookupVia = Tac.Type( 'Mpls::LfibSysdbIpLookupVia' )
LfibSysdbVlanVia = Tac.Type( 'Mpls::LfibSysdbVlanVia' )
LfibSysdbVlanFloodVia = Tac.Type( 'Mpls::LfibSysdbVlanFloodVia' )
LfibSysdbESFilterVia = Tac.Type( 'Mpls::LfibSysdbEthernetSegmentFilterVia' )
LfibSysdbExtFecVia = Tac.Type( 'Mpls::LfibSysdbExtFecVia' )
MplsVia = TacLazyType( 'Tunnel::TunnelTable::MplsVia' )
PayloadType = Tac.Type( 'Mpls::PayloadType' )
StaticTunnelConfigEntry = TacLazyType( 'Tunnel::Static::StaticTunnelConfigEntry' )
TunnelIdConstants = TacLazyType( 'Tunnel::TunnelTable::TunnelIdConstants' )

def MplsDebugIpLookupVia_handler( mode, args ):
   label = args[ 'LABEL' ]
   vrfName = args[ 'VRF' ]
   payloadType = PayloadType.ipv4 if 'ipv4' in args else PayloadType.ipv6

   # Instantiate a via which contains info for vrf-name and addrFamily
   ipLookupVia = LfibSysdbIpLookupVia( label, vrfName, payloadType )

   # Actually add it into Sysdb.
   Globals.debugLfib.ipLookupVia.addMember( ipLookupVia )

def MplsDebugIpLookupVia_noOrDefaultHandler( mode, args ):
   del Globals.debugLfib.ipLookupVia[ args[ 'LABEL' ] ]

def MplsDebugVlanVia_handler( mode, args ):
   label = args[ 'LABEL' ]
   vlanId = args[ 'VLAN' ]
   controlWordPresent = LfibFlag( int( 'true' in args ) )

   # Instantiate a via which contains info for vlanId and control-word
   vlanVia = LfibSysdbVlanVia( label, vlanId, controlWordPresent )

   # Actually add it into Sysdb
   Globals.debugLfib.vlanVia.addMember( vlanVia )

def MplsDebugVlanVia_noOrDefaultHandler( mode, args ):
   del Globals.debugLfib.vlanVia[ args[ 'LABEL' ] ]

def MplsDebugVlanFloodVia_handler( mode, args ):
   label = args[ 'LABEL' ]
   vlanId = args[ 'VLAN' ]
   controlWordPresent = LfibFlag( int( 'true' in args ) )

   # Instantiate a via which contains info for vlanId and control-word
   vlanFloodVia = LfibSysdbVlanFloodVia( label, vlanId, controlWordPresent )

   # Actually add it into Sysdb
   Globals.debugLfib.vlanFloodVia.addMember( vlanFloodVia )

def MplsDebugVlanFloodVia_noOrDefaultHandler( mode, args ):
   del Globals.debugLfib.vlanFloodVia[ args[ 'LABEL' ] ]

def MplsDebugEthernetSegmentFilterVia_handler( mode, args ):
   label = args[ 'LABEL' ]
   intf = args[ 'INTF' ]

   eSFilterVia = LfibSysdbESFilterVia( label, IntfId( intf.name ) )
   Globals.debugLfib.ethernetSegmentFilterVia.addMember( eSFilterVia )

def MplsDebugEthernetSegmentFilterVia_noOrDefaultHandler( mode, args ):
   del Globals.debugLfib.ethernetSegmentFilterVia[ args[ 'LABEL' ] ]

def MacAddrDebugStaticTunnel_handler( mode, args ):
   no = CliCommand.isNoOrDefaultCmd( args )
   tunnelIntfId = Globals.getTunnelIntfId( args[ 'TUN_TYPE' ], args[ 'TUN_IDX' ] )
   BridgingCliHandler.setStaticHelper( mode, args, [ tunnelIntfId ],
                                       args[ 'LABEL' ], no=no )

def getTunnelIntfIdFromTunnelInfo( tunnelType, tunnelIndex ):
   ipv6 = tunnelIndex & TunnelIdConstants.tunnelAfMask
   af = 'ipv6' if ipv6 else 'ipv4'
   tunnelId = getTunnelIdFromIndex( tunnelTypeXlate[ tunnelType ], tunnelIndex,
                                    af=af )
   tunnelIntfId = Tac.Type( 'Arnet::DynamicTunnelIntfId' ).tunnelIdToIntfId(
                     tunnelId )
   return tunnelIntfId

def MplsStaticTopLabelTunnel_handler( mode, args ):
   # Relevant arguments.
   label = args[ 'LABEL' ]
   if 'TTYPE' in args:
      addr = None
      intf = getTunnelIntfIdFromTunnelInfo( args[ 'TTYPE' ],
                                            args[ 'TINDEX' ] )
   else:
      addr = args.get( 'ADDR_ON_INTF' ) or args.get( 'ADDR' )
      intf = args.get( 'INTF' )

   if 'pop' in args:
      labelOp = {
            'pop': ( args.get( 'PTYPE', 'autoDecide' ), False ),
      }
   elif 'swap-label' in args:
      labelOp = {
            'swap': {
                  'outLabel': args.get( 'OUT_LABEL' ),
            },
      }
   else:
      labelOp = None

   via = {
         'intfNexthop': {
               'intf': intf,
               'nexthop': addr,
         },
         'metric': args.get( 'METRIC' ),
         'labelOp': labelOp,
         'backup': args.get( 'backup' ),
   }
   MplsCliHandler.handleMplsLabelConfig( mode, [ label ], via )

def MplsStaticTopLabelTunnel_noOrDefaultHandler( mode, args ):
   # Relevant arguments.
   label = args[ 'LABEL' ]
   tunnelType = args.get( 'TTYPE' )
   tunnelIndex = args.get( 'TINDEX' )
   addr = args.get( 'ADDR_ON_INTF' ) or args.get( 'ADDR' )
   if 'pop' in args:
      labelOp = {
         'pop': ( args.get( 'PTYPE', 'autoDecide' ), False ),
      }
   elif 'swap-label' in args:
      labelOp = {
         'swap': {
            'outLabel': args.get( 'OUT_LABEL' ),
         },
      }
   else:
      labelOp = None

   if ( tunnelType and tunnelIndex ) or addr:
      if addr:
         via = {
            'intfNexthop': {
               'intf': args.get( 'INTF' ),
               'nexthop': addr,
            },
         }
      else:
         tunnelIntfId = getTunnelIntfIdFromTunnelInfo( tunnelType, tunnelIndex )
         via = {
            'intfNexthop': {
               'intf': tunnelIntfId,
            },
         }
      via[ 'labelOp' ] = labelOp
      via[ 'metric' ] = args.get( 'METRIC' )
      via[ 'backup' ] = args.get( 'backup' )
   else:
      via = None
   MplsCliHandler.handleNoMplsLabelConfig( mode, [ label ], via )

def ShowMplsLfibInputBgpLuCmd_handler( mode, args ):
   showMplsConfigWarnings( mode )
   capiRevision = mode.session_.requestedModelRevision()

   LazyMount.force( Globals.lfibInfo )
   if not Globals.lfibInfo:
      return cliPrinted( MplsModel.MplsRoutes() )

   labelStack = Tac.Value( "Arnet::BoundedMplsLabelStack" )
   dispFilter = Tac.Type( "Mpls::DisplayFilter" ).bgpLuInputFilter

   sys.stdout.flush()
   fd = sys.stdout.fileno()
   fmt = mode.session_.outputFormat()
   helper = Tac.newInstance( "Mpls::MplsRouteHelper", labelStack, capiRevision,
                              mplsSupported() )
   helper.lfibInfo = Globals.lfibInfo
   helper.tunnelFib = Globals.tunnelFib
   helper.srteForwardingStatus = Globals.srteForwardingStatus
   helper.srTePolicyStatus = Globals.srTePolicyStatus
   helper.rsvpLerTunnelTable = Globals.rsvpLerTunnelTable
   helper.lfibVskOverrideConfig = Globals.lfibVskOverrideConfig
   helper.lfibViaSetStatus = Globals.lfibViaSetStatus
   helper.render( fd, fmt, None, None, None, None, None, None,
                  Globals.bgpLuLfibInput, dispFilter, True, False )
   return cliPrinted( MplsModel.MplsRoutes )

def MplsStaticTopLabelFec_handler( mode, args ):
   label = args[ 'LABEL' ]
   fecType = args[ 'FTYPE' ]
   fecIndex = args[ 'FINDEX' ]
   fecId = FecId.fecIdForAdjType( Globals.shortFecTypeToCompleteType[ fecType ],
                                    fecIndex )
   labelAction = 'pop' if 'pop' in args else 'forward'

   payloadType = args.get( 'PTYPE', 'autoDecide' )

   extFecVia = LfibSysdbExtFecVia( label, fecId, labelAction )
   extFecVia.payloadType = payloadType

   # Add it into Sysdb.
   Globals.debugLfib.extFecVia.addMember( extFecVia )

def MplsStaticTopLabelFec_noOrDefaultHandler( mode, args ):
   del Globals.debugLfib.extFecVia[ args[ 'LABEL' ] ]

def MplsDebugStaticTunnel_handler( mode, args ):
   tunnelName = args[ 'NAME' ]
   endpoint = ( args.get( 'TEP_V4_ADDR' ) or args.get( 'TEP_V4_PREFIX' ) or
                  args.get( 'TEP_V6_ADDR' ) or args.get( 'TEP_V6_PREFIX' ) )
   nexthop = args.get( 'NEXTHOP_V4' ) or args.get( 'NEXTHOP_V6' )
   intf = args.get( 'INTF' )
   resolvingPrefix = args.get( 'R_TEP' ) or args.get( 'RT_TEP' )
   resolvingTunnelType = args.get( 'RT_TYPE' )
   resolvingTunnelIndex = args.get( 'RT_INDEX' )
   if 'imp-null-tunnel' in args:
      labels = [ 3 ]
   else:
      labels = args.get( 'LABELS' )
      if not validateLabelStackSize( mode, labels ):
         return

   if 'resolving' in args:
      resolvingPrefixAddr = Arnet.IpGenPrefix( str( resolvingPrefix ) )
      if resolvingPrefixAddr.af == AddressFamily.ipv4:
         prefixKey = Arnet.Prefix( resolvingPrefix )
         route = Globals.routingStatus.route.get( prefixKey )
      elif resolvingPrefixAddr.af == AddressFamily.ipv6:
         prefixKey = Arnet.Ip6Prefix( resolvingPrefix )
         route = Globals.routing6Status.route.get( prefixKey )
      else:
         mode.addError( "Invalid resolving prefix: neither IPv4 nor IPv6." )
         return
      if route is None:
         # pylint: disable-next=consider-using-f-string
         mode.addError( "Route to %s not found." % resolvingPrefix )
         return
      nexthopAddr = Arnet.IpGenAddr( '0.0.0.0' )
      fecId = route.fecId
      usedByTunnelFecId = FecId.convertFecIdForUseByTunnel( fecId )
      intfId = FecId.fecIdToHierarchicalIntfId( usedByTunnelFecId )
      tunnelEndpoint = resolvingPrefixAddr
   elif 'resolving-tunnel' in args:
      resolvingPrefixAddr = Arnet.IpGenPrefix( str( resolvingPrefix ) )
      nexthopAddr = Arnet.IpGenAddr( '0.0.0.0' )
      intfId = getTunnelIntfIdFromTunnelInfo( resolvingTunnelType,
                                                resolvingTunnelIndex )
      tunnelEndpoint = resolvingPrefixAddr
   else:
      nexthopAddr, tunnelEndpoint, intfId = \
         MplsCliHandler.getMplsTunnelNhTepAndIntfId( mode, nexthop, endpoint, intf )
      if ( nexthopAddr, tunnelEndpoint, intfId ) == ( None, None, None ):
         return

   staticTunnelConfigEntry = Globals.staticTunnelConfig.entry.get( tunnelName )
   if staticTunnelConfigEntry and staticTunnelConfigEntry.tep == tunnelEndpoint:
      newStaticTunnelConfigEntry = \
         MplsCliHandler.copyStaticTunnelConfigEntry( staticTunnelConfigEntry )
   else:
      newStaticTunnelConfigEntry = StaticTunnelConfigEntry( tunnelName )
      newStaticTunnelConfigEntry.tep = tunnelEndpoint
   # add newly configured via to either via or backupVia
   newMplsVia = MplsVia( nexthopAddr, intfId, labelOperation( labels ) )
   if 'backup' in args:
      newStaticTunnelConfigEntry.backupVia = newMplsVia
   else:
      # Clear existing vias if any, and add newly created via
      newStaticTunnelConfigEntry.via.clear()
      newStaticTunnelConfigEntry.via[ newMplsVia ] = True
      newStaticTunnelConfigEntry.inStaticTunnelMode = False

   if staticTunnelConfigEntry:
      # Check for duplicate and returns if it is the case
      if staticTunnelConfigEntry == newStaticTunnelConfigEntry:
         return
   Globals.staticTunnelConfig.entry.addMember( newStaticTunnelConfigEntry )

def MplsDebugStaticTunnel_noOrDefaultHandler( mode, args ):
   tunnelName = args[ 'NAME' ]

   staticTunnelConfigEntry = Globals.staticTunnelConfig.entry.get( tunnelName )
   if staticTunnelConfigEntry:
      newStaticTunnelConfigEntry = \
         MplsCliHandler.copyStaticTunnelConfigEntry( staticTunnelConfigEntry )
      if 'backup' in args:
         newStaticTunnelConfigEntry.backupVia = MplsVia()
      else:
         newStaticTunnelConfigEntry.via.clear()
      # if neither primary via or backup via is still valid, then delete the
      # entry, else update it in the static tunnel config table
      if ( not newStaticTunnelConfigEntry.via and
            newStaticTunnelConfigEntry.backupVia == MplsVia() ):
         del Globals.staticTunnelConfig.entry[ tunnelName ]
      else:
         Globals.staticTunnelConfig.entry.addMember( newStaticTunnelConfigEntry )
