# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
'''
Add L3 EVPN MPLS Support to the Etba data plane

This code depends on the mplsBridgeInit and mplsAgentInit related
infrastructure / base code, and so will be registered in the MplsEtba
Plugin as a helper, instead of being registered directly in the Plugin
method in this module itself.
'''

from Arnet import (
      IpGenAddr,
      IpGenPrefix,
      PktParserTestLib,
)
from collections import defaultdict
from IpLibConsts import DEFAULT_VRF
from MplsPktHelper import isMpls  # pylint: disable=no-name-in-module
from EbraTestBridgePlugin.MplsEtbaDecapUtils import (
    IpStatusReactor,
    findDecapRoute,
    findIntfPortInVrf,
)
import SharedMem
import Smash
import Tac
import Tracing
from EbraTestBridgeLib import (
   HANDLER_PRIO_HIGH,
   PKTEVENT_ACTION_ADD,
   applyVlanChanges,
)
from TypeFuture import TacLazyType

th = Tracing.Handle( 'EbraTestBridgeMplsL3Evpn' )
terr = th.trace0
tinfo = th.trace1
tverb = th.trace2
tfunc = th.trace8
PacketFormat = Tracing.HexDump

EthIntfId = TacLazyType( 'Arnet::EthIntfId' )
MplsLabel = TacLazyType( 'Arnet::MplsLabel' )
PayloadType = TacLazyType( 'Mpls::PayloadType' )
VlanIntfId = TacLazyType( 'Arnet::VlanIntfId' )

noMatch = ( None, None, False )
drop = ( None, None, True )
noMatchWithHighVlan = ( None, None, False, None )
dropWithHighVlan = ( None, None, True, None )
lspPingUdpPort = 3503

def isExpNull( label ):
   return label in ( MplsLabel.explicitNullIpv4,
                     MplsLabel.explicitNullIpv6 )

def isEntropyLabelIndicator( label ):
   return label == MplsLabel.entropyLabelIndicator

def isValidMplsIpPayload( mplsHdr, pkt, addressFamily, func=None ):
   func = func or 'isValidMplsIpPayload'

   if mplsHdr.ttl == 0:
      tinfo( func, 'drop, MPLS ttl', mplsHdr.ttl )
      return False

   ipPacket = pkt.stringValue[ mplsHdr.offset + 4 : ]
   if not ipPacket:
      tinfo( func, 'drop, no data for IP packet' )
      return False

   # Validate the AF if it's not autoDecide
   ipVersion = ipPacket[ 0 ] >> 4
   if ( addressFamily != PayloadType.autoDecide and
        ipVersion != int( addressFamily[ -1 ] ) ):
      tinfo( func, 'drop, ipVersion', ipVersion, 'does not match expected',
             addressFamily )
      return False
   return True

def isValidAF( af, vrfName, func=None ):
   if af not in ( PayloadType.ipv4, PayloadType.ipv6, PayloadType.autoDecide ):
      terr( func, 'invalid via addressFamily', af, 'vrf', vrfName )
      return False
   return True

def decappedPktAndPort( bridge, vrfName, data, label, func=None ):
   func = func or 'decappedPktAndPort'
   port, vlanId = findIntfPortInVrf( bridge, vrfName, bridge.mplsL3VrfInterfaces_ )
   if port is None:
      terr( func, 'no interface for vrf', vrfName )
      return noMatch
   tinfo( func, 'pop label', label, 'in vrf', vrfName, 'src port', port.name() )
   if vlanId:
      # This packet goes to an SVI interface, insert vlan tag for the trunk ports.
      data = applyVlanChanges( data, 0, vlanId, PKTEVENT_ACTION_ADD )
   return ( data, port, False )

def isValidMplsDecapPacket( bridge, pkt, headers, func=None ):
   '''Terminate only if the received packet is in the following required format'''
   func = func or 'isValidMplsDecapPacket'
   # Only terminate packets that has the following format:
   # 1. EthHdr + MplsHdr + Payload
   # 2. EthHdr + EthDot1QHdr + MplsHdr + Payload
   #
   # Skip the packets such as:
   # EthHdr + EthDot1QHdr + EthDot1QHdr + MplsHdr + Payload

   headerType, ethHdr = headers[ 0 ]
   if headerType != 'EthHdr' or not ethHdr:
      tinfo( func, 'cannot find ethHdr' )
      return False

   mplsHdr = PktParserTestLib.findHeader( headers, "MplsHdr" )
   if not mplsHdr:
      # If we can't find mplsHdr at all, then no match.
      tinfo( func, 'cannot find MPLS header' )
      return False

   if ethHdr.ethType == 'ethTypeMpls':
      # Eth header followed by MPLS header
      pass
   elif ethHdr.ethType == 'ethTypeDot1Q':
      headerType, dot1qHdr = headers[ 1 ]
      assert headerType == "EthDot1QHdr" and dot1qHdr
      if dot1qHdr.ethType == 'ethTypeMpls':
         # Eth header followed by a vlan tag followed by MPLS header
         pass
      else:
         # A vlan tag followed by other ether type, no match
         tinfo( func, 'not MPLS (type: %s)' % dot1qHdr.ethType )
         return False
   else:
      tinfo( func, 'not MPLS (type: %s)' % ethHdr.ethType )
      return False
   return True

def isMplsOamHdr( headers ):
   '''
   Determines whether the IP/UDP headers makes up an OAM payload or not
   '''
   func = 'isMplsOamHdr'
   ipHdr = PktParserTestLib.findHeader( headers, 'IpHdr' )
   ip6Hdr = PktParserTestLib.findHeader( headers, 'Ip6Hdr' )
   udpHdr = PktParserTestLib.findHeader( headers, 'UdpHdr' )
   if ipHdr and udpHdr:
      # Trap the OAM MPLS packet if this is Egress LSR
      ipv4OamPrefix = IpGenPrefix( '127.0.0.0/8' )
      if ( ipv4OamPrefix.contains( IpGenAddr( ipHdr.dst ) ) and
           udpHdr.dstPort == lspPingUdpPort ):
         tinfo( func, "Skip processing MPLS OAM IP packet in reserved handlers" )
         return True
   elif ip6Hdr and udpHdr:
      # Trap the OAM MPLS packet if this is Egress LSR
      ipv6OamPrefix1 = IpGenPrefix( '0:0:0:0:0:FFFF:7F00::/104' )
      ipv6OamPrefix2 = IpGenPrefix( '::1/128' )
      if ( ( ipv6OamPrefix1.contains( IpGenAddr( ip6Hdr.dst.stringValue ) ) or
           ipv6OamPrefix2.contains( IpGenAddr( ip6Hdr.dst.stringValue ) ) ) and
           udpHdr.dstPort == lspPingUdpPort ):
         tinfo( func, "Skip processing MPLS OAM IPv6 packet in reserved handlers" )
         return True
   return False

def isMplsOamPkt( bridge, destMac, data, srcPort ):
   '''
   Skip processing of MPLS OAM packet with reserved labels, as it needs to be
   punted to CPU at a later stage with reserved label.
   '''
   func = 'isMplsOamPkt'

   if not bridge.mplsRoutingInfo_.mplsRouting:
      tinfo( func, 'mplsRouting is disabled' )
      return False

   if not isMpls( data ):
      tinfo( "Non-MPLS frame (MplsPktHelper), skip" )
      return False

   ( pkt, headers, _ ) = PktParserTestLib.parsePktStr( data )
   if not isValidMplsDecapPacket( bridge, pkt, headers, func ):
      return False

   mplsHdr = PktParserTestLib.findHeader( headers, "MplsHdr" )
   # Make a duplicate copy so that we don't mangle the original packet
   newData = data
   while mplsHdr:
      if mplsHdr.label == MplsLabel.routerAlert:
         return True
      elif isExpNull( mplsHdr.label ):
         ( newData, _, dropPkt ) = \
               explicitNullLabelHandler( bridge, destMac, newData, srcPort )
         if newData is None or dropPkt:
            return False
      elif isEntropyLabelIndicator( mplsHdr.label ):
         ( newData, _, dropPkt ) = \
               entropyLabelHandler( bridge, destMac, newData, srcPort )
         if newData is None or dropPkt:
            return False
      else:
         # This is not an Egress LSR for the current packet
         return False

      ( pkt, headers, _ ) = PktParserTestLib.parsePktStr( newData )
      mplsHdr = PktParserTestLib.findHeader( headers, "MplsHdr" )

   ( pkt, headers, _ ) = PktParserTestLib.parsePktStr( newData,
                                                       parseBeyondMpls=True )
   return isMplsOamHdr( headers )

def entropyLabelHandler( bridge, dstMacAddr, data, srcPort ):
   '''
   Strip the ELI + EL when ELI appears as top label during packet processing
   and return the rest of the packet if packet is a non-OAM packet.
   '''
   func = 'entropyLabelHandler'

   ( pkt, headers, _ ) = PktParserTestLib.parsePktStr( data )
   if not isValidMplsDecapPacket( bridge, pkt, headers, func ):
      return noMatch

   mplsHdr = PktParserTestLib.findHeader( headers, "MplsHdr" )

   if mplsHdr.bos and isEntropyLabelIndicator( mplsHdr.label ):
      tinfo( func, 'ELI cannot be BOS, dropping packet' )
      return drop

   # Pop the ELI MPLS label header
   data = ( pkt.stringValue[ : mplsHdr.offset ] +
            pkt.stringValue[ mplsHdr.offset + 4 : ] )
   tinfo( func, 'popped ELI mpls header out of mpls packet' )

   ( pkt, headers, _ ) = PktParserTestLib.parsePktStr( data )
   ethHdr = PktParserTestLib.findHeader( headers, 'EthHdr' )
   mplsHdr = PktParserTestLib.findHeader( headers, "MplsHdr" )

   if mplsHdr.bos:
      # Possible OAM packet when EL is BOS, guess the ethType
      if pkt.rawByte[ mplsHdr.offset + 4 ] & 0xf0 == 0x40:
         ethHdr.ethType = 'ethTypeIp'
      elif pkt.rawByte[ mplsHdr.offset + 4 ] & 0xf0 == 0x60:
         ethHdr.ethType = 'ethTypeIp6'
      else:
         tinfo( func, 'Unknown payload, dropping packet' )
         return drop

   # Pop the EL MPLS label header
   data = ( pkt.stringValue[ : mplsHdr.offset ] +
            pkt.stringValue[ mplsHdr.offset + 4 : ] )
   tinfo( func, 'popped EL mpls header out of mpls packet' )

   return ( data, False, False )

def explicitNullLabelHandler( bridge, dstMacAddr, data, srcPort ):
   '''
   Strip the outer label if it's an explicit null, until it finds the non-explicit
   null label or until it finds the payload which can be routed in the correct VRF if
   it's just the plain IP packet.
   '''
   func = 'explicitNullLabelHandler'

   ( pkt, headers, _ ) = PktParserTestLib.parsePktStr( data )
   if not isValidMplsDecapPacket( bridge, pkt, headers, func ):
      return noMatch

   mplsHdr = PktParserTestLib.findHeader( headers, "MplsHdr" )
   hasExpNullLabel = isExpNull( mplsHdr.label )
   while hasExpNullLabel:
      ( pkt, headers, _ ) = PktParserTestLib.parsePktStr( data )
      ethHdr = PktParserTestLib.findHeader( headers, 'EthHdr' )
      mplsHdr = PktParserTestLib.findHeader( headers, "MplsHdr" )
      if not mplsHdr.bos:
         # Pop the MPLS header until it finds non-explicit null label
         data = ( pkt.stringValue[ : mplsHdr.offset ] +
                  pkt.stringValue[ mplsHdr.offset + 4 : ] )
         ( pkt, headers, _ ) = PktParserTestLib.parsePktStr( data )
         mplsHdr = PktParserTestLib.findHeader( headers, "MplsHdr" )
         tinfo( func, 'popped explicit null mpls header out of mpls packet' )
         hasExpNullLabel = mplsHdr and isExpNull( mplsHdr.label )
         if not hasExpNullLabel:
            return ( data, False, False )
      else:
         label = mplsHdr.label
         via = findDecapRoute( bridge, label )
         vrfName = via.vrfName
         addressFamily = via.payloadType

         if not isValidAF( addressFamily, vrfName, func ):
            return noMatch

         if not isValidMplsIpPayload( mplsHdr, pkt, addressFamily, func ):
            return drop

         # Construct new packet with the MPLS header removed and reparse
         ethHdr.ethType = ( 'ethTypeIp' if addressFamily == PayloadType.ipv4 else
                            'ethTypeIp6' )
         newData = ( pkt.stringValue[ : PktParserTestLib.EthHdrSize ] +
                     pkt.stringValue[ mplsHdr.offset + 4 : ] )
         ( pkt, headers, _ ) = PktParserTestLib.parsePktStr( newData )

         if addressFamily == PayloadType.ipv4:
            ipHdr = PktParserTestLib.findHeader( headers, 'IpHdr' )
            # The IP TTL will be decremented in the routing phase, not here
            ipHdr.ttl = mplsHdr.ttl
            # Clear checksum and recompute
            ipHdr.checksum = 0
            ipHdr.checksum = ipHdr.computedChecksum
         else:
            ipHdr = PktParserTestLib.findHeader( headers, 'Ip6Hdr' )
            # The IPv6 hop limit will be decremented in the routing phase, not here
            ipHdr.hopLimit = mplsHdr.ttl

         tverb( func, PacketFormat( newData ) )

         return decappedPktAndPort( bridge, vrfName, pkt.stringValue, label, func )
   return noMatch

def isReservedLabel( label ):
   return label < MplsLabel.unassignedMin

def reservedLabelHandler( bridge, destMac, data, srcPort ):
   '''
   This handler is responsible for handling all the reserved labels by invoking their
   respective sub-handlers such as explicit null and in future such as ELI label etc.
   '''
   func = 'reservedLabelHandler'

   if not bridge.mplsRoutingInfo_.mplsRouting:
      tinfo( func, 'mplsRouting is disabled' )
      return noMatchWithHighVlan

   if not isMpls( data ):
      tinfo( "Non-MPLS frame (MplsPktHelper), skip" )
      return noMatchWithHighVlan

   if isMplsOamPkt( bridge, destMac, data, srcPort ):
      return noMatchWithHighVlan

   ( pkt, headers, _ ) = PktParserTestLib.parsePktStr( data )
   if not isValidMplsDecapPacket( bridge, pkt, headers, func ):
      return noMatchWithHighVlan

   mplsHdr = PktParserTestLib.findHeader( headers, "MplsHdr" )
   pktModified = False
   dropPkt = False
   port = None
   while mplsHdr and isReservedLabel( mplsHdr.label ):
      if isExpNull( mplsHdr.label ):
         ( data, port, dropPkt ) = \
               explicitNullLabelHandler( bridge, destMac, data, srcPort )
         if not data or dropPkt:
            return ( data, port, dropPkt, None )
         pktModified = True
      elif isEntropyLabelIndicator( mplsHdr.label ):
         ( data, port, dropPkt ) = \
               entropyLabelHandler( bridge, destMac, data, srcPort )
         if not data or dropPkt:
            return ( data, port, dropPkt, None )
         pktModified = True
      else:
         return ( data, port, dropPkt, None ) if pktModified else noMatchWithHighVlan

      ( pkt, headers, _ ) = PktParserTestLib.parsePktStr( data )
      mplsHdr = PktParserTestLib.findHeader( headers, "MplsHdr" )

   return ( data, port, dropPkt, None ) if pktModified else noMatchWithHighVlan

def mplsL3EvpnDecapPreTunnelHandler( bridge, dstMacAddr, data, srcPort ):
   '''
   Terminate any frame with an L3 EVPN label, and route it in the correct VRF.

   After finding an ipLookup for the MPLS label, the correct VRF name
   is present in the ipLookup VIA.  To have ETBA do the routing in the correct
   VRF, we fake the source interface (by returning a new chgSrcPort interface) as
   any interface in the correct VRF, see 'findIntfPortInVrf'.

   The MPLS TTL from the MPLS header will be copied (unchanged) to the underlying
   IP TTL or IPv6 hop-count field.  It may be decremented by the kernel during
   any applicable routing action.

   This is the plugin signature for preTunnelHdlr methods:
      ( chgData, chgSrcPort, drop ) = preTunnelHdlr( self, dstMacAddr, data )
   '''
   func = 'mplsL3EvpnDecapPreTunnelHandler'
   tfunc( func, PacketFormat( data ) )

   if not bridge.mplsRoutingInfo_.mplsRouting:
      tinfo( func, 'mplsRouting is disabled' )
      return noMatchWithHighVlan

   if not isMpls( data ):
      tinfo( "Non-MPLS frame (MplsPktHelper), skip" )
      return noMatchWithHighVlan

   ( pkt, headers, _ ) = PktParserTestLib.parsePktStr( data, parseBeyondMpls=True )
   if isMplsOamHdr( headers ):
      tinfo( "MPLS OAM packet, skip" )
      return noMatchWithHighVlan

   if not isValidMplsDecapPacket( bridge, pkt, headers, func ):
      return noMatchWithHighVlan

   ethHdr = PktParserTestLib.findHeader( headers, "EthHdr" )
   mplsHdr = PktParserTestLib.findHeader( headers, "MplsHdr" )

   if isReservedLabel( mplsHdr.label ):
      return noMatchWithHighVlan

   label = mplsHdr.label
   via = findDecapRoute( bridge, label )
   if via is None:
      tinfo( func, 'no decap route for label', label )
      return noMatchWithHighVlan
   vrfName = via.vrfName
   addressFamily = via.payloadType

   if mplsHdr.bos is False:
      if vrfName != DEFAULT_VRF:
         tinfo( func, 'non-BOS label', mplsHdr.label, 'on VRF', vrfName )
         return noMatchWithHighVlan
      # If the VRF label points to the default VRF, and it is not the BOS label, we
      # will just pop the top label and leave the rest of the packet as is for other
      # downstream handlers to process
      newData = ( pkt.stringValue[ : PktParserTestLib.EthHdrSize ] +
                  pkt.stringValue[ mplsHdr.offset + 4 : ] )
      tverb( func, PacketFormat( newData ) )
      ( chgData, chgSrcPort, chgDrop ) = decappedPktAndPort(
         bridge, vrfName, newData, label, func )
      return( chgData, chgSrcPort, chgDrop, None )

   if not isValidAF( addressFamily, vrfName, func ):
      return noMatchWithHighVlan

   if not isValidMplsIpPayload( mplsHdr, pkt, addressFamily, func ):
      return dropWithHighVlan

   # If the payloadType is autoDecide, set it to IP or IPv6 here
   if via.payloadType == PayloadType.autoDecide:
      ipPacket = pkt.stringValue[ mplsHdr.offset + 4 : ]
      ipVersion = ipPacket[ 0 ] >> 4
      addressFamily = PayloadType.ipv4 if ipVersion == 4 else PayloadType.ipv6

   # Construct new packet with the Dot1Q and MPLS header removed and reparse
   # i.e., only Ethernet Header with IP payload.
   ethHdr.ethType = ( 'ethTypeIp' if addressFamily == PayloadType.ipv4 else
                      'ethTypeIp6' )
   newData = ( pkt.stringValue[ : PktParserTestLib.EthHdrSize ] +
               pkt.stringValue[ mplsHdr.offset + 4 : ] )

   ( pkt, headers, _ ) = PktParserTestLib.parsePktStr( newData )

   if addressFamily == PayloadType.ipv4:
      ipHdr = PktParserTestLib.findHeader( headers, 'IpHdr' )
      # The IP TTL will be decremented in the routing phase, not here
      ipHdr.ttl = mplsHdr.ttl
      # Clear checksum and recompute
      ipHdr.checksum = 0
      ipHdr.checksum = ipHdr.computedChecksum
   else:
      ipHdr = PktParserTestLib.findHeader( headers, 'Ip6Hdr' )
      # The IPv6 hop limit will be decremented in the routing phase, not here
      ipHdr.hopLimit = mplsHdr.ttl

   tverb( func, PacketFormat( newData ) )
   ( chgData, chgSrcPort, chgDrop ) = decappedPktAndPort(
      bridge, vrfName, pkt.stringValue, label, func )
   return( chgData, chgSrcPort, chgDrop, None )

def doMplsL3EvpnDecapBridgeInit( bridge, decapLfib, ipStatus ):
   bridge.decapLfib_ = decapLfib
   bridge.mplsL3IpStatus_ = ipStatus
   bridge.mplsL3VrfInterfaces_ = defaultdict( set )
   bridge.mplsL3IpStatusReactor_ = IpStatusReactor(
      ipStatus, bridge, bridge.mplsL3VrfInterfaces_ )
   bridge.mplsHwCapability_.mplsStaticVrfLabelSupported = True

def mplsL3EvpnDecapBridgeInit( bridge ):
   func = 'mplsL3EvpnDecapBridgeInit'
   tinfo( func )
   decapLfib = bridge.sEm().getTacEntity( 'mpls/decapLfib' )
   ipStatus = bridge.em().entity( 'ip/status' )
   doMplsL3EvpnDecapBridgeInit( bridge, decapLfib, ipStatus )
   tinfo( func, 'finished' )

def mplsL3EvpnDecapAgentInit( em ):
   # Mount the decapLfib for VRF ipLookup action
   shmemEm = SharedMem.entityManager( sysdbEm=em )
   shmemEm.doMount( 'mpls/decapLfib', 'Mpls::LfibStatus',
                    Smash.mountInfo( 'keyshadow' ) )
   # Mount IP status to maintain a mapping of VRF:interface
   mg = em.activeMountGroup()
   Tac.Type( "Ira::IraIpStatusMounter" ).doMountEntities( mg.cMg_, True, False )

def pluginHelper( ctx ):
   #
   # TODO XXX: Following preTunnel handler registration is enabled only to support
   # explicit NULL labels but reservedLabelHandler Has more functionality
   # of using "decapLfib" to perform more decap operations. When this support is
   # needed update following BPF filters accordingly.
   #
   if not ctx.inArfaMode():
      # Reserved label handler
      ctx.registerPreTunnelHandler( reservedLabelHandler, HANDLER_PRIO_HIGH )
      ctx.registerPreTunnelHandler( mplsL3EvpnDecapPreTunnelHandler )

   ctx.registerBridgeInitHandler( mplsL3EvpnDecapBridgeInit )
   ctx.registerAgentInitHandler( mplsL3EvpnDecapAgentInit )
