#!/usr/bin/env python3
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

'''
Add L3 EVPN MPLS Support to the Etba data plane

This code depends on the mplsBridgeInit and mplsAgentInit related
infrastructure / base code, and so will be registered in the MplsEtba
Plugin as a helper, instead of being registered directly in the Plugin
method in this module itself.
'''

from Arnet import (
      Ip6Prefix,
      IpGenPrefix,
      PktParserTestLib,
      Prefix,
   )
from Arnet.MplsLib import (
   constructMplsHeader,
)
from ForwardingHelper import resolveTunnel
from EbraTestBridgePlugin.FwdIntfEtba import FwdIntfDevice
import FibUtils
import QuickTrace
import SharedMem
import Smash
import Tac
import Tracing

from MplsEtbaLib import (
   computeFlowOrEntropyLabel,
   ecmpHash,
   ELI,
   removeImpNullFromLabelStack,
)
from Toggles.RoutingLibToggleLib import toggleFibGenMountPathEnabled

from IpLibConsts import DEFAULT_VRF

# pkgdeps: library Trie

th = Tracing.Handle( 'MplsL3EvpnEtbaEncap' )
terr = th.trace0
tinfo = th.trace1
tverb = th.trace2
tfunc = th.trace8

qtInitialized = False

def qTrace():
   '''Initializes QuickTrace once, return the module itself'''
   global qtInitialized
   if not qtInitialized:
      QuickTrace.initialize( 'etba-routingHandlers.qt' )
      qtInitialized = True
   return QuickTrace

def qv( *args ):
   qTrace().Var( *args )

def qt0( *args ):
   qTrace().trace0( *args )

PacketFormat = Tracing.HexDump
EthIntfId = Tac.Type( 'Arnet::EthIntfId' )
FecIdIntfId = Tac.Type( 'Arnet::FecIdIntfId' )
TunnelTableIdentifier = Tac.Type( "Tunnel::TunnelTable::TunnelTableIdentifier" )
TunnelTableMounter = Tac.Type( "Tunnel::TunnelTable::TunnelTableMounter" )

def mplsL3EvpnEncapFwdIntfRouting( mplsFwdIntf, data ):
   '''
   Handle frames routed by the kernel onto the fwd0 interface in a given VRF.

   This is a plugin / handler for all frames received by the fwd0 interface,
   and so it should return NO_MATCH if there is no L3 EVPN route for the given
   destination.
   '''

   func = 'mplsL3EvpnEncapFwdIntfRouting'
   noMatch = FwdIntfDevice.NO_MATCH

   bridge = mplsFwdIntf.bridge
   vrfName = mplsFwdIntf.vrfName

   tfunc( func, 'vrf', vrfName, PacketFormat( data ) )

   if not bridge.mplsRoutingInfo_.mplsRouting:
      tinfo( func, 'mplsRouting is disabled' )
      qt0( func, 'mplsRouting is disabled' )
      return noMatch

   ( pkt, headers, _ ) = PktParserTestLib.parsePktStr( data )
   ethHdr = PktParserTestLib.findHeader( headers, 'EthHdr' )
   if not ethHdr:
      tinfo( func, 'ethHdr not found' )
      qt0( func, 'ethHdr not found' )
      return noMatch
   if ethHdr.ethType == 'ethTypeIp':
      af = 'ipv4'
      ipHdr = PktParserTestLib.findHeader( headers, "IpHdr" )
   elif ethHdr.ethType == 'ethTypeIp6':
      af = 'ipv6'
      ipHdr = PktParserTestLib.findHeader( headers, 'Ip6Hdr' )
   else:
      tinfo( func, 'no handling for ethType', ethHdr.ethType )
      qt0( func, 'no handling for ethType', qv( ethHdr.ethType ) )
      return noMatch

   if not ipHdr:
      tinfo( func, 'missing', af, 'header' )
      qt0( func, 'missing', qv( af ), 'header' )
      return noMatch

   # On real hardware the ASIC cannot handle packets with IP
   # options. A packet with IP options is punted back to the CPU and
   # dropped by the DMA driver since it is recognised as a CPU sourced
   # packet. We emulate the same behaviour here by dropping packets
   # with IP options.
   if af == 'ipv4' and ipHdr.hasOptions:
      tinfo( func, 'ignore frame with IP options' )
      qt0( func, 'ignore frame with IP options' )
      return noMatch

   ttl = ipHdr.ttl if af == 'ipv4' else ipHdr.hopLimit
   if ttl == 0:
      tinfo( func, 'ignore', af, 'frame with ttl', ttl )
      qt0( func, 'ignore', qv( af ), 'frame with ttl', qv( ttl ) )
      return noMatch
   tverb( func, 'vrfName', vrfName, 'ip src', ipHdr.src, 'dst', ipHdr.dst,
          'ttl', ttl )

   vrfInfo = _findCreateVrfInfo( bridge, vrfName )
   if not vrfInfo:
      terr( func, 'findCreateVrfInfo failed for vrf', vrfName )
      qt0( func, 'findCreateVrfInfo failed for vrf', qv( vrfName ) )
      return noMatch

   route, fec = vrfInfo.getRoute( ipHdr.dst )
   if route is None or fec is None:
      tinfo( func, 'no route for', ipHdr.dst )
      qt0( func, 'no route for', qv( ipHdr.dst ) )
      return noMatch

   if len( fec.via ) < 1:
      tinfo( func, 'no via for route', str( route.key ), 'fec', fec.fecId )
      qt0( func, 'no via for route',
           qv( str( route.key ) ), 'fec', qv( fec.fecId ) )
      return noMatch

   via = ecmpHash( fec.via, ipHdr.src, ipHdr.dst )
   tunnelInfo = vrfInfo.getTunnel( bridge, via, ipHdr )
   if not tunnelInfo:
      tinfo( func, 'not a tunnel via' )
      qt0( func, 'not a tunnel via' )
      return noMatch
   if not tunnelInfo.resolved:
      terr( func, tunnelInfo.errMsg )
      qt0( func, tunnelInfo.errMsg )
      return noMatch

   labelStack = tunnelInfo.labelStack + [ via.mplsLabel ]
   removeImpNullFromLabelStack( labelStack )

   if ELI in labelStack:
      entropyLabel = computeFlowOrEntropyLabel( pkt, ethHdr, None, ipHdr )
      eliIndexes = [ i + 1 for i, x in enumerate( labelStack ) if x == ELI ]
      for idx in eliIndexes:
         labelStack[ idx ] = entropyLabel

   mplsHdr = constructMplsHeader( labelStack, mplsTtl=ttl )
   ethHdr.dst = tunnelInfo.dstMac
   ethHdr.src = bridge.bridgeMac()
   ethHdr.ethType = 'ethTypeMpls'
   data = pkt.stringValue
   newData = data[ : ipHdr.offset ] + mplsHdr + data[ ipHdr.offset : ]
   return newData, tunnelInfo.intfId

class VrfInfo:
   '''
   Uses Routing/Routing6 TrieBuilder to do LPM when trying to route IP frames
   using L3 EVPN tunnels.
   '''
   def __init__( self, vrfName, l3Config, sEm ):
      self.vrfName = vrfName

      fecModeStatus = Tac.newInstance( 'Smash::Fib::FecModeStatus', 'fms' )
      # fecModeSm is instantiated to generate the correct fecMode in fecModeStatus
      _ = Tac.newInstance( 'Ira::FecModeSm', l3Config, fecModeStatus )

      routeInfo = FibUtils.routeStatusInfo( 'keyshadow' )
      fwdInfo = FibUtils.forwardingStatusInfo( 'keyshadow' )
      if vrfName == DEFAULT_VRF:
         vrfStatus = '/status'
      else:
         vrfStatus = '/vrf/status/%s' % vrfName
      self.rs = sEm.doMount( 'routing' + vrfStatus,
                             'Smash::Fib::RouteStatus', routeInfo )
      self.rs6 = sEm.doMount( 'routing6' + vrfStatus,
                              'Smash::Fib6::RouteStatus', routeInfo )
      if fecModeStatus.fecMode == 'fecModeUnified':
         if toggleFibGenMountPathEnabled():
            self.fs = self.fs6 = None
            self.fsGen = sEm.doMount( 'forwardingGen/unifiedStatus',
                                      'Smash::FibGen::ForwardingStatus', fwdInfo )
         else:
            self.fs = sEm.doMount( 'forwarding/unifiedStatus',
                                   'Smash::Fib::ForwardingStatus', fwdInfo )
            self.fs6 = sEm.doMount( 'forwarding6/unifiedStatus',
                                    'Smash::Fib6::ForwardingStatus', fwdInfo )
            self.fsGen = None
      else:
         if toggleFibGenMountPathEnabled():
            self.fs = self.fs6 = None
            self.fsGen = sEm.doMount( 'forwardingGen' + vrfStatus,
                                      'Smash::FibGen::ForwardingStatus', fwdInfo )
         else:
            self.fs = sEm.doMount( 'forwarding' + vrfStatus,
                                   'Smash::Fib::ForwardingStatus', fwdInfo )
            self.fs6 = sEm.doMount( 'forwarding6' + vrfStatus,
                                    'Smash::Fib6::ForwardingStatus', fwdInfo )
            self.fsGen = None

      self.pfs = sEm.doMount( 'forwarding/srte/status',
                               'Smash::Fib::ForwardingStatus',
                               FibUtils.forwardingStatusInfo( 'reader' ) )
      self.trie = Tac.newInstance( 'Routing::Trie', '%s-trie' % vrfName )
      self.trieBuilder = Tac.newInstance( 'Routing::TrieBuilder',
                                          self.rs, self.trie )
      self.trie6 = Tac.newInstance( 'Routing6::Trie', '%s-trie6' % vrfName )
      self.trie6Builder = Tac.newInstance( 'Routing6::TrieBuilder',
                                           self.rs6, self.trie6 )

   def longestMatch( self, addrOrPrefix ):
      genPrefix = IpGenPrefix( str( addrOrPrefix ) )
      if genPrefix.af == 'ipv4':
         prefix = Prefix( str( genPrefix ) )
         lpm = self.trie.longestMatch[ prefix ]
      elif genPrefix.af == 'ipv6':
         prefix = Ip6Prefix( str( genPrefix ) )
         lpm = self.trie6.longestMatch( prefix )
      else:
         return None
      if lpm.isNullPrefix:
         return None
      return lpm

   def getRoute( self, addr ):
      func = f'getRoute[vrf={self.vrfName}]'
      noMatch = ( None, None )
      genPrefix = IpGenPrefix( str( addr ) )
      lpm = self.longestMatch( addr )
      if lpm is None:
         return noMatch
      if genPrefix.af == 'ipv4':
         rs = self.rs
      elif genPrefix.af == 'ipv6':
         rs = self.rs6
      else:
         return noMatch
      route = rs.route.get( lpm )
      if not route:
         tinfo( func, 'route for lpm', lpm, 'not present in rs', rs )
         return noMatch

      if toggleFibGenMountPathEnabled():
         fec = self.fsGen.fec.get( route.fecId )
         if not fec:
            tinfo( func, 'FEC for lpm', lpm, 'fecId', route.fecId,
                   'not present in fs', self.fsGen )
            return noMatch
      else:
         fec = self.fs.fec.get( route.fecId )
         if not fec:
            tinfo( func, 'FEC for lpm', lpm, 'fecId', route.fecId,
                   'not present in fs', self.fs )
            fec = self.fs6.fec.get( route.fecId )
            if not fec:
               tinfo( func, 'FEC for lpm', lpm, 'fecId', route.fecId,
                      'not present in fs', self.fs6 )
               return noMatch
      return ( route, fec )

   def getTunnel( self, bridge, via, ipHdr ):
      if FecIdIntfId.isFecIdIntfId( via.intfId ):
         # For SR-TE policies, the intfId would be a NextLevelFecId type.
         # In that case, we need to lookup that fec and resolve its SLs.
         srTePolicyFecId = FecIdIntfId.intfIdToFecId( via.intfId )
         srTePolicyFec = self.pfs.fec.get( srTePolicyFecId )
         if not srTePolicyFec:
            return None
         srTePolicyVia = ecmpHash( srTePolicyFec.via, ipHdr.src, ipHdr.dst )
         if srTePolicyVia.tunnelId == 0:
            return None
         tunnelInfoList = bridge.fwdHelper.resolveSrTeTunnel(
            srTePolicyVia.tunnelId )
         if not tunnelInfoList:
            return None
         return ecmpHash( tunnelInfoList, ipHdr.src, ipHdr.dst )
      else:
         if via.tunnelId == 0:
            return None
         return resolveTunnel( bridge.fwdHelper, via.tunnelId )

def _findCreateVrfInfo( bridge, vrfName ):
   if vrfName in bridge.mplsL3EvpnEncapVrfInfo_:
      return bridge.mplsL3EvpnEncapVrfInfo_[ vrfName ]
   l3Config = bridge.em().entity( 'l3/config' )
   vrfInfo = VrfInfo( vrfName, l3Config, bridge.sEm() )
   bridge.mplsL3EvpnEncapVrfInfo_[ vrfName ] = vrfInfo
   return vrfInfo

def mplsL3EvpnEncapBridgeInit( bridge ):
   func = 'mplsL3EvpnEncapBridgeInit'
   tinfo( func )
   shmemEm = SharedMem.entityManager( sysdbEm=bridge.em() )
   tunnelFib = shmemEm.doMount( 'tunnel/tunnelFib', 'Tunnel::TunnelFib::TunnelFib',
                                Smash.mountInfo( 'keyshadow' ) )
   # Add info we need to hang off the bridge
   bridge.mplsL3EvpnEncapVrfInfo_ = {}
   bridge.tunnelFib_ = tunnelFib
   tableInfo = TunnelTableMounter.getMountInfo(
      TunnelTableIdentifier.srTeSegmentListTunnelTable ).tableInfo
   tunnelTable = shmemEm.doMount( tableInfo.mountPath, tableInfo.tableType,
                                  Smash.mountInfo( 'keyshadow' ) )
   bridge.srTeSegmentListTunnelTable = tunnelTable
   tinfo( func, 'finished' )

def mplsL3EvpnEncapAgentInit( em ):
   em.mount( 'l3/config', 'L3::Config', 'r' )
   shmemEm = SharedMem.entityManager( sysdbEm=em )
   shmemEm.doMount( 'tunnel/tunnelFib', 'Tunnel::TunnelFib::TunnelFib',
                    Smash.mountInfo( 'keyshadow' ) )

def pluginHelper( ctx ):
   ctx.registerBridgeInitHandler( mplsL3EvpnEncapBridgeInit )
   ctx.registerAgentInitHandler( mplsL3EvpnEncapAgentInit )
   FwdIntfDevice.registerFwdIntfRoutingHandler( mplsL3EvpnEncapFwdIntfRouting )
