# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import functools
from collections import (
   defaultdict,
   namedtuple,
)

from Arnet import (
      EthAddr,
      IpGenAddr,
      IpGenPrefix,
      Ip6Prefix,
      Prefix,
   )
from GenericReactor import GenericReactor
from IpLibConsts import DEFAULT_VRF
from TypeFuture import TacLazyType
import Tac
import Tracing

th = Tracing.Handle( 'ForwardingHelper' )
t2 = th.trace2
t3 = th.trace3
t8 = th.trace8

# Tac Types
ArpKey = TacLazyType( 'Arp::Table::ArpKey' )
DynamicTunnelIntfId = TacLazyType( 'Arnet::DynamicTunnelIntfId' )
FecAdjType = TacLazyType( 'Smash::Fib::AdjType' )
FecId = TacLazyType( 'Smash::Fib::FecId' )
FecIdIntfId = TacLazyType( 'Arnet::FecIdIntfId' )
IntfId = TacLazyType( 'Arnet::IntfId' )
MplsLabel = TacLazyType( 'Arnet::MplsLabel' )
NexthopGroupEntryKey = TacLazyType( 'NexthopGroup::NexthopGroupEntryKey' )
NexthopGroupId = TacLazyType( 'Routing::NexthopGroup::NexthopGroupId' )
NexthopGroupIntfId = TacLazyType( 'Arnet::NexthopGroupIntfId' )
RouteType = TacLazyType( 'Routing::RouteType' )
TunnelId = TacLazyType( 'Tunnel::TunnelTable::TunnelId' )
TunnelType = TacLazyType( 'Tunnel::TunnelTable::TunnelType' )
MplsLabelConversion = Tac.Type( "Arnet::MplsLabelConversion" )

ARP_SMASH_DEFAULT_VRF_ID = Tac.Type( 'Vrf::VrfIdMap::VrfId' ).defaultVrf

# Return type for getRouteAndFec
RouteAndFec = namedtuple( 'RouteAndFec', [ 'route', 'fec' ] )
noMatchRouteAndFec = RouteAndFec( None, None )

# Return type for resolveL3NexthopGroup
L3NhgInfo = namedtuple( 'L3NhgInfo', [ 'route', 'fec', 'dstIp', 'labelStack' ] )
noMatchL3NhgInfo = L3NhgInfo( None, None, None, None )


# Return type for getResolvedNexthopInfo
NexthopInfo = namedtuple( 'NexthopInfo', [
   'resolved',
   'dstMac',
   'nexthopIp',
   'intf', # NB: L3 IntfId
   'labelStack', # NB: Order is [ Top, ..., Bottom ]
   'srcIntfIdStack', # NB: Each intfId corresponds to a label in the labelStack and
                     #     indicates the source
   'route',
   ] )
noMatchNexthopInfo = NexthopInfo( False, None, None, None, None, None, None )
noMatchNexthopInfoList = [ noMatchNexthopInfo ]

class Af:
   ipv4 = 'ipv4'
   ipv6 = 'ipv6'

class Counter:
   def __init__( self, initValue=0 ):
      self._value = initValue

   def inc( self, amount=1 ):
      self._value += amount

   @property
   def value( self ):
      return self._value

class MplsTunnelInfo:
   '''Helper class used by resolveTunnel()

   It is either resolved, having a valid labelStack + dstMac + intfId, or
   it will have an error message specified.
   '''
   def __init__( self, labelStack=None, srcIntfIdStack=None, dstMac=None,
                 intfId=None, nexthop=None, useBackupVias=False, errMsg=None ):
      self.resolved = ( labelStack is not None and
                        srcIntfIdStack is not None and
                        dstMac is not None and
                        intfId is not None )
      assert self.resolved or errMsg
      self.labelStack = labelStack
      self.srcIntfIdStack = srcIntfIdStack
      self.dstMac = dstMac
      self.intfId = intfId
      self.errMsg = errMsg
      self.nexthop = nexthop # informational
      self.useBackupVias = useBackupVias

   def __repr__( self ):
      if self.resolved:
         if self.labelStack:
            assert len( self.labelStack ) == len( self.srcIntfIdStack )
         return 'MplsTunnelInfo( %s, %s, %s, %s useBackupVias=%r )' % (
                                                               self.labelStack,
                                                               self.srcIntfIdStack,
                                                               self.dstMac,
                                                               self.intfId,
                                                               self.useBackupVias )
      else:
         return 'MplsTunnelInfo( errMsg=%s )' % repr( self.errMsg )

   def __eq__( self, other ):
      return repr( self ) == repr( other )

   def __lt__( self, other ):
      return repr( self ) < repr( other )

   def __ne__( self, other ):
      return repr( self ) != repr( other )

   def __le__( self, other ):
      return repr( self ) <= repr( other )

   def __gt__( self, other ):
      return repr( self ) > repr( other )

   def __ge__( self, other ):
      return repr( self ) >= repr( other )

   def __hash__( self ):
      return hash( repr( self ) )

class RouteOrFecChangeReactor:
   def __init__( self, trie, trie6,
                 fwdStatus, fwd6Status, fwdGenStatus,
                 hwStatusSm ):
      self.trie_ = trie
      self.trie6_ = trie6
      self.fwdStatus_ = fwdStatus
      self.fwd6Status_ = fwd6Status
      self.fwdGenStatus_ = fwdGenStatus
      self.hwStatusSm_ = hwStatusSm
      self.routeChangeReactor_ = GenericReactor( self.trie_, [ 'changedRoute' ],
                                                 self.handleRoute )
      self.routeChangeReactor6_ = GenericReactor( self.trie6_, [ 'changedRoute' ],
                                                  self.handleRoute )
      self.fecChangeReactor_ = GenericReactor( self.fwdStatus_, [ 'fec' ],
                                               self.handleFec )
      self.fecChangeReactor6_ = GenericReactor( self.fwd6Status_, [ 'fec' ],
                                                self.handleFec )
      self.fecChangeReactorGen_ = GenericReactor( self.fwdGenStatus_, [ 'fec' ],
                                                  self.handleFec )

   def setHwStatusSm( self, hwStatusSm ):
      self.hwStatusSm_ = hwStatusSm

   def handleRoute( self, *args, **kwargs ):
      if self.hwStatusSm_:
         self.hwStatusSm_.kickTimer()

   def handleFec( self, *args, **kwargs ):
      if self.hwStatusSm_:
         self.hwStatusSm_.kickTimer()

def unlabeledFirst( viaA, viaB ):
   '''Used to sort a list of vias with unlabeled ones first'''
   hasLabelA = viaA.getRawAttribute( 'mplsLabel' ).isValid()
   hasLabelB = viaB.getRawAttribute( 'mplsLabel' ).isValid()
   if hasLabelA != hasLabelB:
      return 1 if hasLabelA else -1
   return (viaA > viaB) - (viaA < viaB)

# BUG470932: Enhance this to resolve all vias if needed
def resolveTunnel( fwdHelper, tunnelId, depth=0 ):
   """Return an MplsTunnelInfo object with a resolved label stack,
   outer destination MAC address and egress interface, or an unresolved
   MplsTunnelInfo object with a relevant error message. This method will resolve
   an (LU) tunnel resolving over an SRTE policy. However, it will not resolve
   SR-TE segment list tunnels directly, use resolveSrTeTunnel() instead"""

   # pylint: disable=protected-access
   trace = lambda *args: fwdHelper._dtrace( depth, *args )
   tunnelFibEntry = fwdHelper.tunnelFib_.entry.get( tunnelId )
   if not tunnelFibEntry or not tunnelFibEntry.tunnelVia:
      return MplsTunnelInfo(
         errMsg="no tunnel FIB entry" )

   # Recursive resolution over backupTunnelVia not supported yet
   # Using only primary interface for now
   primaryIntfId = tunnelFibEntry.tunnelVia[ 0 ].intfId
   if ( DynamicTunnelIntfId.isDynamicTunnelIntfId( primaryIntfId ) or
        FecIdIntfId.isSrTePolicyIntfId( primaryIntfId ) or
        FecIdIntfId.isFecIdIntfId( primaryIntfId ) ):
      trace( 'resolveTunnel: Hierarchical tunnel intfId', primaryIntfId )
      # Recursively retrieve the resolved nexthop
      nexthopInfoList = fwdHelper.resolveHierarchical( True, l3IntfId=primaryIntfId,
                                                       depth=depth + 1 )
      if nexthopInfoList == noMatchNexthopInfoList:
         return MplsTunnelInfo( errMsg="tunnel {} not resolved".format(
            DynamicTunnelIntfId.tunnelId( primaryIntfId ) ) )
      nexthopInfo = nexthopInfoList[ 0 ]
      leafLabelStack = nexthopInfo.labelStack
      leafSrcIntfIdStack = nexthopInfo.srcIntfIdStack
      intfId = nexthopInfo.intf
      nexthop = nexthopInfo.nexthopIp
      dstMac = nexthopInfo.dstMac
      useBackupVias = False # Currently only resolving over primary via
      encapId = tunnelFibEntry.tunnelVia[ 0 ].encapId
   else:
      # Pick the primary via from tunnel table entry if arePrimaryViasUsable
      # is set to True and primary interface is up. Otherwise pick backup via.
      # Always pick the first via and we don't implement any ECMP hashing
      # here.
      linkUp = not fwdHelper.ethIntfStatusDir_ or \
               ( primaryIntfId in fwdHelper.ethIntfStatusDir_.intfStatus and \
                 fwdHelper.ethIntfStatusDir_.intfStatus.get(
                    primaryIntfId ).operStatus == 'intfOperUp' )
      if tunnelFibEntry.arePrimaryViasUsable and linkUp:
         tunnelVia = tunnelFibEntry.tunnelVia[ 0 ]
         useBackupVias = False
      elif tunnelFibEntry.backupTunnelVia:
         tunnelVia = tunnelFibEntry.backupTunnelVia[ 0 ]
         useBackupVias = True
      else:
         return MplsTunnelInfo( errMsg="No usable primary or backup via"
                                " found for {}".format( tunnelId ) )
      trace( 'resolveTunnel: Resolving leaf tunnel with tunnelId', tunnelId )
      leafLabelStack = []
      leafSrcIntfIdStack = []
      intfId = tunnelVia.intfId
      nexthop = tunnelVia.nexthop
      dstMac = fwdHelper.resolveL2Nexthop( tunnelVia.intfId, tunnelVia.nexthop )
      if dstMac is None:
         return MplsTunnelInfo( errMsg="L2 resolution failed for {}, {}".format(
            tunnelVia.intfId, tunnelVia.nexthop ) )
      encapId = tunnelVia.encapId

   labelStack = []
   srcIntfIdStack = []
   tunnelFib = fwdHelper.tunnelFib_
   labelStackEncap = tunnelFib.labelStackEncap.get( encapId )
   tunnelIntfId = DynamicTunnelIntfId.tunnelIdToIntfId(
         tunnelId.value
         if isinstance( tunnelId, Tac.Type( 'Tunnel::TunnelTable::TunnelId' ) )
         else tunnelId )
   if labelStackEncap:
      # Build the label/intfId stacks
      viaLabels = labelStackEncap.labelStack
      for i in range( viaLabels.calculatedStackSize ):
         labelStack.append( viaLabels.calculatedLabelStack( i ) )
      labelStack.reverse()
      srcIntfIdStack = [ tunnelIntfId ] * len( labelStack )
      labelStack = leafLabelStack + labelStack
      srcIntfIdStack = leafSrcIntfIdStack + srcIntfIdStack

      # Sanitize the implicit-nulls from the labelStack and intfIds
      if all( x == MplsLabel.implicitNull for x in labelStack ):
         # Only imp-nulls; just reduce to one
         labelStack = labelStack[ : 1 ]
         srcIntfIdStack = srcIntfIdStack[ : 1 ]
      else:
         indicesToDel = [ i for i, v in enumerate( labelStack ) if v == 3 ]
         for i in reversed( indicesToDel ):
            del labelStack[ i ]
            del srcIntfIdStack[ i ]

   if not labelStack:
      return MplsTunnelInfo( errMsg="Invalid (empty) tunnel label stack" )

   return MplsTunnelInfo( labelStack=labelStack,
                          srcIntfIdStack=srcIntfIdStack,
                          dstMac=dstMac,
                          intfId=intfId,
                          nexthop=nexthop,
                          useBackupVias=useBackupVias )

def getNhgSize( nhg, isBackup=False ):
   '''
   Takes a Routing::NexthopGroupConfigEntry and returns the number of primary or
   backup entries in the NHG.

   'isBackup' flag is set to True to get the number of backup entries, else number
   of primary entries in the NHG are returned.
   '''
   if isBackup:
      return len( nhg.backupEntry )

   if nhg.size:
      return nhg.size

   # Iterate in reverse to avoid look up of all n entries of nhg
   for i in reversed( list( nhg.destinationIpIntfInternal ) ):
      dstIp = nhg.destinationIp( i )
      if dstIp.isAddrZero:
         continue
      return i + 1

   # This is either a new nhg with an array size of zero or all
   # entries are zero addresses
   return 0

class HfecInfo:
   '''
   Intermediate result of hierarchical FEC resolution.  The "route" attribute
   is purely informational, and will be provided whenever there is a resolving
   IP route for "hop".

   The labelStack argument can only be provided as a list of integers.  For no
   labelStack, use an empty list ([]).

   The srcIntfIdStack argument is the same as the labelStack and the there should a
   corresponding contributing special intfId for each label in the labelStack.
   '''
   def __init__( self, hop, l3IntfId, dstMac, labelStack, srcIntfIdStack, route ):
      assert hop
      self.hop = str( IpGenAddr( str( hop ) ) )
      self.l3IntfId = IntfId( str( l3IntfId ) ).stringValue if l3IntfId else None
      self.dstMac = str( EthAddr( str( dstMac ) ) ) if dstMac else None
      self.labelStack = labelStack
      self.srcIntfIdStack = srcIntfIdStack
      self.route = str( IpGenPrefix( str( route ) ) ) if route else None

   def __str__( self ):
      return repr( self )

   def __repr__( self ):
      return 'HfecInfo( {}, {}, {}, {}, {}, {} )'.format(
            self.hop, self.l3IntfId, self.dstMac, self.labelStack,
            self.srcIntfIdStack, self.route )

class ForwardingHelper:
   '''
   Helper object to assist in forwarding functions.  This is used to share
   certain specific forwarding-related code with other modules.
   '''

   def __init__( self,
                 bridgingStatus,
                 vrfRoutingStatus,
                 vrfRouting6Status,
                 forwardingStatus,
                 forwarding6Status,
                 forwardingGenStatus,
                 srTeForwardingStatus,
                 srTeSegmentListTunnelTable,
                 nhgEntryStatus,
                 arpSmash,
                 tunnelFib,
                 trie4,
                 trie6,
                 vrfNameStatus,
                 intfConfigDir,
                 ethIntfStatusDir=None,
                 ):
      self.bridgingStatus = bridgingStatus
      self.vrfRoutingStatus_ = vrfRoutingStatus
      self.vrfRouting6Status_ = vrfRouting6Status
      self.forwardingStatus_ = forwardingStatus
      self.forwarding6Status_ = forwarding6Status
      self.forwardingGenStatus_ = forwardingGenStatus
      self.srTeForwardingStatus_ = srTeForwardingStatus
      self.srTeSegmentListTunnelTable_ = srTeSegmentListTunnelTable
      self.nhgEntryStatus_ = nhgEntryStatus
      self.arpSmash_ = arpSmash
      self.tunnelFib_ = tunnelFib
      self.trie4 = trie4
      self.trie6 = trie6
      self.vrfNameStatus_ = vrfNameStatus
      self.intfConfigDir_ = intfConfigDir
      self.ethIntfStatusDir_ = ethIntfStatusDir
      self._counts = defaultdict( Counter )

   def getCounterValue( self, name ):
      return self._counts[ name ].value

   @staticmethod
   def isNexthopGroupRoute( route ):
      isNhgRoute = route.routeType == RouteType.nexthopGroup
      return isNhgRoute

   def resolveFecId( self, fecId ):
      func = 'resolveFecId:'
      forwardingStatusList = (
         ( "Gen", self.forwardingGenStatus_, self._counts[ 'genAdj' ] ),
         ( Af.ipv4, self.forwardingStatus_, self._counts[ 'v4Adj' ] ),
         ( Af.ipv6, self.forwarding6Status_, self._counts[ 'v6Adj' ] ),
         ( 'SR-TE', self.srTeForwardingStatus_, self._counts[ 'srTeAdj' ] ),
      )

      # The FEC ID space is global across all FEC tables (IPv4, IPv6, SR-TE), so
      # look up the FEC ID in all tables each time and just take the first match.
      # IPv4 routes may point to IPv6 FECs, and vice versa.  Any type of route
      # may point to an SR-TE FEC.
      for fsName, fs, counter in forwardingStatusList:
         fec = fs.fec.get( fecId )
         if fec:
            t8( func, 'found adjacency in', fsName, 'fib' )
            counter.inc()
            return fec
      t8( func, 'adjacency not found in any FIB' )
      return None

   def getRouteAndFec( self, dstIp, vrfName=DEFAULT_VRF ):
      noMatch = noMatchRouteAndFec
      func = 'getRouteAndFec:'
      af = IpGenAddr( dstIp ).af

      if af == Af.ipv4:
         routePrefix = self.trie4[ vrfName ].longestMatch[ Prefix( dstIp ) ]
         routingStatus = self.vrfRoutingStatus_.get( vrfName )
      elif af == Af.ipv6:
         routePrefix = self.trie6[ vrfName ].longestMatch( Ip6Prefix( dstIp ) )
         routingStatus = self.vrfRouting6Status_.get( vrfName )
      else:
         t8( func, 'invalid af', af )
         return noMatch

      route = None
      if routingStatus:
         route = routingStatus.route.get( routePrefix )

      if not route:
         t8( func, 'found no match for dstIp', dstIp, 'routePrefix', routePrefix )
         return noMatch

      fec = self.resolveFecId( route.fecId )
      if not fec:
         t8( func, 'found no adjacency for', routePrefix )
         return noMatch

      return RouteAndFec( route, fec )

   def getNhgSmashEntry( self, nhgId=None, nhgName=None ):
      '''
      Retrieves the NexthopGroupEntry from Smash using the passed in nhgId. If the
      name is specified, it will bypass the O(n) lookup of the entry, since the coll.
      is keyed by the nhgName.
      '''
      nhgSmashEntries = self.nhgEntryStatus_.nexthopGroupEntry
      entry = None
      if nhgName:
         key = NexthopGroupEntryKey()
         key.nhgNameIs( nhgName )
         entry = nhgSmashEntries.get( key )
      else:
         if nhgId and nhgId != NexthopGroupId.null:
            for e in nhgSmashEntries.values():
               if e.nhgId == nhgId:
                  entry = e
                  break
      return entry

   def getNhgFromSysdb( self, nhgId=None, nhgName=None ):
      '''
      Retrieves the NexthopGroupConfigEntry using the passed in nhgId. If the name is
      specified, it will bypass the O(n) lookup of the nhgName in Smash and directly
      retrieve the Config from Sysdb.
      '''
      func = 'getNhgFromSysdb:'

      # Ensure that iraEtbaRoot object exists
      iraEtbaRoot = Tac.root.get( 'ira-etba-root' )
      assert iraEtbaRoot, "Couldn't find ira-etba-root tac object"

      # Get the nhgName since that is keyed to the nhg config object in Sysdb
      if not nhgName:
         if not nhgId or nhgId == NexthopGroupId.null:
            t8( func, 'invalid ID', nhgId, 'and name', nhgName )
            return None
         entry = self.getNhgSmashEntry( nhgId=nhgId )
         if not entry:
            t8( func, "no nexthop-group for ID", nhgId )
            return None
         nhgName = entry.key.nhgName()
         t8( func, 'found nexthop-group', nhgName, 'for ID', nhgId )

      nhg = iraEtbaRoot.nexthopGroupConfig.nexthopGroup.get( nhgName )
      if not nhg:
         t8( func, nhgName, 'not found in nhg config collection' )
      return nhg

   def getNexthopGroupVia( self, nhgId, entry=None ):
      nhg = self.getNhgFromSysdb( nhgId )
      nhgViaList = []
      if not nhg:
         return nhgViaList

      nhgSize = getNhgSize( nhg )
      nhgEntryIndexes = [ entry ] if entry else list( range( nhgSize ) )
      for index in nhgEntryIndexes:
         entry = nhg.entry[ index ]
         dstIp = entry.destinationIpIntf.destIp
         mplsLabelStack = entry.mplsLabelStack
         # The NHG's labelStack is stored such that index 0 is the BOS and the
         # highest filled index is the TOS. This is the opposite of how labelStack
         # lists are usually represented in code and visually in CLI. As a result,
         # we should retrieve the labels in reverse order. E.g.
         # NHG Config LabelStack: (BOS) [ 30, 20 ] (TOS)
         # Typical representation: (TOS) [ 20, 30 ] (BOS)
         nhgMplsLabelStack = []
         for i in range( mplsLabelStack.stackSize ):
            nhgMplsLabelStack.append( mplsLabelStack.labelStack( i ) )
         nhgMplsLabelStack.reverse()
         nhgViaList.append( ( dstIp, nhgMplsLabelStack ) )
      return nhgViaList

   def resolveNexthopGroupTunnel( self, depth, vrfName, tunnelId,
                                  selectOneVia=True ):
      '''
      Returns a list of resolved NHG vias. If we are only selecting a single via:
         Autosize NHG: Returns the first resolved via
         Fixed size NHG: Returns the first via
      This keeps the behavior consistent with PD.
      '''
      retList = []
      tunnelFibEntry = self.tunnelFib_.entry.get( tunnelId )
      if not tunnelFibEntry or not tunnelFibEntry.tunnelVia:
         return retList
      nhgIntfId = tunnelFibEntry.tunnelVia[ 0 ].intfId
      nhgId = NexthopGroupIntfId.nexthopGroupId( nhgIntfId )
      nhg = self.getNhgFromSysdb( nhgId )
      isAutoSize = nhg.size == 0
      nhgViaList = self.getNexthopGroupVia( nhgId )
      if not nhgViaList:
         return retList
      for nhgVia in nhgViaList:
         nhgHop, nhgLabelStack = nhgVia
         nhgIntfIdStack = [ nhgIntfId ] * len( nhgLabelStack )
         infoList = self.resolveHierarchical( allowEncap=False, hop=nhgHop,
                                              depth=depth + 1, vrfName=vrfName,
                                              selectOneVia=selectOneVia )
         if infoList == noMatchNexthopInfoList:
            if selectOneVia and not isAutoSize:
               return retList
            continue

         for info in infoList:
            ret = HfecInfo( info.nexthopIp, info.intf, info.dstMac,
                            info.labelStack + nhgLabelStack,
                            info.srcIntfIdStack + nhgIntfIdStack, info.route )
            retList.append( ret )
            if selectOneVia:
               break
      return retList

   # This method is not being used right now. But it is planned to be used in
   # future by MplsUtils package. Hence keeping this around.
   def resolveL3NexthopGroup( self, nhgId, vrfName=DEFAULT_VRF ):
      noMatch = noMatchL3NhgInfo
      func = 'resolveL3NexthopGroup:'

      nhgViaList = self.getNexthopGroupVia( nhgId )
      l3NhgInfoList = []
      for nhgVia in nhgViaList:
         dstIp, nhgMplsLabelStack = nhgVia
         if dstIp is None:
            return noMatch
         routeAndFec = self.getRouteAndFec( str( dstIp ), vrfName=vrfName )
         if routeAndFec == noMatchRouteAndFec:
            return noMatch
         route, fec = routeAndFec
         if self.isNexthopGroupRoute( route ):
            t8( func, 'nested nhgs are not supported. Returning noMatch' )
            return noMatch
         l3NhgInfo = L3NhgInfo( route, fec, dstIp, nhgMplsLabelStack )
         l3NhgInfoList.append( l3NhgInfo )
      return l3NhgInfoList

   def getArpVrfId( self, intf ):
      '''
      Get the arp VRF ID for given interface, default value set to default VRF ID
      '''
      vrfId = ARP_SMASH_DEFAULT_VRF_ID
      if self.intfConfigDir_ and intf in self.intfConfigDir_.intfConfig:
         vrfName = self.intfConfigDir_.intfConfig[ intf ].vrf
         if ( self.vrfNameStatus_ and self.vrfNameStatus_.nameToIdMap and
              vrfName in self.vrfNameStatus_.nameToIdMap.vrfNameToId ):
            vrfId = self.vrfNameStatus_.nameToIdMap.vrfNameToId[ vrfName ]
      return vrfId

   def resolveL2Nexthop( self, intf, hop ):
      '''
      Given an interface and hop, find the corresponding mac address.
      Note that the arp entry has a vrfId in the key, which can be derived
      from the interface.
      '''
      if not hop:
         return None
      hop = IpGenAddr( str( hop ) )
      arpKey = ArpKey( self.getArpVrfId( intf ), hop, intf )
      afGetFunc = {
         Af.ipv4: self.arpSmash_.arpEntry.get,
         Af.ipv6: self.arpSmash_.neighborEntry.get,
      }
      arpEntry = afGetFunc[ hop.af ]( arpKey )
      if arpEntry:
         return arpEntry.ethAddr
      else:
         return None

   def resolveL2NexthopForFibVia( self, hop, fibVia ):
      '''
      For an IP hop that resolves using fibVia, determine which hop is used, and
      return the L2 resolution for that hop.

      For, e.g. connected routes, the fibVia.hop will be unused, so we need to
      resolve the "outer" hop using the interface specified in the fibVia.
      '''
      viaGenHop = IpGenAddr( str( fibVia.hop ) )
      if viaGenHop.isAddrZero:
         viaHop = hop
      else:
         viaHop = str( viaGenHop )
      return viaHop, self.resolveL2Nexthop( fibVia.intfId, viaHop )

   def _dtrace( self, depth, *args ):
      space = ' ' * depth * 2
      t3( space, *args )

   def _isSpecialIntfId( self, l3IntfId ):
      specialTypes = (
         DynamicTunnelIntfId.isDynamicTunnelIntfId,
         FecIdIntfId.isFecIdIntfId,
         FecIdIntfId.isNexthopGroupIdIntfId,
         NexthopGroupIntfId.isNexthopGroupIntfId,
      )
      return any( specialType( l3IntfId ) for specialType in specialTypes )

   def _resolveHfecL3IntfId( self, allowEncap, hop, l3IntfId, depth,
                             vrfName=DEFAULT_VRF, selectOneVia=True ):
      trace = lambda *args: self._dtrace( depth, *args )
      assert l3IntfId, 'no l3IntfId provided'
      nhgId = None
      if NexthopGroupIntfId.isNexthopGroupIntfId( l3IntfId ):
         nhgId = NexthopGroupIntfId.nexthopGroupId( l3IntfId )
      elif FecIdIntfId.isNexthopGroupIdIntfId( l3IntfId ):
         nhgId = FecIdIntfId.intfIdToNexthopGroupId( l3IntfId )

      retList = []
      if DynamicTunnelIntfId.isDynamicTunnelIntfId( l3IntfId ):
         viaTunnelId = DynamicTunnelIntfId.tunnelId( l3IntfId )
         tunnelInfoList = []
         if not allowEncap:
            trace( 'encap not allowed, refusing to resolve tunnelId', viaTunnelId )
            return None
         if TunnelId( viaTunnelId ).tunnelType() == TunnelType.srTeSegmentListTunnel:
            tunnelInfoList = self.resolveSrTeTunnel( viaTunnelId,
                                                     depth=depth + 1,
                                                     selectOneVia=selectOneVia )
         elif TunnelId( viaTunnelId ).tunnelType() == TunnelType.nexthopGroupTunnel:
            retList = self.resolveNexthopGroupTunnel( depth + 1, vrfName,
                                                      viaTunnelId,
                                                      selectOneVia=selectOneVia )
         else:
            tunnelInfoList = resolveTunnel( self, viaTunnelId, depth=depth + 1 )
         if not tunnelInfoList and not retList:
            return None
         if not isinstance( tunnelInfoList, list ):
            tunnelInfoList = [ tunnelInfoList ]
         for tunnelInfo in tunnelInfoList:
            trace( 'tunnel info:', tunnelInfo )
            if not tunnelInfo.resolved:
               trace( 'tunnel resolution failed:', tunnelInfo.errMsg )
               continue
            ret = HfecInfo( tunnelInfo.nexthop, tunnelInfo.intfId, tunnelInfo.dstMac,
                            tunnelInfo.labelStack, tunnelInfo.srcIntfIdStack, None )
            retList.append( ret )
            if selectOneVia:
               break
      elif nhgId:
         trace( 'fecIdIntfId', l3IntfId, 'nhgId', nhgId )
         if not allowEncap:
            trace( 'encap not allowed, refusing to resolve nhgId', nhgId )
            return None
         nhg = self.getNhgFromSysdb( nhgId )
         isAutoSize = nhg.size == 0
         nhgViaList = self.getNexthopGroupVia( nhgId )
         if not nhgViaList:
            return None
         # TODO: Support NHG resolution in non-default VRF.
         # Override allowEncap to false when resolving the nexthop group hop since
         # nexthop group hops can only resolve using unlabelled IP routes.
         for nhgVia in nhgViaList:
            nhgHop, nhgLabelStack = nhgVia
            nhgIntfIdStack = [ l3IntfId ] * len( nhgLabelStack )
            infoList = self.resolveHierarchical( allowEncap=False, hop=nhgHop,
                                                 depth=depth + 1, vrfName=vrfName,
                                                 selectOneVia=selectOneVia )
            # For autosized NHGs, we will choose the first resolved via in the case
            # where we selectOneVia, same as PD
            if infoList == noMatchNexthopInfoList:
               if selectOneVia and not isAutoSize:
                  trace( 'HFEC resolution failed for nhgId', nhgId )
                  return None
               continue

            for info in infoList:
               ret = HfecInfo( info.nexthopIp, info.intf, info.dstMac,
                               info.labelStack + nhgLabelStack,
                               info.srcIntfIdStack + nhgIntfIdStack, info.route )
               retList.append( ret )
               if selectOneVia:
                  break
      elif FecIdIntfId.isFecIdIntfId( l3IntfId ):
         fecId = FecIdIntfId.intfIdToFecId( l3IntfId )
         trace( 'fecIdIntfId', l3IntfId, 'fecId', fecId )
         infoList = self.resolveHierarchical( allowEncap, hop=hop, fecId=fecId,
                                              depth=depth + 1, vrfName=vrfName,
                                              selectOneVia=selectOneVia )
         if infoList == noMatchNexthopInfoList:
            trace( 'HFEC resolution failed for FEC intfId', l3IntfId )
            return None

         for info in infoList:
            ret = HfecInfo( info.nexthopIp, info.intf, info.dstMac, info.labelStack,
                            info.srcIntfIdStack, info.route )
            retList.append( ret )
            if selectOneVia:
               break
      else:
         trace( 'already L3 resolved', hop, l3IntfId )
         ret = HfecInfo( hop, l3IntfId, None, [], [], None )
         retList.append( ret )
      if not retList:
         return None
      # Should not be able to get here with a labelStack for allowEncap=False.  In
      # all branches above, we're either explicitly checking allowEncap, or else
      # calling resolveHierarchical with allowEncap=False.
      for ret in retList:
         assert allowEncap or not ret.labelStack
      return retList

   def _fibViaLabelStack( self, fibVia ):
      if fibVia.getRawAttribute( 'mplsLabel' ).isValid():
         return [ fibVia.mplsLabel ]
      else:
         return []

   def _resolveHfecFecId( self, allowEncap, hop, fecId, depth, vrfName=DEFAULT_VRF,
                          selectOneVia=True ):
      trace = lambda *args: self._dtrace( depth, *args )
      fecIdAdjType = FecId( fecId ).adjType()
      fecIdIntfId = FecIdIntfId.fecIdToIntfId( fecId )
      trace( 'fecIdAdjType:', fecIdAdjType, 'for fecId', fecId )
      retList = []
      if fecIdAdjType in [ FecAdjType.nextHopGroupAdj, FecAdjType.tunnelFibAdj ]:
         fecIdIntfId = FecId.fecIdToHierarchicalIntfId( fecId )
         infoList = self.resolveHierarchical( allowEncap, hop=hop,
                                              l3IntfId=fecIdIntfId,
                                              depth=depth + 1, vrfName=vrfName,
                                              selectOneVia=selectOneVia )

         if infoList == noMatchNexthopInfoList:
            trace( 'HFEC resolution failed for', fecIdAdjType, 'adj', fecIdIntfId )
            return None

         for info in infoList:
            ret = HfecInfo( info.nexthopIp, info.intf, info.dstMac, info.labelStack,
                            info.srcIntfIdStack, info.route )
            retList.append( ret )
            if selectOneVia:
               break
      else:
         if FecId( fecId ).adjType() == 'usedByTunnelGenAdj':
            fecId = FecId( FecId.fecIdToNewAdjType( 'fibGenAdj', fecId ) )
         elif FecId( fecId ).adjType() == 'usedByTunnelV4Adj':
            fecId = FecId( FecId.fecIdToNewAdjType( 'fibV4Adj', fecId ) )
         elif FecId( fecId ).adjType() == 'usedByTunnelV6Adj':
            fecId = FecId( FecId.fecIdToNewAdjType( 'fibV6Adj', fecId ) )

         fec = self.resolveFecId( fecId )
         if fec is None:
            trace( 'HFEC resolution failed for fecId', fecId )
            return None
         # If any via is using a "special" HFEC IntfId value, then we can just pick
         # the first resolvable HFEC via.  If there are no HFEC vias, then we can
         # continue to do L3 resolution on any other vias in the FEC.
         for via in fec.via.values():
            if not self._isSpecialIntfId( via.intfId ):
               continue
            infoList = self.resolveHierarchical( allowEncap, hop=via.hop,
                                                 l3IntfId=via.intfId,
                                                 depth=depth + 1, vrfName=vrfName,
                                                 selectOneVia=selectOneVia )

            # If this via didn't resolve, try others.
            if infoList == noMatchNexthopInfoList:
               continue

            for info in infoList:
               trace( 'resolving with special intfId', via.intfId )
               fibLabelStack = self._fibViaLabelStack( via )
               labelStack = info.labelStack + fibLabelStack
               srcIntfIdStack = ( info.srcIntfIdStack +
                                  [ fecIdIntfId ] * len( fibLabelStack ) )
               ret = HfecInfo( info.nexthopIp, info.intf, info.dstMac,
                               labelStack, srcIntfIdStack, info.route )
               if not allowEncap and ret.labelStack:
                  trace( 'HFEC resolution failed because allowEncap=False', ret )
                  continue
               retList.append( ret )
               if selectOneVia:
                  break
            if selectOneVia:
               break
         if not retList:
            # Since there are no resolved HFEC vias, try to find a regular via,
            # starting with the unlabelled ones.
            for via in sorted( fec.via.values(),
                               key=functools.cmp_to_key( unlabeledFirst ) ):
               viaHop, dstMac = self.resolveL2NexthopForFibVia( hop, via )
               if dstMac is None:
                  trace( 'skip L2-unresolved via, viaHop', viaHop, 'via', via )
                  continue
               labelStack = self._fibViaLabelStack( via )
               srcIntfIdStack = [ fecIdIntfId ] * len( labelStack )
               ret = HfecInfo( viaHop, via.intfId, None, labelStack, srcIntfIdStack,
                               None )
               if not allowEncap and ret.labelStack:
                  trace( 'HFEC resolution failed because allowEncap=False', ret )
                  continue
               retList.append( ret )
               if selectOneVia:
                  break
         if not retList:
            trace( 'HFEC resolution failed: no usable vias', fec )
            return None
      return retList

   def _resolveHfecHop( self, allowEncap, hop, depth, vrfName=DEFAULT_VRF,
                        selectOneVia=True ):
      trace = lambda *args: self._dtrace( depth, *args )
      routeInfo = self.getRouteAndFec( hop, vrfName=vrfName )

      if routeInfo == noMatchRouteAndFec:
         trace( 'L3 resolution failed for hop', hop )
         return None
      fec = routeInfo.fec
      retList = []
      fecId = fec.fecId
      infoList = self.resolveHierarchical( allowEncap, hop=hop, fecId=fecId,
                                           depth=depth + 1, vrfName=vrfName,
                                           selectOneVia=selectOneVia )

      if infoList == noMatchNexthopInfoList:
         trace( 'L3 resolution failed for hop', hop )
         return None

      for info in infoList:
         routeKey = routeInfo.route.key
         ret = HfecInfo( info.nexthopIp, info.intf, info.dstMac, info.labelStack,
                         info.srcIntfIdStack, info.route or routeKey )
         # Should not be able to get here with a labelStack for allowEncap=False. In
         # all branches above, we're calling resolveHierarchical with
         # allowEncap = False
         assert allowEncap or not ret.labelStack
         retList.append( ret )
         if selectOneVia:
            break
      return retList

   def resolveHierarchical( self, allowEncap, hop=None, l3IntfId=None, fecId=None,
                            depth=0, vrfName=DEFAULT_VRF, selectOneVia=True ):
      '''
      Resolve nexthop using all available resolution sources, recursing as needed to
      resolve hierarchical FECs.

      Since an HFEC is a tree, this is equivalent to a full depth-first search for a
      list of nodes with an L2 nexthop.

      This method expects that at least one of the arguments among 'hop', 'l3IntfId'
      or 'fecId' is provided.
      '''
      assert hop or l3IntfId or fecId, 'At least one among hop, l3IntfId, fecId '\
         'must be provided'

      trace = lambda *args: self._dtrace( depth, *args )
      if hop:
         hop = str( hop )
      trace( 'resolveHierarchical, allowEncap', allowEncap, 'hop', hop,
             'l3IntfId', l3IntfId, 'fecId', fecId, "vrf", vrfName )
      if l3IntfId:
         infoList = self._resolveHfecL3IntfId( allowEncap, hop, l3IntfId, depth,
                                               vrfName=vrfName,
                                               selectOneVia=selectOneVia )
      elif fecId:
         infoList = self._resolveHfecFecId( allowEncap, hop, fecId, depth,
                                            vrfName=vrfName,
                                            selectOneVia=selectOneVia )
      else:
         infoList = self._resolveHfecHop( allowEncap, hop, depth, vrfName=vrfName,
                                          selectOneVia=selectOneVia )

      if not infoList or infoList == noMatchNexthopInfoList :
         return noMatchNexthopInfoList

      retList = []
      for info in infoList:
         if not info.dstMac:
            dstMac = self.resolveL2Nexthop( info.l3IntfId, info.hop )
            if dstMac is None:
               trace( 'L2 resolution failed for', info.l3IntfId, info.hop )
               continue
            info.dstMac = dstMac
         assert len( info.labelStack ) == len( info.srcIntfIdStack )
         nexthopInfo = NexthopInfo( True, info.dstMac, info.hop, info.l3IntfId,
                                    info.labelStack, info.srcIntfIdStack,
                                    info.route )
         retList.append( nexthopInfo )
         trace( 'resolved to', nexthopInfo )
         if selectOneVia:
            break

      if not retList:
         retList.append( noMatchNexthopInfo )
      return retList

   def getResolvedNexthopInfo( self, dstIp, intf=None, vrfName=DEFAULT_VRF ):
      '''Performs L2 and L3 nexthop resolution for dstIp, intf'''
      retList = self.resolveHierarchical( allowEncap=True, hop=dstIp, l3IntfId=intf,
                                          vrfName=vrfName )
      # Return the first resolved nexthop
      return retList[ 0 ]

   def getSrTeTunnelTableEntry( self, tunnelId ):
      slTunnelTable = self.srTeSegmentListTunnelTable_
      slTunnelEntry = slTunnelTable.entry.get( tunnelId, None )
      return slTunnelEntry

   def resolveSrTeTunnel( self, tunnelId, depth=0, selectOneVia=True ):
      trace = lambda *args: self._dtrace( depth, *args )
      tunnelIntfId = DynamicTunnelIntfId.tunnelIdToIntfId( tunnelId )
      trace( 'resolveSrTeTunnel: tunnel intfId', tunnelIntfId )
      slTunnelEntry = self.getSrTeTunnelTableEntry( tunnelId )
      if not slTunnelEntry:
         trace( 'resolveSrTeTunnel: no tunnel entry found' )
         return None
      retList = []
      for via in sorted( slTunnelEntry.via.values() ):
         labelStack = []
         srcIntfIdStack = [ tunnelIntfId ] * via.labels.stackSize
         for i in range( via.labels.stackSize ):
            labelStack.insert( 0, via.labels.labelStack( i ) )
         if FecIdIntfId.isFecIdIntfId( via.intfId ):
            trace( 'resolveSrTeTunnel: Resolving nextLevelFecId via:', via.intfId )
            infoList = self.resolveHierarchical( True, l3IntfId=via.intfId,
                                                 depth=depth + 1,
                                                 selectOneVia=selectOneVia )
            if infoList == noMatchNexthopInfoList:
               return None
            for info in infoList:
               ret = MplsTunnelInfo( labelStack=labelStack,
                                     srcIntfIdStack=srcIntfIdStack,
                                     dstMac=info.dstMac, intfId=info.intf,
                                     nexthop=info.nexthopIp )
               retList.append( ret )
               if selectOneVia:
                  break
         elif DynamicTunnelIntfId.isDynamicTunnelIntfId( via.intfId ):
            viaTunnelId = DynamicTunnelIntfId.tunnelId( via.intfId )
            trace( 'resolveSrTeTunnel: Resolving sub-tunnel via:', viaTunnelId )
            tunnelInfoList = resolveTunnel( self, viaTunnelId, depth=depth + 1 )
            if not tunnelInfoList:
               return None
            if not isinstance( tunnelInfoList, list ):
               tunnelInfoList = [ tunnelInfoList ]
            for tunnelInfo in tunnelInfoList:
               if not tunnelInfo.resolved:
                  return None
               # append the tilfa-tunnel label stack with
               # the SR-TE tunnel label stack.
               # This is required to achieve the full resolved stack.
               tunnelInfo.labelStack.extend( labelStack )
               tunnelInfo.srcIntfIdStack.extend( srcIntfIdStack )

               # Sanitize the implicit-nulls from the labelStack and intfIds
               if all( x == MplsLabel.implicitNull for x in tunnelInfo.labelStack ):
                  # Only imp-nulls; just reduce to one
                  tunnelInfo.labelStack = tunnelInfo.labelStack[ : 1 ]
                  tunnelInfo.srcIntfIdStack = tunnelInfo.srcIntfIdStack[ : 1 ]
               else:
                  # Remove the labels and corresponding intfIds for all imp-nulls
                  indicesToDel = [
                     i for i, v in enumerate( tunnelInfo.labelStack ) if v == 3 ]
                  for i in reversed( indicesToDel ):
                     del tunnelInfo.labelStack[ i ]
                     del tunnelInfo.srcIntfIdStack[ i ]

               ret = MplsTunnelInfo ( labelStack=tunnelInfo.labelStack,
                                      srcIntfIdStack=tunnelInfo.srcIntfIdStack,
                                      dstMac=tunnelInfo.dstMac,
                                      intfId=tunnelInfo.intfId,
                                      nexthop=tunnelInfo.nexthop )
               retList.append( ret )
               if selectOneVia:
                  break
         else:
            dstMac = self.resolveL2Nexthop( via.intfId, str( via.nexthop ) )
            if dstMac is None:
               return [ MplsTunnelInfo( \
                        errMsg="L2 resolution failed for {}, {}".format( \
                        via.intfId, via.nexthop ) ) ]
            ret = MplsTunnelInfo( labelStack=labelStack,
                                  srcIntfIdStack=srcIntfIdStack,
                                  dstMac=dstMac,
                                  intfId=via.intfId,
                                  nexthop=via.nexthop )
            retList.append( ret )
            if selectOneVia:
               break
         if selectOneVia:
            break
      return retList

def forwardingHelperFactory( bridge ):
   return ForwardingHelper(
         bridgingStatus=bridge.brStatus,
         vrfRoutingStatus=bridge.vrfRoutingStatus_,
         vrfRouting6Status=bridge.vrfRouting6Status_,
         forwardingStatus=bridge.forwardingStatus_,
         forwarding6Status=bridge.forwarding6Status_,
         forwardingGenStatus=bridge.forwardingGenStatus_,
         srTeForwardingStatus=bridge.srTeForwardingStatus_,
         srTeSegmentListTunnelTable=bridge.srTeSegmentListTunnelTable,
         nhgEntryStatus=bridge.nhgEntryStatus_,
         arpSmash=bridge.arpSmash_,
         tunnelFib=bridge.tunnelFib_,
         trie4=bridge.vrfTrie_,
         trie6=bridge.vrfTrie6_,
         vrfNameStatus=bridge.vrfNameStatus_,
         intfConfigDir=bridge.intfConfigDir_,
         ethIntfStatusDir=bridge.ethIntfStatusDir_,
      )

def forwardingHelperKwFactory( **kwargs ):
   '''
   This factory function should only be used for tests that are mocking
   or using only small pieces of the Helper.  Actual plugin should use the
   mplsEtbaForwardingHelperFactory( bridge ) factory function instead.
   As mplsEtbaForwardingHelperFactory() needs a bridge object to be instatiated,
   we can use this factory method in cases where initializing a bridge is not so
   useful.
   '''
   args = [
      "bridgingStatus",
      "vrfRoutingStatus",
      "vrfRouting6Status",
      "forwardingStatus",
      "forwarding6Status",
      "forwardingGenStatus",
      "srTeForwardingStatus",
      "srTeSegmentListTunnelTable",
      "nhgEntryStatus",
      "arpSmash",
      "tunnelFib",
      "trie4",
      "trie6",
      "vrfNameStatus",
      "intfConfigDir",
      "ethIntfStatusDir"
   ]
   defaultKwArgs = { k: None for k in args }
   defaultKwArgs.update( kwargs )
   return ForwardingHelper( **defaultKwArgs )
