# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import BasicCli
from CliPlugin import (
   TunnelFibModel,
   TunnelModels,
)
import CliPlugin.TunnelCli
from CliPlugin.TunnelCli import(
   getEndpointFromTunnelId,
   getTunnelEntry,
   getTunnelIdModel,
   tokenMulticastMatcher,
   tokenTunnelFibMatcher,
   tokenTunnelMatcher,
   tunnelIndexMatcher,
)
from CliPlugin.TunnelFibModel import (
   getTunnelFibViaModel,
   getTunnelFibToUse,
   getTunnelViaModelFromDynTunIntf,
   getFecFromIntfId,
   getMplsViaModelFromFecVia,
)
from CliPlugin.TunnelModels import (
   getTunnelViaModelFromTunnelIntf,
)
import CliToken.TunnelCli
import Tac
import Tracing
from TunnelLib import (
   getTunnelViaStatusFromTunnelId,
   isDyTunIntfId,
)
from TypeFuture import TacLazyType
import ShowCommand
import LazyMount
import SmashLazyMount
from Toggles import TunnelToggleLib
from Toggles.TunnelToggleLib import toggleMvpnRsvpP2mpEnabled

__defaultTraceHandle__ = Tracing.Handle( 'TunnelFibCli' )
t0 = Tracing.trace0

FecIdIntfId = TacLazyType( 'Arnet::FecIdIntfId' )
FecId = TacLazyType( 'Smash::Fib::FecId' )
readerInfo = SmashLazyMount.mountInfo( 'reader' )
TunnelType = TacLazyType( 'Tunnel::TunnelTable::TunnelType' )

unifiedTunnelFibViewHelper = None
unifiedTunnelFib = None
tunnelMulticastFib = None
nhgStatus = None
flexAlgoConfig = None

def getAliasEndpointsFromTunnelId( tunnelId ):
   tunnelEntry = getTunnelEntry( tunnelId )
   tid = Tac.Value( "Tunnel::TunnelTable::TunnelId", tunnelId )
   tunnelType = tid.tunnelType()
   if tunnelType == TunnelType.gueTunnel:
      return None
   if tunnelEntry:
      aliasEndpoints = list( tunnelEntry.endpoint.values() )[ 1 : ]
      return aliasEndpoints if aliasEndpoints else None
   return None

def getAlgorithmFromTunnelId( tunnelId ):
   tunnelId = Tac.Value( "Tunnel::TunnelTable::TunnelId", tunnelId )
   if not tunnelId.isFlexAlgoTunnelType( tunnelId.tunnelType() ):
      return None
   algoId = tunnelId.extractFlexAlgoId()
   return flexAlgoConfig.flexAlgoName( algoId )

def getGueTunnelEntryFields( tunnelId, tunnelEntry ):
   tid = Tac.Value( "Tunnel::TunnelTable::TunnelId", tunnelId )
   tunnelType = tid.tunnelType()
   if tunnelType == 'gueTunnel':
      # Need to retrieve GueTunnelInfo from gueTunnelEncap collection in TunnelFib
      if len( tunnelEntry.tunnelVia ) > 0:
         encapId = tunnelEntry.tunnelVia[ 0 ].encapId
         gueTunnelEncap = unifiedTunnelFib.gueTunnelEncap( encapId )
         if gueTunnelEncap:
            gueTunnelInfo = gueTunnelEncap.encapInfo
            source = gueTunnelInfo.src
            payloadType = gueTunnelInfo.payloadType
            tos = gueTunnelInfo.tos
            ttl = gueTunnelInfo.ttl
            return TunnelFibModel.GueTunnelEntryFields(
               source=source, payloadType=payloadType, tos=tos, ttl=ttl )
   return None

def getViaModelFromTunnelVia( tunnelVia ):
   vias = []
   if FecIdIntfId.isFecIdIntfId( tunnelVia.intfId ):
      fec = getFecFromIntfId( tunnelVia.intfId )
      if fec is None:
         return vias
      for fecVia in sorted( fec.via.values() ):
         addr = Tac.Value( "Arnet::IpGenAddr", str( fecVia.hop ) )
         vias.append( TunnelModels.TunnelVia( nexthop=addr,
                                              interface=fecVia.intfId,
                                              type='ip' ) )
   else:
      vias.append( TunnelModels.TunnelVia( nexthop=tunnelVia.nexthop,
                                           interface=tunnelVia.intfId,
                                           type='ip' ) )
   return vias

def getMplsViaModelFromTunnelVia( tunnelVia ):
   labelStackEncap = \
      unifiedTunnelFib.labelStackEncap( tunnelVia.encapId )
   if labelStackEncap is not None:
      labelOp = labelStackEncap.labelStack
   else:
      labelOp = Tac.Value( 'Arnet::MplsLabelOperation' )
   labels = []
   for mplsStackIndex in range( labelOp.stackSize - 1, -1, -1 ):
      labels.append( str( labelOp.labelStack( mplsStackIndex ) ) )
   vias = []
   if FecIdIntfId.isFecIdIntfId( tunnelVia.intfId ):
      fec = getFecFromIntfId( tunnelVia.intfId )
      if fec is None:
         return vias
      for fecVia in sorted( fec.via.values() ):
         addr = Tac.Value( "Arnet::IpGenAddr", str( fecVia.hop ) )
         vias.append( TunnelModels.MplsVia( nexthop=addr,
                                            interface=fecVia.intfId,
                                            type='ip',
                                            labels=labels ) )
   else:
      vias.append( TunnelModels.MplsVia( nexthop=tunnelVia.nexthop,
                                         interface=tunnelVia.intfId,
                                         type='ip',
                                         labels=labels ) )
   return vias

#------------------------------------------------------------------------
# Auxillary function to extract vias from tunnel entries containing
# optimized fecIds
#------------------------------------------------------------------------
def getTunnelFecViaModel( fecVia, encapId, resolvingTunnel=None, multicast=False ):
   tunnelFibToUse = getTunnelFibToUse( multicast=multicast )

   viaModel = getMplsViaModelFromFecVia( tunnelFibToUse, encapId, fecVia )
   viaModel.resolvingTunnel = resolvingTunnel

   return viaModel

def getTunnelFibViaFromFec( via, encapId, cachedNHG ):
   viaList = []
   fec = getFecFromIntfId( via.intfId )
   if fec is None:
      return viaList

   ucmpEligibleFlag = Tac.enumValue( "Smash::Fib::FecFlags", "ucmpEligible" )
   ucmpFecFlag = bool( fec.fecFlags & ucmpEligibleFlag )

   for fecVia in sorted( fec.via.values() ):
      if FecIdIntfId.isFecIdIntfId( fecVia.intfId ):
         viaList.extend( getTunnelFibViaFromFec( fecVia, encapId, cachedNHG ) )
      elif isDyTunIntfId( fecVia.intfId ):
         resolvingTunnel = getTunnelViaModelFromTunnelIntf( fecVia.intfId )
         viaModel = getTunnelFecViaModel( fecVia, encapId,
                                          resolvingTunnel=resolvingTunnel )
         if viaModel:
            viaList.append( viaModel )
      else:
         addr = Tac.Value( "Arnet::IpGenAddr", str( fecVia.hop ) )
         interface = fecVia.intfId
         createdTunVia = Tac.newInstance( 'Tunnel::TunnelTable::TunnelVia',
                                          addr, interface, encapId )
         viaModel = getTunnelFibViaModel( createdTunVia, cachedNHG,
                                          ucmpEligible=ucmpFecFlag )
         if viaModel:
            viaList.append( viaModel )

   return viaList

#-------------------------------------------------------------------------
# The "show tunnel fib" command.
#-------------------------------------------------------------------------
def getTunnelFibEntryModel( tunnelId, encapFilter, cachedNHG, debug=False,
                            multicast=False ):
   tunnelType = Tac.Value( 'Tunnel::TunnelTable::TunnelId', tunnelId ).tunnelType()
   gueTunnel = tunnelType == TunnelType.gueTunnel
   dsfTunnel = tunnelType == TunnelType.voqFabricTunnel
   tunnelFibToUse = getTunnelFibToUse( multicast=multicast, gueTunnel=gueTunnel,
                                       dsfTunnel=dsfTunnel )
   tunnelFibEntry = tunnelFibToUse.entry.get( tunnelId )
   if tunnelFibEntry is None:
      return None

   def skipVia( tacTunVia ):
      return ( ( encapFilter is not None and
                 encapFilter != tacTunVia.encapId.encapType ) or
               tacTunVia == Tac.newInstance( "Tunnel::TunnelTable::TunnelVia" ) )

   viaList = list()
   ucmpEligibleFlag = tunnelFibEntry.flags.ucmpEligible
   for tunVia in tunnelFibEntry.tunnelVia.values():
      if skipVia( tunVia ):
         continue
      if FecIdIntfId.isSrTePolicyIntfId( tunVia.intfId ):
         tid = FecIdIntfId.tunnelId( tunVia.intfId )
         resolvingTunnel = getTunnelIdModel( tid )
         viaModel = getTunnelFibViaModel(
            tunVia, cachedNHG, resolvingTunnel=resolvingTunnel,
            multicast=multicast, gueTunnel=gueTunnel, dsfTunnel=dsfTunnel,
            ucmpEligible=ucmpEligibleFlag )
         if viaModel:
            viaList.append( viaModel )
      elif FecIdIntfId.isFecIdIntfId( tunVia.intfId ):
         viaList = getTunnelFibViaFromFec( tunVia, tunVia.encapId, cachedNHG )
      else:
         resolvingTunnel = getTunnelViaModelFromDynTunIntf( tunVia.intfId )
         viaModel = getTunnelFibViaModel(
            tunVia, cachedNHG, resolvingTunnel=resolvingTunnel,
            multicast=multicast, gueTunnel=gueTunnel, dsfTunnel=dsfTunnel,
            ucmpEligible=ucmpEligibleFlag )
         if viaModel:
            viaList.append( viaModel )
   if not viaList:
      return None

   # Get the backupTunnelVia
   backupViaList = list()
   for backupTunnelVia in tunnelFibEntry.backupTunnelVia.values():
      if skipVia( backupTunnelVia ):
         continue

      assert not FecIdIntfId.isSrTePolicyIntfId( backupTunnelVia.intfId ), \
             "Unexpected SR TE policy interface ID for backup via"
      assert not FecIdIntfId.isFecIdIntfId( backupTunnelVia.intfId ), \
             "Unexpected FEC ID interface ID for backup via"

      resolvingTunnel = getTunnelViaModelFromDynTunIntf( backupTunnelVia.intfId )
      viaModel = getTunnelFibViaModel(
         backupTunnelVia, cachedNHG, debug,
         resolvingTunnel=resolvingTunnel, backup=True, multicast=multicast,
         gueTunnel=gueTunnel, dsfTunnel=dsfTunnel )

      if viaModel:
         backupViaList.append( viaModel )

   tacTunnelId = Tac.Value( "Tunnel::TunnelTable::TunnelId", tunnelId )
   tunnelViaStatus = getTunnelViaStatusFromTunnelId(
         tunnelId,
         CliPlugin.TunnelCli.programmingStatus,
         multicast=multicast )

   entryModel = TunnelFibModel.TunnelFibEntry(
         tunnelType=tacTunnelId.typeCliStr(), tunnelIndex=tacTunnelId.tunnelIndex(),
         endpoint=getEndpointFromTunnelId( tunnelId ),
         aliasEndpoints=getAliasEndpointsFromTunnelId( tunnelId ),
         algorithm=getAlgorithmFromTunnelId( tunnelId ), vias=viaList,
         backupVias=backupViaList,
         tunnelViaStatus=tunnelViaStatus,
         gueTunnelEntryFields=getGueTunnelEntryFields( tunnelId, tunnelFibEntry ) )
   if debug:
      entryModel.tunnelId = tunnelId
      entryModel.tunnelFecId = FecId.tunnelIdToFecId( tacTunnelId )
      entryModel.interface = tacTunnelId.intfCliStr()
      entryModel.seqNo = tunnelFibEntry.seqNo

   return entryModel

def getTunnelFibModel( typeFilter=None, encapFilter=None, indexFilter=None,
                       debug=False, multicast=False ):
   t0( 'getTunnelFibModel: typeFilter', typeFilter, 'encapFilter', encapFilter,
       'indexFilter', indexFilter, 'debug', debug )

   def skipTunnel( tacTunnelId ):
      # Match on tunnel type. If no type filter, skip any hidden tunnel types.
      if typeFilter is not None:
         if tacTunnelId.tunnelType() != typeFilter:
            return True

      # Match on tunnel index (integer, possibly zero )
      if indexFilter is not None:
         if int( tacTunnelId.tunnelIndex() ) != int( indexFilter ):
            return True
      return False

   cachedNHG = {}
   for key, entry in nhgStatus.nexthopGroupEntry.items():
      cachedNHG[ entry.nhgId ] = key.nhgName()
   fibModel = TunnelFibModel.TunnelFib()
   fib = tunnelMulticastFib if multicast else unifiedTunnelFib
   for tid in fib.entry:
      tacTunnelId = Tac.Value( 'Tunnel::TunnelTable::TunnelId', tid )
      if skipTunnel( tacTunnelId ):
         continue

      entryModel = getTunnelFibEntryModel( tid, encapFilter, cachedNHG, debug,
                                           multicast=multicast )
      if entryModel is None:
         continue
      key = tacTunnelId.typeCliStr()
      if key not in fibModel.categories:
         fibModel.categories[ key ] = TunnelFibModel.TunnelFibCategory()
      subclass = fibModel.categories[ key ]
      subclass.entries[ int( tacTunnelId.tunnelIndex() ) ] = entryModel

   t0( 'getTunnelFibModel: finished' )
   return fibModel

unicastParseTable = [
      ( frozenset( [ 'bgp', 'labeled-unicast' ] ), 'bgpLuTunnel', 'mplsEncap' ),
      ( frozenset( [ 'bgp', 'labeled-unicast', 'forwarding' ] ),
         'bgpLuForwardingTunnel', 'mplsEncap' ),
      ( frozenset( [ 'bgp', 'udp' ] ), 'gueTunnel', 'gueEncap' ),
      ( frozenset( [ 'static', 'mpls' ] ), 'staticTunnel', 'mplsEncap' ),
      ( frozenset( [ 'static', 'interface', 'gre' ] ),
         'staticInterfaceTunnel', 'greEncap' ),
      ( frozenset( [ 'static', 'interface', 'ipsec' ] ),
         'staticInterfaceTunnel', 'ipSecEncap' ),
      ( frozenset( [ 'static', 'interface', 'gre-over-ipsec' ] ),
         'staticInterfaceTunnel', 'ipSecGreEncap' ),
      ( frozenset( [ 'ldp' ] ), 'ldpTunnel', 'mplsEncap' ),
      ( frozenset( [ 'rsvp' ] ), 'rsvpLerTunnel', 'mplsEncap' ),
      ( frozenset( [ 'rsvp-frr' ] ), 'rsvpFrrTunnel', 'mplsEncap' ),
      ( frozenset( [ 'rsvp-sub' ] ), 'rsvpLerSubTunnel', 'mplsEncap' ),
      ( frozenset( [ 'ti-lfa' ] ), 'tiLfaTunnel', 'mplsEncap' ),
      ( frozenset( [ 'isis', 'segment-routing' ] ), 'srTunnel', 'mplsEncap' ),
      ( frozenset( [ 'ospf', 'segment-routing' ] ), 'ospfSrTunnel', 'mplsEncap' ),
      ( frozenset( [ 'traffic-engineering', 'segment-routing', 'policy' ] ),
         'srTeSegmentListTunnel', 'mplsEncap' ),
      ( frozenset( [ 'nexthop-group' ] ), 'nexthopGroupTunnel', None ),
      ( frozenset( [ 'isis', 'flex-algo' ] ), 'isisFlexAlgoTunnel', 'mplsEncap' ),
      ( frozenset( [ 'segment-routing', 'ipv6', 'transport' ] ),
         'srv6TransportTunnel', 'srv6TunnelEncap' ),
      ( frozenset( [ 'voq', 'fabric' ] ), 'voqFabricTunnel', 'dsfEncap' ),
      ( frozenset(), None, None ), # Guard
      ]

multicastParseTable = [
      ( frozenset( [ 'mldp' ] ), 'mldpP2mpTunnel', 'mplsEncap' ),
      ( frozenset( [ 'static' ] ), 'staticMcastTunnel', 'mplsEncap' ),
      ( frozenset( [ 'rsvp' ] ), 'rsvpP2mpTunnel', 'mplsEncap' ),
      ( frozenset(), None, None ), # Guard
      ]

class ShowTunnelFibCmd( ShowCommand.ShowCliCommandClass ):
   # Normally, there would be less dup code, but for 'isis flex algo',
   # the tunnel index token is preceded by 'index'.
   syntax = '''
         show tunnel fib [ debug ]
         [ ( isis flex-algo [ index TUNNEL_INDEX ] )
         | ( ldp
           | nexthop-group
           | rsvp
           | rsvp-frr
           | rsvp-sub
           | ti-lfa
           | ( bgp labeled-unicast [ forwarding ] )
           | ( isis segment-routing )
           | ( static ( mpls
                      | ( interface ( gre
                                    | ipsec
                                    | gre-over-ipsec ) ) ) )
           | ( traffic-engineering segment-routing policy )'''
   if TunnelToggleLib.toggleSrv6TransportTunnelEnabled():
      syntax += '| ( segment-routing ipv6 transport )'
   if TunnelToggleLib.toggleOspfSegmentRoutingTunnelEnabled():
      syntax += '| ( ospf segment-routing )'
   if TunnelToggleLib.toggleDynamicGueTunnelsEnabled():
      syntax += '| ( bgp udp )'
   if TunnelToggleLib.toggleDsfPhase1Enabled():
      syntax += '| ( voq fabric )'
   # Add optional tunnel index.
   syntax += ' [ TUNNEL_INDEX ]'
   syntax += ')' # Close the paren before the 'ldp' keyword.
   syntax += ']' # Close the bracket after the 'debug' keyword.

   data = {
         'bgp': CliToken.TunnelCli.bgpKw,
         'debug': 'Show debugging information',
         'fabric': CliToken.TunnelCli.fabricKw,
         'fib': tokenTunnelFibMatcher,
         'flex-algo': 'Flexible algorithm',
         'gre': 'Generic Routing Encapsulation',
         'ipsec': 'IP Security',
         'gre-over-ipsec': 'GRE over IPsec',
         'index': 'Tunnel index',
         'interface': 'Tunnel interface',
         'ipv6': 'IPv6',
         'isis': CliToken.TunnelCli.isisKw,
         'ospf': CliToken.TunnelCli.ospfKw,
         'labeled-unicast': CliToken.TunnelCli.labeledUnicastKw,
         'ldp': CliToken.TunnelCli.ldpKw,
         'mpls': CliToken.TunnelCli.mplsKw,
         'nexthop-group': 'Nexthop Group',
         'policy': CliToken.TunnelCli.policyKw,
         'forwarding': CliToken.TunnelCli.forwardingKw,
         'rsvp': 'Resource Reservation Protocol LER',
         'rsvp-frr': 'Resource Reservation Protocol Fast Reroute',
         'rsvp-sub': 'Resource Reservation Protocol LER sub tunnel',
         'segment-routing': CliToken.TunnelCli.segmentRoutingKw,
         'static': CliToken.TunnelCli.staticKw,
         'ti-lfa': CliToken.TunnelCli.tilfaKw,
         'traffic-engineering': CliToken.TunnelCli.trafficEngineeringKw,
         'transport': 'Basic SRv6 transport',
         'tunnel': tokenTunnelMatcher,
         'TUNNEL_INDEX': tunnelIndexMatcher,
         'udp': CliToken.TunnelCli.udpKw,
         'voq': CliToken.TunnelCli.voqKw,
         }
   cliModel = TunnelFibModel.TunnelFib

   @staticmethod
   def handler( mode, args ):
      t0( 'ShowTunnelFibCmd handler args:', list( args.values() ) )
      tunnelIndex = args.pop( 'TUNNEL_INDEX', None )
      debug = bool( args.pop( 'debug', None ) )
      for key in [ 'show', 'tunnel', 'fib', 'index' ]:
         args.pop( key, None )
      argSet = set( args )
      for tokenSet, tunnelType, encapType in unicastParseTable:
         if tokenSet == argSet:
            return getTunnelFibModel( typeFilter=tunnelType, encapFilter=encapType,
                                      indexFilter=tunnelIndex, debug=debug )
      raise Tac.InternalException( 'Parse table inconsistency' )

BasicCli.addShowCommandClass( ShowTunnelFibCmd )

# -------------------------------------------------------------------------
# The "show tunnel fib multicast" command.
# -------------------------------------------------------------------------
showTunnelFibSyntax = '''show tunnel fib multicast
            [ ( mldp | static ) [ TUNNEL_INDEX ] ]'''
showTunnelFibData = {
      'tunnel': tokenTunnelMatcher,
      'multicast': tokenMulticastMatcher,
      'fib': tokenTunnelFibMatcher,
      'mldp': CliToken.TunnelCli.mldpKw,
      'static': CliToken.TunnelCli.staticKw,
      'TUNNEL_INDEX': tunnelIndexMatcher
      }
if toggleMvpnRsvpP2mpEnabled():
   showTunnelFibSyntax = '''show tunnel fib multicast
               [ ( mldp | rsvp | static ) [ TUNNEL_INDEX ] ]'''
   showTunnelFibData[ 'rsvp' ] = CliToken.TunnelCli.rsvpKw

class ShowTunnelFibMulticastCmd( ShowCommand.ShowCliCommandClass ):
   syntax = showTunnelFibSyntax
   data = showTunnelFibData
   cliModel = TunnelFibModel.TunnelFib

   @staticmethod
   def handler( mode, args ):
      tunnelIndex = args.get( 'TUNNEL_INDEX' )
      for key in [ 'show', 'tunnel', 'multicast', 'fib', 'TUNNEL_INDEX', 'index' ]:
         args.pop( key, None )
      argSet = set( args )
      for tokenSet, tunnelType, encapType in multicastParseTable:
         if tokenSet == argSet:
            return getTunnelFibModel( typeFilter=tunnelType, encapFilter=encapType,
                                      indexFilter=tunnelIndex, debug=False,
                                      multicast=True )
      raise Tac.InternalException( 'Parse table inconsistency' )

BasicCli.addShowCommandClass( ShowTunnelFibMulticastCmd )

#------------------------------------------------------------------------
# Plugin
#------------------------------------------------------------------------
def Plugin( entityManager ):
   global nhgStatus
   global unifiedTunnelFibViewHelper
   global unifiedTunnelFib
   global tunnelMulticastFib
   global flexAlgoConfig

   unifiedTunnelFibViewHelper = Tac.newInstance(
      "Tunnel::TunnelFib::UnifiedTunnelFibViewHelper",
      entityManager.cEntityManager(), "reader" )
   unifiedTunnelFib = unifiedTunnelFibViewHelper.unifiedView
   tunnelMulticastFib = SmashLazyMount.mount(
      entityManager, 'tunnel/tunnelMfib', 'Tunnel::TunnelFib::TunnelFib',
      readerInfo )
   nhgStatus = SmashLazyMount.mount( entityManager,
                                     "routing/nexthopgroup/entrystatus",
                                     "NexthopGroup::EntryStatus",
                                     readerInfo )
   flexAlgoConfig = LazyMount.mount( entityManager,
                                     'te/flexalgo/config', 'FlexAlgo::Config', 'r' )
