#!/usr/bin/env python3
# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac
import Tracing
import SharedMem
import Smash
import SmashLazyMount
import Shark
import MplsEtbaLib
from MplsPktHelper import isMpls # pylint: disable=no-name-in-module
import ForwardingHelper
from EbraTestBridgePort import EbraTestPort
from Arnet import PktParserTestLib
from Arnet.EthTestLib import EthHdrSize
import Arnet.MplsLib
from SysConstants.if_ether_h import ETH_P_MPLS_UC
from EbraTestBridgeLib import (
   PKTEVENT_ACTION_ADD,
   PKTEVENT_ACTION_REPLACE,
   PKTEVENT_ACTION_REMOVE,
   PKTEVENT_ACTION_NONE,
   applyVlanChanges,
   )
from L2RibLib import traverseAndValidateMultiTable
from TypeFuture import TacLazyType

DynamicTunnelIntfIdType = TacLazyType( 'Arnet::DynamicTunnelIntfId' )
FecIdIntfIdType = TacLazyType( 'Arnet::FecIdIntfId' )
FecIdType = TacLazyType( 'Smash::Fib::FecId' )
TunnelIdType = TacLazyType( "Tunnel::TunnelTable::TunnelId" )
TunnelType = TacLazyType( "Tunnel::TunnelTable::TunnelType" )
DestType = TacLazyType( "L2Rib::DestType" )

handle = Tracing.Handle( 'EtbaEvpnMpls' )
t0 = handle.trace0
t2 = handle.trace2
t8 = handle.trace8

# Constants used in this plugin
ETH_ADDR_ZERO = Tac.Type( "Arnet::EthAddr" ).ethAddrZero

# Global state mounted from sysdb/smash accessed by various plugin
# methods.
l2RibHostOutput_ = None
loadBalanceTable_ = None
destTable_ = None
labelTable_ = None
floodSetOutput_ = None
mplsTrunkIntfDir_ = None
mplsTrunkIntf_ = None
mplsTrunkPort_ = None
mplsConfig_ = None
brInputConfig_ = None
stpPortMode_ = None
managedIntfList_ = None
evpnStatus_ = None
mplsLfibStatus_ = None
lFibStatusConsumerSm_ = None
mplsEsFilterLfib_ = None
tunnelFib_ = None
arpTable_ = None
mlagStatus_ = None
bridge_ = None

# global sm
mtiSm_ = None

# Round robin hashing for overlay ECMP for each packet.
packetCount = 0
reactors_ = {}

# When a packet enters with MplsTunnelHdr for L2EVPN, this floodFilter
# set is updated in PreTunnelHandler. It is for the VLAN that the
# packet entered the DUT and will be used by PktReplicationHandler to
# remove the interfaces in mplsFloodFilter from the final list of
# interfaces that the packet must be replicated.
mplsFloodFilter_ = set()

def rewriteVlanTag( innerData, rewriteVlan ):
   """Decide the vlan rewrite action based on packet carrying vlan tag or not"""
   t8( "Modifying packet to carry vlan", rewriteVlan )
   ( __, headers, __ ) = PktParserTestLib.parsePktStr( innerData )
   innerDot1QHdr = PktParserTestLib.findHeader( headers, 'EthDot1QHdr' )
   if not innerDot1QHdr and not rewriteVlan:
      t8( "No innerDot1q and no rewrite vlan. No action needed." )
      return innerData
   # Determine the VlanAction which requires a rewrite now.
   if innerDot1QHdr and rewriteVlan:
      vlanAction = PKTEVENT_ACTION_REPLACE
   elif innerDot1QHdr and not rewriteVlan:
      vlanAction = PKTEVENT_ACTION_REMOVE
   elif rewriteVlan:
      vlanAction = PKTEVENT_ACTION_ADD
   t8( "Final vlan action", vlanAction )
   return applyVlanChanges( innerData, 0, rewriteVlan, vlanAction )

def postBridging( bridge, srcMac, destMac, vlanId,
                  data, srcPortName, destIntf, finalIntfs ):
   if not bridge.mplsRoutingInfo_.mplsRouting:
      return ( None, None, None, None )

   if vlanId:
      vlanConfig = bridge.brConfig_.vlanConfig.get( vlanId )
   else:
      return ( None, None, None, None )

   if vlanConfig and vlanConfig.etreeRole == 'etreeRoleLeaf':
      leafBridging = "MplsTrunk" not in srcPortName and \
                     "MplsTrunk" not in destIntf and \
                     "Cpu" not in srcPortName and \
                     "Cpu" not in destIntf
      if leafBridging:
         t0( "etree leaf: local bridging is disabled" )
         return ( 'deny', None, None, None )
   return ( None, None, None, None )

def intfIdToTunnelId( intfId ):
   # We first convert to fecId because our intfId could either be for an SR-TE policy
   # fecId or DynamicTunnelIdIntfId.
   if FecIdIntfIdType.isSrTePolicyIntfId( intfId ):
      return FecIdIntfIdType.tunnelId( intfId )
   else:
      return DynamicTunnelIntfIdType.tunnelId( intfId )

def resolveTunnel( tunnelId ):
   # Provide a wrapper around ForwardingHelper.resolveTunnel to encompass the
   # global variable access, tunnelFib_ and fwdHelper, that are required arguments
   # for the ForwardingHelper.resolveTunnel helper method.
   # This also helps with the unit tests in Evpn/test/EvpnMplsEtbaPluginTest.py
   # since they are directly calling resolveTunnel( tunnelId ) from this module.
   return ForwardingHelper.resolveTunnel( bridge_.fwdHelper, tunnelId )

def getNextDestinationIndex( numMplsDest ):
   """Returns an index in the range of [ 0, numMplsDest ) for each
   packet. Used to enforce overlay ECMP hashing from LoadBalance
   table."""
   global packetCount
   packetCount += 1
   return packetCount % numMplsDest

def getLBNext( objTuple ):
   """Attempt to fetch the loadBalance pointed from objTuple to pick a
   single entry using round-robin hashing."""
   loadBalance = loadBalanceTable_.lb.get( objTuple.objId )
   return loadBalance.lb.next[ 
         getNextDestinationIndex( len( loadBalance.lb.next ) ) ] \
         if loadBalance and loadBalance.lb.next else None

def getLabelData( objTuple ):
   """objTuple can lead to ECMP Label via tableTypeLoadBalance or to
   just tableTypeLabel."""
   if objTuple.tableType == 'tableTypeLabel':
      return labelTable_.label.get( objTuple.objId )
   if objTuple.tableType == 'tableTypeLoadBalance':
      lbNext = getLBNext( objTuple )
      if lbNext and lbNext.tableType == 'tableTypeLabel':
         return labelTable_.label.get( lbNext.objId )
   return None

def getTunnelData( objTuple ):
   """objTuple can lead to ECMP Tunnel via tableTypeLoadBalance or to
   just tableTypeDest."""
   dest = None
   if objTuple.tableType == 'tableTypeDest':
      dest = destTable_.dest.get( objTuple.objId )
   elif objTuple.tableType == 'tableTypeLoadBalance':
      lbNext = getLBNext( objTuple )
      if lbNext and lbNext.tableType == 'tableTypeDest':
         dest = destTable_.dest.get( lbNext.objId )
   return dest if dest and dest.dest.destType == 'destTypeTunnel' else None

def resolveDstMac( dstMacAddr, vlanId ):
   """L2RIB host output lookup result which can return a list of tuples
   where each tuple contains the evpnLabel and tunnel-id.
   """
   destSet = []
   macVlanPair = Tac.Value( "Bridging::MacVlanPair", dstMacAddr, vlanId )
   hostAndSource = l2RibHostOutput_.host.get( macVlanPair )
   if hostAndSource:
      t8( "DstMacAddr", dstMacAddr, "vlanId", vlanId, "is a hit" )
      if hostAndSource.nextIsStored:
         objTuple = hostAndSource.next
         labelData = getLabelData( objTuple )
         tunnelDest = getTunnelData( labelData.next ) if labelData else None
         if labelData and tunnelDest:
            destSet.append( ( labelData.label,
                              intfIdToTunnelId( tunnelDest.tunnel.intfId ) ) )
   else:
      t8( "DstMacAddr", dstMacAddr, "vlanId", vlanId, "is a miss. Flood the packet" )
      vlanFloodSet = floodSetOutput_.vlanFloodSet.get( vlanId )
      floodSet = vlanFloodSet.floodSet.get( ETH_ADDR_ZERO ) if vlanFloodSet else None
      floodDestSet = floodSet.destSet if floodSet else set()
      for dest in floodDestSet:
         if dest.destType == 'destTypeMpls':
            destSet.append( ( dest.mpls.label,
                              intfIdToTunnelId( dest.mpls.intfId ) ) )

   t8( "DstMacAddr", dstMacAddr, "vlanId", vlanId, "resolves to", destSet )
   return destSet

def constructEthHeader( srcMac, dstMac ):
   """Construct the outer ethernet header with ethType Mpls"""
   ethHeader = Tac.newInstance( 'Arnet::Pkt' )
   ethHeader.newSharedHeadData = EthHdrSize
   ethHdr = Tac.newInstance( 'Arnet::EthHdrWrapper', ethHeader, 0 )
   ethHdr.src = srcMac
   ethHdr.dst = dstMac
   ethHdr.typeOrLen = ETH_P_MPLS_UC
   return ethHeader.stringValue

def constructTunnelPacket( data, innerVlanId, labelStack, outerSrcMac,
                           outerDestMac, addControlWord ):
   """Encapsulate data rewritten with innnerVlanId inside the
   MplsLabelHdrs and outer MacHdrs"""
   t8( "Constructing tunneled packet with innerVlan", innerVlanId,
       "labelStack", labelStack, "outerSrcMac", outerSrcMac,
       "outerDestMac", outerDestMac )
   innerData = rewriteVlanTag( data, innerVlanId )
   ethHdr = constructEthHeader( outerSrcMac, outerDestMac )
   mplsHdr = Arnet.MplsLib.constructMplsHeader( labelStack )
   controlWord = b'\x00' * 4 if addControlWord else b''
   return ethHdr + mplsHdr + controlWord + innerData

def getTunnelFrame( tunnelId, evpnLabel, ingressVlanId, dstMacAddr, ingressPort,
                    bridgeMac, innerData ):
   """Rewrite innerData with innerVlanId and encap the data with labels
   from tunnelId and evpnLabel. Return the egressPort on which the
   tunnelFrame must be sent out.
   """
   # pylint: disable-next=isinstance-second-argument-not-valid-type
   if not isinstance( tunnelId, TunnelIdType ):
      tunnelId = TunnelIdType( tunnelId )
   if tunnelId.tunnelType() == TunnelType.srTePolicyTunnel:
      srTePolicyFecId = FecIdType.tunnelIdToFecId( tunnelId )
      srTePolicyFec = bridge_.pfs.fec.get( srTePolicyFecId )
      if not srTePolicyFec:
         return ( None, None )
      # We do not support ECMP for SR-TE + L2Evpn at the segment-list level.
      tunnelInfoList = bridge_.fwdHelper.resolveSrTeTunnel(
         srTePolicyFec.via[ 0 ].tunnelId )
      if not tunnelInfoList:
         return ( False, False )
      # Using only primary interface for now
      tunnelInfo = tunnelInfoList[ 0 ]
   else:
      tunnelInfo = resolveTunnel( tunnelId )
   t8( "Resolved tunnelInfo", repr( tunnelInfo ) )
   if not tunnelInfo.resolved:
      return ( None, None )

   MplsEtbaLib.removeImpNullFromLabelStack( tunnelInfo.labelStack )

   intfId = tunnelId.intfId()

   esiLabel = evpnStatus_.segmentStatus[ ingressPort ].remoteEsiLabel.get(
      intfId ) if evpnStatus_.segmentStatus.get( ingressPort ) else None

   # Inner DMAC lookup on the ingress vlan to decide on ESI label.
   assert dstMacAddr is not None, "Invalid innere destination MAC"
   macVlanPair = Tac.Value( "Bridging::MacVlanPair", dstMacAddr, ingressVlanId )
   bumPacket = l2RibHostOutput_.host.get( macVlanPair ) is None

   innerLabelStack = [ evpnLabel, esiLabel.mplsLabel() ] if bumPacket and esiLabel \
                     and esiLabel.encapType == 'encapMpls' else [ evpnLabel ]
   innerVlanId = evpnStatus_.vlanToDot1q.get( ingressVlanId )
   tunnelData = constructTunnelPacket( innerData, innerVlanId,
                                       tunnelInfo.labelStack + innerLabelStack,
                                       bridgeMac,
                                       tunnelInfo.dstMac,
                                       evpnStatus_.controlWordPresent )
   # When egressInterface resolves to an SVI port, we need to do a
   # DMAC lookup to find the underlying physical port in that VLAN.
   if Tac.Type( 'Arnet::VlanIntfId' ).isVlanIntfId( tunnelInfo.intfId ):
      # Outer DMAC lookup on the vlan-id
      vlanId = Tac.Type( 'Arnet::VlanIntfId' ).vlanId( tunnelInfo.intfId )
      macVlanPair = Tac.Value( "Bridging::MacVlanPair",
                               tunnelInfo.dstMac, vlanId )
      hostAndSource = l2RibHostOutput_.host.get( macVlanPair )
      t8( "Egress port lookup for", macVlanPair, "returned", hostAndSource )
      egressIntf = None
      if hostAndSource and hostAndSource.host.intfIsStored:
         egressIntf = hostAndSource.host.intf
   else:
      egressIntf = tunnelInfo.intfId
   return ( tunnelData, egressIntf )

class EbraTestMplsTrunkPort( EbraTestPort ):
   """Simulates a MplsTrunkInterface for L2EVPN"""
   def __init__( self, bridge, tapDevice, trapDevice, intfConfig, intfStatus ):
      assert tapDevice is None
      assert trapDevice is None
      assert intfConfig is None
      assert intfStatus is not None
      EbraTestPort.__init__( self, bridge, intfConfig, intfStatus )

   def initialized( self ):
      """Checks if we have the needed entities available to send frames"""
      return evpnStatus_ and tunnelFib_ and l2RibHostOutput_ and arpTable_ and \
         mlagStatus_

   def sendFrame( self, data, srcMacAddr, dstMacAddr, srcPortName,
                  priority, vlanId, vlanAction ):
      """Resolve the DstMac and DstTunnelId to get the encap
      information. Rewrite the VLAN tag to normalized .1q tag if needed."""
      if not self.initialized():
         t8( "Entities needed for encap not found" )
         return
      if mlagStatus_.peerLinkIntf and srcPortName == mlagStatus_.peerLinkIntf.intfId:
         t8( "MplsTrunkPort: Ignore packets ingressing via peer-link", srcPortName )
         return

      dmacLookupResult = resolveDstMac( dstMacAddr, vlanId )
      for ( evpnLabel, tunnelId ) in dmacLookupResult:
         ( tunnelFrame, egressIntf ) = getTunnelFrame( tunnelId, evpnLabel,
                                                       vlanId, dstMacAddr,
                                                       srcPortName,
                                                       self.bridge_.bridgeMac(),
                                                       data )
         egressPort = self.bridge_.port.get( egressIntf )
         if egressPort:
            egressPort.sendFrame( tunnelFrame, None, None, None, None, None,
                                  PKTEVENT_ACTION_NONE )
         else:
            t8( egressIntf, "is not mapped to any egress EtbaTestPort" )

   def trapFrame( self, data ):
      # We must never trap any frame or deal with it.
      pass

def _validNext( objTuple ):
   # Return true for all dest types except destTypeVxlan
   expectedDestTypes = [ DestType.destTypeIntf, DestType.destTypeMpls,
         DestType.destTypeTunnel, DestType.destTypeCpu ]
   return traverseAndValidateMultiTable( objTuple, destTable_, labelTable_, 
                                         loadBalanceTable_, 1,
                                         expectedDestTypes=expectedDestTypes )

def isEvpnRemoteHost( hostAndSource ):
   if not hostAndSource:
      return False
   validType = hostAndSource.entryType in ( 'evpnDynamicRemoteMac',
                                            'evpnConfiguredRemoteMac' )
   return hostAndSource.source == 'sourceBgp' and validType

# sourceBgp; evpnDynamicRemoteMac, evpnConfiguredRemoteMac and evpnIntfDynamicMac
# entryTypes are the ones we are concerned in L2RIB updates. Only returns true
# if correct entries exist in L2Rib output tables.
def isMplsEvpnHost( hostAndSource ):
   validDestType = ( hostAndSource.nextIsStored and
                     _validNext( hostAndSource.host.next ) )
   validRemoteDest = validDestType and (
      hostAndSource.entryType in ( 'evpnDynamicRemoteMac',
                                   'evpnConfiguredRemoteMac' ) )
   validLocalDest = hostAndSource.entryType in ( 'evpnIntfDynamicMac',
                                                 'evpnIntfStaticMac' )
   t0( "isMplsEvpnHost - hostAndSource %s" % str( hostAndSource.host ) )
   validDest = validRemoteDest or validLocalDest
   return hostAndSource.source == 'sourceBgp' and validDest

# Used to fetch the destination interface for programming the MAC
# in the bridge. Applicable only for hosts whose destination matters
# for this plugin ( intf/mpls ).
def getHostInterface( host ):
   if host.entryType in ( 'evpnDynamicRemoteMac', 'evpnConfiguredRemoteMac' ):
      return mplsTrunkIntf_
   elif host.entryType in ( 'evpnIntfDynamicMac', 'evpnIntfStaticMac' ):
      assert host.intfIsStored
      return host.intf
   return None

class L2RibHostOutputReactor( Tac.Notifiee ):
   """When EvpnRemoteHost is added to L2RIB hostOutput with destination
   pointing to Mpls Tunnel, have the bridge learn the host with
   interface as MplsTrunkIntf. On delete notifications, remove the
   host from bridge only if the host is MplsEvpnRemote."""

   notifierTypeName = 'L2Rib::HostOutput'
   def __init__( self, l2RibOutput, brStatus, bridge, mplsTrunkIntf ):
      self.bridge_ = bridge
      self.l2RibOutput_ = l2RibOutput
      self.brStatus_ = brStatus
      self.mplsTrunkIntf_ = mplsTrunkIntf
      self.retryInterval = 0.01
      self.retryDict = {}
      self.retryAct = Tac.ClockNotifiee( handler=self._doMaybeAddHost,
                                         timeMin=Tac.endOfTime )
      Tac.Notifiee.__init__( self, self.l2RibOutput_ )
      t0( "Initialized L2RibHostOutputReactor" )
      self.handleHost( hostKey=None )

   def __del__(self):   
      self.retryAct.timeMin = Tac.endOfTime

   def _doMaybeAddHost( self ):
      keys = list( self.retryDict.keys() )
      for key in keys:
         self.handleHost( key )
      if not self.retryDict:
         self.retryAct.timeMin = Tac.endOfTime

   @Tac.handler( 'host' )
   def handleHost( self, hostKey ):
      t8( "handleHost:", hostKey )
      if hostKey is None:
         for host in self.l2RibOutput_.host.values():
            self.addHost( host )
         return

      hostEntry = self.l2RibOutput_.host.get( hostKey, None )
      if hostEntry:
         self.addHost( hostEntry )
      else:
         self.deleteHost( hostKey )

   def addHost( self, hostEntry ):
      t8( "addHost %s, vlan%d" % ( hostEntry.macAddr, hostEntry.vlanId ) )
      if isMplsEvpnHost( hostEntry ):
         self.bridge_.learnAddr( hostEntry.vlanId, hostEntry.macAddr,
                                 getHostInterface( hostEntry ),
                                 hostEntry.entryType,
                                 hostEntry.seqNo )
         if self.retryDict.get( hostEntry.key ):
            del self.retryDict[ hostEntry.key ]
      elif isEvpnRemoteHost( hostEntry ):
         t8( "addHost %s, vlan%d to retry list" % 
             ( hostEntry.macAddr, hostEntry.vlanId ) )
         self.retryDict[ hostEntry.key ] = True
         if self.retryAct.timeMin == Tac.endOfTime:
            self.retryAct.timeMin = Tac.now() + self.retryInterval

   def deleteHost( self, hostKey ):
      t8( "deleteHost %s, vlan%d" % ( hostKey.macaddr, hostKey.vlanId ) )
      fdbStatus = self.brStatus_.fdbStatus.get( hostKey.vlanId )
      hostEntry = fdbStatus.learnedHost.get( hostKey.macaddr ) if fdbStatus else None
      # Restrict re-add to only evpnRemoteMac whose destIntf is
      # MplsTrunkIntf or if hostEntry is of type evpnIntfDynamic/StaticMac
      validEvpnHost = hostEntry and ( hostEntry.intf == self.mplsTrunkIntf_ or \
                                  hostEntry.entryType == 'evpnIntfDynamicMac' or \
                                  hostEntry.entryType == 'evpnIntfStaticMac' )
      if validEvpnHost:
         self.bridge_.deleteMacAddressEntry( hostKey.vlanId, hostKey.macaddr )

class MplsTrunkInterfaceStatusDirReactor( Tac.Notifiee ):
   notifierTypeName = 'Mpls::MtiStatusDir'

   def __init__( self, mtiStatusDir, bridge ):
      self.mtiStatusDir_ = mtiStatusDir
      self.bridge_ = bridge
      Tac.Notifiee.__init__( self, self.mtiStatusDir_ )
      t0( "Initialized MplsTrunkInterfaceStatusDirReactor" )
      self.handleIntfStatus()

   @Tac.handler( 'mtiStatus' )
   def handleIntfStatus( self, key=None ):
      # There should only be a maximum of 1 MtiStatus
      assert len( self.mtiStatusDir_.mtiStatus.keys() ) <= 1
      bridge = self.bridge_
      # Disregard "key" since we know there is only going to be one MplsTrunk
      # interface and we can use the global mplsTrunkIntf_ to look up
      mplsTrunkIntfStatus = self.mtiStatusDir_.mtiStatus.get( mplsTrunkIntf_ )
      if mplsTrunkIntfStatus:
         global mplsTrunkPort_
         mplsTrunkPort_ = EbraTestMplsTrunkPort( bridge, None, None, None,
                                                 mplsTrunkIntfStatus )
         bridge.addPort( mplsTrunkPort_ )
         t8( "Creating L2RibHostOutputReactor" )
         reactors_[ 'l2Rib' ] = L2RibHostOutputReactor( 
                                    l2RibHostOutput_, bridge.brStatus_,
                                    bridge, mplsTrunkIntf_ )
      else:
         if reactors_.get( 'l2Rib' ):
            t8( "Deleting L2RibHostOutputReactor due to MplsTrunkIntf removal" )
            del reactors_[ 'l2Rib' ]
         bridge.delPort( mplsTrunkIntf_ )

def handleAgentInit( em ):
   """Mount all the entities needed."""
   t2( "EvpnMpls plugin initialization" )
   mg = em.mountGroup()
   shmemEm = SharedMem.entityManager( sysdbEm=em )
   smashMountInfo = SmashLazyMount.mountInfo( 'keyshadow' )

   global l2RibHostOutput_
   l2RibHostOutput_ = shmemEm.doMount( "bridging/l2Rib/hostOutput",
                                       "L2Rib::HostOutput", smashMountInfo )
   global floodSetOutput_
   shmemMg = shmemEm.getMountGroup()
   floodSetOutput_ = shmemMg.doMount( "bridging/l2Rib/floodOutput",
                                      "L2Rib::FloodSetOutput",
                                      Shark.mountInfo( 'shadow' ) )
   shmemMg.doClose()

   global arpTable_
   arpTable_ = shmemEm.doMount( "arp/status", "Arp::Table::Status",
                                Smash.mountInfo( 'shadow' ) )
   global evpnStatus_
   evpnStatus_ = mg.mount( "evpn/status",
                           "Evpn::EvpnStatus", "r" )

   global tunnelFib_
   tunnelFib_ = shmemEm.doMount( "tunnel/tunnelFib", "Tunnel::TunnelFib::TunnelFib",
                                 smashMountInfo )
   global mplsLfibStatus_
   mplsLfibStatus_ = shmemEm.doMount( "mpls/decapLfib", "Mpls::LfibStatus",
                                      smashMountInfo )
   global mplsEsFilterLfib_
   mplsEsFilterLfib_ = shmemEm.doMount( "mpls/evpnEthernetSegmentFilterLfib",
                                        "Mpls::LfibStatus", smashMountInfo )

   global lFibStatusConsumerSm_
   lFibStatusConsumerSm_ = Tac.newInstance(
   "Mpls::LfibStatusConsumerSm", mplsLfibStatus_ )

   global loadBalanceTable_
   loadBalanceTable_ = shmemEm.doMount( "bridging/l2Rib/lbOutput",
                                        "L2Rib::LoadBalanceOutput", smashMountInfo )
   global destTable_
   destTable_ = shmemEm.doMount( "bridging/l2Rib/destOutput", "L2Rib::DestOutput",
                                 smashMountInfo )
   global labelTable_
   labelTable_ = shmemEm.doMount( "bridging/l2Rib/labelOutput", "L2Rib::LabelOutput",
                                  smashMountInfo )
   # Mount necessary entities for MtiSm to manage MplsTrunk interface
   global mplsTrunkIntf_
   mplsTrunkIntf_ = 'MplsTrunk1'

   global mplsTrunkIntfDir_
   mplsTrunkIntfDir_ = mg.mount( "interface/status/eth/mpls",
                                 "Mpls::MtiStatusDir", "w" )

   global mplsConfig_
   mplsConfig_ = mg.mount( "routing/mpls/config", "Mpls::Config", "r" )

   global brInputConfig_
   brInputConfig_ = mg.mount( "bridging/input/config/mpls", 
                              "Bridging::Input::Config", "wc" )

   global stpPortMode_
   stpPortMode_ = mg.mount( "stp/portMode/mpls", 
                            "EthIntf::PortModeConfig", "wc" )

   global managedIntfList_
   managedIntfList_ = em.getLocalEntity( "interface/status/managedIntfList" )

   # Mount Mlag::Status to identify the peer-link interface.
   global mlagStatus_
   mlagStatus_ = mg.mount( "mlag/status",
                           "Mlag::Status", "r" )

   def finishMounts():
      t8( "EvpnMpls plugin mounts complete" )
   mg.close( finishMounts )

def handleBridgeInit( bridge ):
   t2( "EvpnMpls bridgeInit" )
   global mtiSm_
   t8( "Instantiating MtiSm" )
   mtiSm_ = Tac.newInstance( "EvpnMpls::MtiSm",
                             mplsConfig_, brInputConfig_,
                             stpPortMode_, mplsTrunkIntfDir_,
                             managedIntfList_, False )
   t8( "Creating MplsTrunkInterfaceStatusDirReactor" )
   assert not reactors_.get( 'mtid', None )
   reactors_[ 'mtid' ] = \
         MplsTrunkInterfaceStatusDirReactor( mplsTrunkIntfDir_,
                                             bridge )
   global bridge_
   bridge_ = bridge

# EvpnRemoteMacs must exist in the bridge FdbStatus till they are
# explicitly removed from L2RibHostOutput. Re-populate the entries if
# they get aged out/flushed for some reason.
def handleAgedMac( bridge, vlanId, macAddr ):
   t8( "Aged MAC notification for", macAddr, "vlan", vlanId )
   hostKey = Tac.Value( "Bridging::MacVlanPair", macAddr, vlanId )
   hostEntry = l2RibHostOutput_.host.get( hostKey )
   if hostEntry and isMplsEvpnHost( hostEntry ):
      bridge.learnAddr( vlanId, macAddr,
                        getHostInterface( hostEntry ),
                        hostEntry.entryType,
                        hostEntry.seqNo )

# Hop around the collections in LfibStatus to find the viaKey. Ignore
# routes whose source is not lfibSourceBgpL2Evpn.
def getEvpnViaKey( label ):
   t8( "Attempt to fetch viaSetKey for label", label )
   if not mplsLfibStatus_ or not mplsEsFilterLfib_:
      # Etba expects both Lfib mounted
      t8( "Either decap lfib or es filter lfib or both not found." )
      return None
   routeKey = Tac.Type( "Mpls::RouteKey" ).fromLabel( label )
   decapRoute = mplsLfibStatus_.lfibRoute.get( routeKey, None )
   esFilterRoute = mplsEsFilterLfib_.lfibRoute.get( routeKey, None )
   lfibRoute = decapRoute if decapRoute else esFilterRoute
   if not lfibRoute:
      t8( "No LfibRoute exists in decapLfib and es filter lfib for", routeKey )
      return None
   if lfibRoute.source != 'lfibSourceBgpL2Evpn':
      t8( "Source is not recognized", lfibRoute.source )
      return None
   viaSet = mplsLfibStatus_.viaSet.get( lfibRoute.viaSetKey ) if decapRoute else \
         mplsEsFilterLfib_.viaSet.get( lfibRoute.viaSetKey )
   if not viaSet or not viaSet.viaKey:
      t8( "No viaSet or viaKeys exists for", lfibRoute.viaSetKey )
      return None
   # For L2EVPN we only have one viaKey in the viaSet and pick the
   # first one.
   return viaSet.viaKey[ 0 ]

def updateMplsFloodFilter( vlan ):
   vlanFloodStatus = evpnStatus_.vlanFloodStatus.get( vlan )
   if vlanFloodStatus:
      mplsFloodFilter_.update( set( vlanFloodStatus.floodFilter.keys() ) )
      t8( "Updating Mpls Flood-filter to contain", mplsFloodFilter_ )

def handleVlanVia( viaKey, innerData ):
   """Rewrite innerData to contain the vlanId identified by the vlanVia"""
   if viaKey not in mplsLfibStatus_.vlanVia:
      return ( False, innerData )
   vlanVia = mplsLfibStatus_.vlanVia[ viaKey ]
   return ( True, rewriteVlanTag( innerData, vlanVia.vlanId ) )

def handleVlanFloodVia( viaKey, innerData ):
   """Rewrite innerData to carry vlanId from vlanFloodVia and update MPLS
   flood-filter to carry all non-DF ports."""
   if viaKey not in mplsLfibStatus_.evpnVlanFloodVia:
      return ( False, innerData )
   vlanFloodVia = mplsLfibStatus_.evpnVlanFloodVia[ viaKey ]
   updateMplsFloodFilter( vlanFloodVia.vlanId )
   return ( True, rewriteVlanTag( innerData, vlanFloodVia.vlanId ) )

def rewriteVlanFromDot1qTag( innerData, dot1qTagToVlanColl ):
   """Helper to return the rewrite VLAN using dot1q tag from innerData as
   key to look in the dot1q -> vlan collection"""
   if not dot1qTagToVlanColl:
      t8( "Dot1qTagToVlanColl is not found" )
      return None
   ( __, headers, __ ) = PktParserTestLib.parsePktStr( innerData )
   innerDot1QHdr = PktParserTestLib.findHeader( headers, 'EthDot1QHdr' )
   if not innerDot1QHdr:
      t8( "Inner dot1QHeader is not found" )
      return None
   normalizedVlan = innerDot1QHdr.tagControlVlanId
   return dot1qTagToVlanColl.dot1qTagToVlanId.get( normalizedVlan )

def handleVlanAwareVia( viaKey, innerData ):
   """Use the dot1qTag and vlanAwareVia to do the VLAN translation"""
   if viaKey not in mplsLfibStatus_.evpnVlanAwareVia:
      return ( False, innerData )
   vlanAwareVia = mplsLfibStatus_.evpnVlanAwareVia[ viaKey ]
   dot1qTagToVlanColl = mplsLfibStatus_.dot1qTagToVlanColl.get(
      vlanAwareVia.dot1qTagToVlanKey, None )
   rewriteVlan = rewriteVlanFromDot1qTag( innerData, dot1qTagToVlanColl )
   return ( True, rewriteVlanTag( innerData, rewriteVlan ) ) if rewriteVlan else \
      ( False, innerData )

def handleVlanAwareFloodVia( viaKey, innerData ):
   """Same as handleVlanAwareVia but also updates mplsFloodFilter"""
   if viaKey not in mplsLfibStatus_.evpnVlanAwareFloodVia:
      return ( False, innerData )
   vlanAwareFloodVia = mplsLfibStatus_.evpnVlanAwareFloodVia[ viaKey ]
   dot1qTagToVlanColl = mplsLfibStatus_.dot1qTagToVlanColl.get(
      vlanAwareFloodVia.dot1qTagToVlanKey, None )
   rewriteVlan = rewriteVlanFromDot1qTag( innerData, dot1qTagToVlanColl )
   if not rewriteVlan:
      return ( False, innerData )
   updateMplsFloodFilter( rewriteVlan )
   return ( True, rewriteVlanTag( innerData, rewriteVlan ) )

def handleEvpnESVia( viaKey, innerData ):
   """Update the operational flood-set using MplsFloodFilter"""
   evpnESFilterVia = mplsEsFilterLfib_.evpnEthernetSegmentFilterVia.get( viaKey )
   if not evpnESFilterVia:
      return ( False, innerData )
   filterIntf = evpnESFilterVia.intfId
   t8( "Filtering interface", filterIntf, "from flood-set" )
   if filterIntf:
      mplsFloodFilter_.add( filterIntf )
   t8( "Final list of filtered interfaces", mplsFloodFilter_ )
   return ( True, innerData )

viaTypeHandler = {
   'viaTypeVlan': handleVlanVia,
   'viaTypeEvpnVlanAware': handleVlanAwareVia,
   'viaTypeEvpnVlanFlood': handleVlanFloodVia,
   'viaTypeEvpnVlanAwareFlood': handleVlanAwareFloodVia,
   'viaTypeEvpnEthernetSegment': handleEvpnESVia
}

def dispatchLabelHandler( label, innerData ):
   """Process the label and pass down the control to the respective handler"""
   viaKey = getEvpnViaKey( label )
   t8( "Dispatching via handler for", viaKey )
   return viaTypeHandler[ viaKey.viaType ]( viaKey, innerData ) \
      if viaKey and viaKey.viaType in viaTypeHandler else ( False, False )

def handleTunnelPkt( bridge, dstMacAddr, data, srcPort ):
   """If a MPLS encapsulated packet arrives which has label mapped to
   L2EVPN via, decap the header and apply the label action to
   construct a new packet. Update mplsFloodFilter_ state based on the
   new VLAN and EvpnVlanFloodStatus. NOTE: Outer LSP Label must be
   stripped if we follow PHP. Returns tuple ( changedData,
   srcPort=MTI, drop=False, highOrderVlanBits=None) if we can
   successfully decap the pktHdr else ( False, False, False, None ).
   """
   retVal = ( False, False, False, None )
   mplsFloodFilter_.clear()
   if not mplsTrunkIntf_:
      t8( "MplsTrunkInterface is not available" )
      return retVal

   if not isMpls( data ):
      t8( "Non-MPLS frame (MplsPktHelper), skip" )
      return retVal

   bridgeMacAddr = bridge.bridgeMac()
   if dstMacAddr != bridgeMacAddr:
      t8( "dstMacAddr", dstMacAddr, "is different from bridgeMac", bridgeMacAddr )
      return retVal

   # Parse the packet and extract all the MPLS headers.
   ( pkt, headers, curOffset ) = PktParserTestLib.parsePktStr( data )
   mplsHdr = PktParserTestLib.findHeader( headers, "MplsHdr" )
   if not mplsHdr:
      t8( "Cannot find MplsHdr" )
      return retVal

   # We are done only when bottom-of-stack bit is set.
   mplsHeaders = [ mplsHdr ]
   while mplsHdr.bos is False:
      headers = []
      curOffset = PktParserTestLib.mplsParser( pkt, curOffset, headers )
      mplsHdr = PktParserTestLib.findHeader( headers, "MplsHdr" )
      assert mplsHdr is not None
      mplsHeaders.append( mplsHdr )

   # Some more validation to only handle at most 2 inner labels.
   if len( mplsHeaders ) > 2:
      t8( "Evpn Mpls doesn't recognize", len( mplsHeaders ), "labels" )
      return retVal
   innerData = data[ curOffset + 4: ] if evpnStatus_.controlWordPresent else \
               data[ curOffset: ]
   for mplsHeader in mplsHeaders:
      ( validLabel, innerData ) = dispatchLabelHandler( mplsHeader.label,
                                                        innerData )
      if not validLabel:
         return retVal

   # Return the modified packet with srcPort as MplsTrunkInterface.
   return ( innerData, mplsTrunkPort_, False, None )

# EvpnVlanFloodStatus::floodFilter contains all non-DF ports that
# must not be part of flood-set and used for filtering naked BUM
# packets. MplsFloodFilter managed by preTunnelHandler is used
# for tunneled BUM packets.
def handleFloodSet( bridge, finalIntfs, srcPort=None, dropReasonList=None,
                    vlanId=None, dstMacAddr=None, data=None ):
   t8( "handleFloodSet: EbraFloodSet", finalIntfs, "srcPort", srcPort,
       "mplsTrunkIntf", mplsTrunkIntf_, "vlanId", vlanId, "dstMacAddr", dstMacAddr )
   if srcPort is None:
      t8( "Unknown source interface" )
      return
   if vlanId is None or vlanId < 1 or vlanId > 4094:
      t8( "vlanId is None or out of bounds" )
      return
   ( destMatch, _ ) = bridge.destLookup( vlanId, dstMacAddr ) \
                      if dstMacAddr else ( False, None )
   if destMatch:
      t8( "Not a BUM destination" )
      return

   if evpnStatus_ is None or not evpnStatus_.mplsEnabled:
      t8( "MPLS EVPN is not in use" )
      return

   vlanFloodStatus = evpnStatus_.vlanFloodStatus.get( vlanId ) \
                     if evpnStatus_ else None
   floodFilter = set()
   if srcPort.name() != mplsTrunkIntf_ and vlanFloodStatus:
      floodFilter = set( vlanFloodStatus.floodFilter.keys() )
   elif srcPort.name() == mplsTrunkIntf_:
      floodFilter = mplsFloodFilter_
   t8( "floodFilter", floodFilter, "srcPort", srcPort.name() )
   # Remove interfaces in floodFilter from the finalIntfs collection.
   for intf in list( finalIntfs ):
      if intf in floodFilter:
         t8( "Flooding suppressed by Evpn on", intf )
         del finalIntfs[ intf ]

def Plugin( ctx ):
   t2( "MplsEvpn plugin registering" )

   # Logical MplsTrunkIntf to add the Mpls encap header for sending
   # frames to remote TEPs.
   ctx.registerInterfaceHandler( 'Interface::EthPhyIntfStatus',
                                 EbraTestMplsTrunkPort )

   # Etba Agent Init. Mount the required state needed for this plugin.
   ctx.registerAgentInitHandler( handleAgentInit )

   # EbraTestBridge Init. Create necessary reactors as Etba mounts
   # have complete.
   ctx.registerBridgeInitHandler( handleBridgeInit )

   # Re-add any aged EVPN remote MACs.
   ctx.registerPostAgingHandler( handleAgedMac )

   # Decap Mpls header to drop the packet into L2 switching pipeline.
   ctx.registerPreTunnelHandler( handleTunnelPkt )

   # Suppress flooding based on ingress interface ( srcPort ) where
   # the packet entered the DUT.
   ctx.registerPacketReplicationHandler( handleFloodSet )

   # Post bridging modifications and changing action based on egress intf
   ctx.registerPostBridgingHandler( postBridging )
