# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
'''
Add mLDP MPLS Support to the Etba data plane

This code depends on the mplsBridgeInit and mplsAgentInit related
infrastructure / base code, and so will be registered in the MplsEtba
Plugin as a helper, instead of being registered directly in the Plugin
method in this module itself.
'''

from Arnet import (
      PktParserTestLib,
      IpGenAddr,
      IpGenPrefix,
      IpTestLib,
)
from collections import defaultdict
from TypeFuture import TacLazyType
import SharedMem
import Smash
import Tac
import Tracing
from EbraTestBridgeLib import (
   PKTEVENT_ACTION_ADD,
   applyVlanChanges,
)
from EbraTestBridgePlugin.MplsEtbaDecapUtils import (
   findDecapRoute,
   findIntfPortInVrf,
   getVlanPhyIntf,
   IpStatusReactor,
   findDecapMulticastIpLookupVia,
)
import Cell
from MplsEtbaLib import BUD_ROLE_SPL_PKT

th = Tracing.Handle( 'EbraTestBridgeMplsMldp' )
terr = th.trace0
tinfo = th.trace1
tverb = th.trace2
tfunc = th.trace8
PacketFormat = Tracing.HexDump

EthIntfId = TacLazyType( 'Arnet::EthIntfId' )
LfibViaSetType = TacLazyType( 'Mpls::LfibViaSetType' )
MplsLabel = TacLazyType( 'Arnet::MplsLabel' )
VlanIntfId = TacLazyType( 'Arnet::VlanIntfId' )

def getVrfNameFromVrfId( bridge, vrfId ):
   ''' Get the vrfName corresponding to the vrfId '''
   vrfName = None
   if bridge.vrfIdMapStatus:
      vrfEntry = bridge.vrfIdMapStatus.vrfIdToName.get( vrfId )
      if vrfEntry:
         vrfName = vrfEntry.vrfName
   tinfo( 'getVrfNameFromVrfId vrfName ', vrfName, ' for vrfId ', vrfId )
   return vrfName

def getVrfIdFromPmsiIntfId( bridge, pmsiIntfId ):
   ''' Get VrfId for the PmsiIntfId in MvpnIntfStatus '''
   resVrfId = None
   for vrfId, intfId in bridge.mvpnIntfStatus.vrfIntfId.items():
      if intfId.pmsiIntfId == pmsiIntfId:
         resVrfId = vrfId
         break
   tinfo( 'getVrfIdFromPmsiIntfId vrfId ', resVrfId, ' for pmsiIntfId ', pmsiIntfId )
   return resVrfId

def mplsMldpDecapPreTunnelHandler( bridge, dstMacAddr, data, srcPort ):
   '''
   Terminate any frame with an L3 mLDP label, and route it in the correct VRF.

   After finding an ipLookup for the MPLS label, the correct VRF name
   is present in the ipLookup VIA.  To have ETBA do the routing in the correct
   VRF, we fake the source interface (by returning a new chgSrcPort interface) as
   any interface in the correct VRF, see 'findIntfInVrf'.

   The MPLS TTL from the MPLS header will be copied (unchanged) to the underlying
   IP TTL or IPv6 hop-count field.  It may be decremented by the kernel during
   any applicable routing action.

   This is the plugin signature for preTunnelHdlr methods:
      ( chgData, chgSrcPort, drop ) = preTunnelHdlr( self, dstMacAddr, data )
   '''
   func = 'mplsMldpDecapPreTunnelHandler'
   noMatch = ( None, None, False, None )
   drop = ( None, None, True, None )
   tfunc( func, PacketFormat( data ) )
   if not bridge.mplsRoutingInfo_.mplsRouting:
      tinfo( func, 'mplsRouting is disabled' )
      return noMatch

   ( pkt, headers, _ ) = PktParserTestLib.parsePktStr( data )

   # Only terminate packets that has the following format:
   # 1. EthHdr + MplsHdr + Payload
   # 2. EthHdr + EthDot1QHdr + MplsHdr + Payload
   #
   # Skip the packets such as:
   # EthHdr + EthDot1QHdr + EthDot1QHdr + MplsHdr + Payload
   headerType, ethHdr = headers[ 0 ]
   if headerType != 'EthHdr' or not ethHdr:
      tinfo( func, 'cannot find ethHdr' )
      return noMatch

   mplsHdr = PktParserTestLib.findHeader( headers, "MplsHdr" )
   if not mplsHdr:
      # If we can't find mplsHdr at all, then no match.
      tinfo( func, 'cannot find MPLS header' )
      return noMatch
   else:
      tinfo( func, 'found MPLS header' )

   if ethHdr.ethType == 'ethTypeMpls':
      # Eth header followed by MPLS header
      pass
   elif ethHdr.ethType == 'ethTypeDot1Q':
      headerType, dot1qHdr = headers[ 1 ]
      assert headerType == "EthDot1QHdr" and dot1qHdr
      if dot1qHdr.ethType == 'ethTypeMpls':
         # Eth header followed by a vlan tag followed by MPLS header
         pass
      else:
         # A vlan tag followed by other ether type, no match
         tinfo( func, 'not MPLS (type: %s)', dot1qHdr.ethType )
         return noMatch
   else:
      tinfo( func, 'not MPLS (type: %s)', ethHdr.ethType )
      return noMatch

   if not mplsHdr.bos:
      tinfo( func, 'non-BOS label', mplsHdr.label )
      return noMatch

   label = mplsHdr.label
   via = findDecapRoute( bridge, label, viaSetType=LfibViaSetType.multicast )
   tinfo( func, 'label is ', label )
   if via is None:
      tinfo( func, 'no decap route for label based on viaSet ', label )
      # Check the transitLfib for BUD Role Entry
      via = findDecapMulticastIpLookupVia( bridge, label )
      if via is None:
         return noMatch
      else:
         tinfo( func, 'found BUD route for label', label )
         # ingressIntf of multicastIpLookup via should point to PMSI vlan
         assert VlanIntfId.isVlanIntfId( via.ingressIntf )
         pmsiVlan = VlanIntfId.vlanId( via.ingressIntf )
         bridge.packetContext[ BUD_ROLE_SPL_PKT ] = [ dstMacAddr, srcPort, data,
                                                      pmsiVlan ]

   if via.viaType == 'viaTypeMulticastIpLookup':
      vrfId = getVrfIdFromPmsiIntfId( bridge, via.ingressIntf )
      if vrfId is None:
         return drop
      vrfName = getVrfNameFromVrfId( bridge, vrfId )
      if vrfName is None:
         return drop
   else:
      vrfName = via.vrfName

   tinfo( func, ' vrfName ', vrfName, ' for via ', via )
   addressFamily = via.payloadType

   if addressFamily == 'ipv4':
      newEthType = 'ethTypeIp'
   elif addressFamily == 'ipv6':
      newEthType = 'ethTypeIp6'
   else:
      terr( func, 'invalid via addressFamily', addressFamily, 'vrf', vrfName )
      return noMatch

   if mplsHdr.ttl == 0:
      tinfo( func, 'drop, MPLS ttl', mplsHdr.ttl )
      return drop

   ipPacket = pkt.stringValue[ mplsHdr.offset + 4 : ]
   if not ipPacket:
      tinfo( func, 'drop, no data for IP packet' )
      return drop
   ipVersion = ipPacket[ 0 ] >> 4
   if ipVersion != int( addressFamily[ -1 ] ):
      tinfo( func, 'drop, ipVersion', ipVersion, 'does not match expected',
             addressFamily )
      return drop

   # Construct new packet with the Dot1Q and MPLS header removed and reparse
   # i.e., only Ethernet Header with IP payload.
   ethHdr.ethType = newEthType
   newData = ( pkt.stringValue[ : PktParserTestLib.EthHdrSize ] + ipPacket )
   ( pkt, headers, _ ) = PktParserTestLib.parsePktStr( newData )
   if addressFamily == 'ipv4':
      ipHdr = PktParserTestLib.findHeader( headers, 'IpHdr' )
      udpHdr = PktParserTestLib.findHeader( headers, 'UdpHdr' )
      # Trap the OAM MPLS packet if this is Egress LSR
      ipv4OamPrefix = IpGenPrefix( '127.0.0.0/8' )
      if ( ipHdr and
           ipv4OamPrefix.contains( IpGenAddr( ipHdr.dst ) ) and
           udpHdr and udpHdr.dstPort == 3503 ):
         tinfo( func, "trapping MPLS OAM IP packet" )
         srcPort.trapFrame( data )
         return drop

      # The IP TTL will be decremented in the routing phase, not here
      ipHdr.ttl = mplsHdr.ttl
      # Clear checksum and recompute
      ipHdr.checksum = 0
      ipHdr.checksum = ipHdr.computedChecksum
   else:
      ipHdr = PktParserTestLib.findHeader( headers, 'Ip6Hdr' )
      # The IPv6 hop limit will be decremented in the routing phase, not here
      ipHdr.hopLimit = mplsHdr.ttl
   ethHdr = PktParserTestLib.findHeader( headers, 'EthHdr' )
   ethHdr.dst = IpTestLib.ipMcastMacAddr( ipHdr.dst )

   tverb( func, PacketFormat( newData ) )
   if via.viaType == 'viaTypeMulticastIpLookup':
      # ingressIntf of multicastIpLookup via should point to PMSI vlan
      assert VlanIntfId.isVlanIntfId( via.ingressIntf )
      vlanId = VlanIntfId.vlanId( via.ingressIntf )
      phyIntfName = getVlanPhyIntf( bridge, vlanId )
      port = bridge.port.get( phyIntfName )
   else:
      # Non MVPN static config where the viaType is viaTypeIpLookup
      assert via.viaType == 'viaTypeIpLookup'
      # pick some vlan, port belonging to the vrf of via
      port, vlanId = findIntfPortInVrf( bridge, vrfName,
                                        bridge.mplsMldpVrfInterfaces_ )

   if port is None:
      terr( func, 'no interface for vrf', vrfName )
      return noMatch
   tinfo( func, 'pop label', label, 'in vrf', vrfName, 'src port', port.name() )
   data = pkt.stringValue
   tinfo( func, 'pkt is ', data )
   if vlanId:
      # This packet goes to an SVI interface, insert vlan tag for the trunk ports.
      data = applyVlanChanges( pkt.stringValue, 0, vlanId, PKTEVENT_ACTION_ADD )
   tinfo( func, 'port is ', port )
   return ( data, port, False, None )

def doMplsMldpDecapBridgeInit( bridge, decapLfib, ipStatus ):
   bridge.decapLfib_ = decapLfib
   bridge.mplsMldpIpStatus_ = ipStatus
   bridge.mplsMldpVrfInterfaces_ = defaultdict( set )
   bridge.mplsMldpIpStatusReactorMldp_ = IpStatusReactor(
      ipStatus, bridge,
      bridge.mplsMldpVrfInterfaces_ )

def mplsMldpDecapBridgeInit( bridge ):
   func = 'mplsMldpDecapBridgeInit'
   tinfo( func )
   decapLfib = bridge.sEm().getTacEntity( 'mpls/decapLfib' )
   ipStatus = bridge.em().entity( 'ip/status' )
   tacType = 'Routing::Multicast::MvpnVrfStatus'
   bridge.mvpnVrfStatus = bridge.sEm().doMount(
         Tac.Type( tacType ).mountPath( "ipv4", "bgp" ),
         tacType, Smash.mountInfo( 'shadow' ) )
   tacType = 'Routing::Multicast::MvpnIntfStatus'
   bridge.mvpnIntfStatus = bridge.sEm().getTacEntity(
         Tac.Type( tacType ).mountPath() )
   bridge.vrfNameStatus = bridge.em().entity( Cell.path( 'vrf/vrfNameStatus' ) )
   doMplsMldpDecapBridgeInit( bridge, decapLfib, ipStatus )
   tinfo( func, 'finished' )
   bridge.vrfIdMapStatus = bridge.sEm().getEntity[ "vrf/vrfIdMapStatus" ]
   if not bridge.vrfIdMapStatus:
      bridge.vrfIdMapStatus = bridge.sEm().doMount( "vrf/vrfIdMapStatus",
                                             "Vrf::VrfIdMap::Status",
                                             Smash.mountInfo( 'keyshadow' ) )
   bridge.transitLfib_ = bridge.sEm().doMount( 'mpls/transitLfib',
         'Mpls::LfibStatus', Smash.mountInfo( 'keyshadow' ) )

def mplsMldpDecapAgentInit( em ):
   # Mount the decapLfib for VRF ipLookup action
   shmemEm = SharedMem.entityManager( sysdbEm=em )
   shmemEm.doMount( 'mpls/decapLfib', 'Mpls::LfibStatus',
                    Smash.mountInfo( 'keyshadow' ) )
   # Mount IP status to maintain a mapping of VRF:interface
   mg = em.activeMountGroup()
   Tac.Type( "Ira::IraIpStatusMounter" ).doMountEntities( mg.cMg_, True, False )
   em.mount( Cell.path( 'vrf/vrfNameStatus' ),
         'Vrf::VrfIdMap::NameToIdMapWrapper', 'r' )

def pluginHelper( ctx ):
   ctx.registerPreTunnelHandler( mplsMldpDecapPreTunnelHandler )
   ctx.registerBridgeInitHandler( mplsMldpDecapBridgeInit )
   ctx.registerAgentInitHandler( mplsMldpDecapAgentInit )
