# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
'''
MplsEtbaLib.py

This module provides methods that may be useful for all MPLS-related ETBA plugins.
Previously, there was some code duplication between the EvpnMpls and MplsEtba ETBA
plugins, e.g. constructMplsHeader.
'''

from collections import namedtuple
import hashlib

import Tac
from TypeFuture import TacLazyType

EthAddr = TacLazyType( 'Arnet::EthAddr' )
EthIntfId = TacLazyType( 'Arnet::EthIntfId' )
FecId = TacLazyType( 'Smash::Fib::FecId' )
PortChannelIntfId = TacLazyType( 'Arnet::PortChannelIntfId' )
TunnelType = TacLazyType( 'Tunnel::TunnelTable::TunnelType' )
VlanIntfId = TacLazyType( 'Arnet::VlanIntfId' )

ELI = Tac.Type( "Arnet::MplsLabel" ).entropyLabelIndicator
IMP_NULL = Tac.Type( "Arnet::MplsLabel" ).implicitNull
BUD_ROLE_SPL_PKT = "BUD_ROLE_SPL_PKT"

# This is created and passed to the resolver by the client of the resolver
HashInfo = namedtuple( 'HashInfo', [
   'srcMac',
   'dstMac',
   'mplsLabels', # A tuple of MPLS labels ( Top, ..., Bottom )
   'srcIp',
   'dstIp',
   ] )

def ecmpHash( coll, *values ):
   '''
   Based on the SHA256 hash of all 'values', retrieve an entry
   from 'coll' to simulate ECMP hashing.
   '''
   num = len( coll )
   if num == 1:
      return coll[ 0 ]
   m = hashlib.new( 'sha256' )
   for v in values:
      m.update( str( v ).encode() )
   d = m.digest()
   idx = d[ 0 ] % num
   return coll[ idx ]

def removeImpNullFromLabelStack( labelStack ):
   '''
   Remove all implicit-NULL labels from the provided labelStack.
   '''
   idx = 0
   while idx < len( labelStack ):
      if labelStack[ idx ] == IMP_NULL:
         del labelStack[ idx ]
         if labelStack and idx < len( labelStack ) \
            and labelStack[ idx ] == ELI:
            if idx + 2 <= len( labelStack ):
               del labelStack[ idx ]        # Pop ELI
               del labelStack[ idx ]        # Pop EL
            else:
               assert len( labelStack ) >= 2, "ELI without EL in labelStack"
      else:
         idx += 1

def getHashInfo( pkt, ethHdr, mplsHdr, ipHdr ):
   'Creates a HashInfo tuple object using the L2, MPLS, and L3 headers'
   srcIp = ipHdr.src if ipHdr else 0
   dstIp = ipHdr.dst if ipHdr else 0
   # BUG415506: Get the whole label stack once the PktParserTestLib is enhanced
   mplsLabels = [ mplsHdr.label if mplsHdr else 0 ]
   idx = 1
   while mplsHdr and not mplsHdr.bos:
      mplsHdr = Tac.newInstance( "Arnet::MplsHdrWrapper",
                                 mplsHdr.pkt,
                                 mplsHdr.offset + idx * 4 )
      mplsLabels.append( mplsHdr.label )
      idx += 1
   hashInfo = HashInfo( ethHdr.src, ethHdr.dst, tuple( mplsLabels ), srcIp, dstIp )
   return hashInfo

def computeFlowOrEntropyLabel( pkt, ethHdr, mplsHdr, ipHdr ):
   hashInfo = getHashInfo( pkt, ethHdr, mplsHdr, ipHdr )
   lbLabel = hash( hashInfo ) % 2 ** 20
   if lbLabel < 16:
      # Load balance label value can't be in reserved label range
      lbLabel += 16
   return lbLabel

def getTunnelEntry( tunnelTables, policyForwardingStatus, ethHdr, tunnelId ):
   tunnelType = tunnelId.tunnelType()
   if tunnelType == TunnelType.srTePolicyTunnel:
      # We need to retrieve the actual segment-list tunnel from the SR-TE policy FEC.
      srTePolicyFecId = FecId.tunnelIdToFecId( tunnelId )
      srTePolicyFec = policyForwardingStatus.fec.get( srTePolicyFecId )
      if not srTePolicyFec:
         return None
      srTePolicyVia = ecmpHash( srTePolicyFec.via, ethHdr.src, ethHdr.dst )
      # Overwrite the tunnelId and tunnelType to use the segment-list.
      tunnelId = Tac.Value( 'Tunnel::TunnelTable::TunnelId', srTePolicyVia.tunnelId )
      if not tunnelId.isValid():
         return None
      tunnelType = tunnelId.tunnelType()
   for tunnelTable in tunnelTables[ tunnelType ]:
      entry = tunnelTable.entry.get( tunnelId )
      if entry:
         return entry
   return None

def getIntfVlan( bridgingConfig, intfId ):
   if VlanIntfId.isVlanIntfId( intfId ):
      return VlanIntfId.vlanId( intfId )
   else:
      switchIntfConf = bridgingConfig.switchIntfConfig.get( intfId )
      if switchIntfConf is None or switchIntfConf.nativeVlan == 0:
         return None
      return switchIntfConf.nativeVlan

def isPortChannel( intfName ):
   # Adapted from Ale/AleHelper.cpp
   return PortChannelIntfId.isPortChannelIntfId( intfName )

def isRoutedPort( intfName ):
   # Adapted from Ale/AleHelper.cpp
   return isPortChannel( intfName ) or EthIntfId.isEthIntfId( intfName )

def getL2Intf( bridgingStatus, vlan, macAddr ):
   intf = Tac.ValueConst( 'Arnet::IntfId' )
   fdbStatus = bridgingStatus.fdbStatus.get( vlan )
   if not fdbStatus:
      return intf
   macEntry = fdbStatus.learnedHost.get( macAddr )
   if not macEntry:
      return intf
   return macEntry.intf
