#!/usr/bin/env python3
# Copyright (c) 2008, 2009, 2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=superfluous-parens

# We allow wildcard indludes
# pylint: disable-msg=W0401

import Tac
from Arnet.EthTestLib import *
# pylint: disable=redefined-builtin
from Arnet.IpTestLib import *
import Tracing
from collections import namedtuple

t0 = Tracing.trace0

ethTypeMap = {}

# Return type for getPacketInfo
PacketInfo = namedtuple( 'PacketInfo', [
      'pkt',
      'offset',
      'ethHdr',
      'ipHdr',
      'ip6Hdr',
      'mplsHdr',
      'ethDot1QHdr',
   ] )

# Call this to register a parse function for the given ethType.
def ethTypeParserIs( ethType, parseFunc ):
   assert ethType not in ethTypeMap
   assert ethType in Tac.Type( "Arnet::EthType" ).attributes \
      or ethType == "ethTypeLlc"
   ethTypeMap[ ethType ] = parseFunc

def ethTypeParser( ethType ):
   if ethType == Tac.Type( "Arnet::EthType" ).ethTypeUnknown:
      ethType = "ethTypeLlc"
   return ethTypeMap.get( ethType )

# Return a list of header wrappers for the various headers present in the given
# packet.
def parsePkt( pkt, parseBeyondMpls=False ):
   headers = []
   currentOffset = 0
   if pkt.bytes < EthHdrSize:
      return (headers, currentOffset)
   ethHdr = Tac.newInstance( "Arnet::EthHdrWrapper", pkt, currentOffset )
   currentOffset += EthHdrSize
   headers.append( ('EthHdr', ethHdr) )

   ethType = ethHdr.ethType
   parseFunc = ethTypeParser( ethType )
   if parseFunc is None:
      return (headers, currentOffset)

   if ( ethType in ( 'ethTypeMpls', 'ethTypeDot1Q', 'ethTypeDot1ad' )
        and parseBeyondMpls ):
      currentOffset = parseFunc( pkt, currentOffset, headers, parseBeyondMpls=True )
   else:
      currentOffset = parseFunc( pkt, currentOffset, headers )

   return (headers, currentOffset)

def copyStrToPkt( pkt, string, offset=0 ):
   if isinstance( string, str ):
      string = string.encode()
   pkt.rawBytesAtIdxIs( offset, string )

def createPktFromStr( pktStr ):
   p = Tac.newInstance( "Arnet::Pkt" )
   # If the packet string is longer than our maximum packet size, just cut
   # off the end.  It's very unlikely that we really have more packet headers
   # out that that we need to handle.
   strLen = len( pktStr )
   if strLen > p.maxSharedData:
      pktStr = pktStr[ : p.maxSharedData ]
      strLen = p.maxSharedData
   p.bytesValue = pktStr
   return p

# Just like parsePkt, but this copies the pktStr into a packet to give to the parse
# functions.  It returns the packet along with the headers and remaining offset.
def parsePktStr( pktStr, parseBeyondMpls=False ):
   p = createPktFromStr( pktStr )
   ( headers, currentOffset ) = parsePkt( p, parseBeyondMpls )
   return (p, headers, currentOffset)

# Use this to find a particular header in the list of headerType, header pairs
# returned by parsePkt.
def findHeader( headers, headerType, headerPosition=None ):
   matches = [ hdr for currType, hdr in headers if currType == headerType ]
   if headerPosition is not None and len( matches ) > headerPosition:
      return matches[ headerPosition ]
   elif headerPosition is None and len( matches ) == 1:
      return matches[ 0 ]
   # pylint: disable-next=consider-using-f-string
   t0( 'unsupported number of "{}" type headers discovered: {}'.format(
          headerType, len( matches ) ) )
   return None

# This is the parse function for 802.1Q and 802.1ad headers.  Extract the header and
# use the ethTypeMap to find the parser for the next layer of headers.
def dot1Parser( pkt, currentOffset, headers, headerId, parseBeyondMpls=False ):
   if( (pkt.bytes - currentOffset) < EthDot1QHdrSize ):
      return currentOffset
   dot1QHdr = Tac.newInstance( "Arnet::Eth8021QHdrWrapper", pkt, currentOffset )
   currentOffset += EthDot1QHdrSize
   headers.append( ( headerId, dot1QHdr ) )
   ethType = dot1QHdr.ethType
   parseFunc = ethTypeParser( ethType )
   if parseFunc is None:
      return currentOffset

   if ( ethType in ( 'ethTypeMpls', 'ethTypeDot1Q', 'ethTypeDot1ad' )
        and parseBeyondMpls ):
      currentOffset = parseFunc( pkt, currentOffset, headers, parseBeyondMpls=True )
   else:
      currentOffset = parseFunc( pkt, currentOffset, headers )

   return currentOffset

def dot1QParser( pkt, currentOffset, headers, parseBeyondMpls=False ):
   return dot1Parser( pkt, currentOffset, headers, 'EthDot1QHdr',
                      parseBeyondMpls=parseBeyondMpls )

def dot1adParser( pkt, currentOffset, headers, parseBeyondMpls=False ):
   return dot1Parser( pkt, currentOffset, headers, 'EthDot1adHdr',
                      parseBeyondMpls=parseBeyondMpls )

# Register the dot1q and dot1ad parsers.
ethTypeParserIs( "ethTypeDot1Q", dot1QParser )
ethTypeParserIs( "ethTypeDot1ad", dot1adParser )

def mplsParser( pkt, currentOffset, headers, parseBeyondMpls=False ):
   '''
   Parses the top label for the MPLS header. Since the MPLS header doesn't specify
   the etherType for its payload, this will not parse beyond the MPLS header by
   default. If parseBeyondMpls is True, this will skip the remaining MPLS labels
   (by looking for the set BOS bit) and will try to parse the payload.

   Notes:
   - This will only return the offset for the last header it successfully parsed.
   - BUG415506 will change this to parse the entire label stack.
   '''
   mplsHeaderSize = 4
   if( (pkt.bytes - currentOffset) < mplsHeaderSize ):
      return currentOffset
   mplsHdr = Tac.newInstance( "Arnet::MplsHdrWrapper", pkt, currentOffset )
   currentOffset += mplsHeaderSize
   headers.append( ('MplsHdr', mplsHdr) )

   if not parseBeyondMpls:
      return currentOffset

   # For parsing beyond the first label, we can skip the remaining MPLS labels by
   # checking if the BOS bit is set
   isBos = mplsHdr.bos
   topMplsOffset = currentOffset
   while not isBos:
      if ( pkt.bytes - currentOffset ) < mplsHeaderSize:
         return topMplsOffset
      # BOS is the last bit of the third byte of the MPLS label
      isBos = pkt.rawByte[ currentOffset + 2 ] & 1
      currentOffset += mplsHeaderSize

   # Assume the payload is IP or IPv6. However, if there aren't even enough bytes
   # for just the IP header, just return the offset for the first MPLS label
   if ( pkt.bytes - currentOffset ) < IpHdrSize:
      return topMplsOffset

   # To determine whether it's IP or IPv6 (or PW-ACH, for pseudowire control word
   # ping), look at the version in the first nibble.
   parseFunc = None
   version = _getIpVersionNibble( pkt, currentOffset )
   if version == 1:
      parseFunc = pwAchParser
   elif version == 4:
      ethType = 'ethTypeIp'
   elif version == 6:
      ethType = 'ethTypeIp6'
   else:
      return topMplsOffset

   parseFunc = parseFunc or ethTypeParser( ethType )
   if parseFunc is None:
      return currentOffset

   # If the offset changed, return the new one. Otherwise, use the offset for the
   # top MPLS label
   newOffset = parseFunc( pkt, currentOffset, headers )
   if newOffset == currentOffset:
      return topMplsOffset
   else:
      return newOffset

# Register the mpls parser.
ethTypeParserIs( "ethTypeMpls", mplsParser )

def _getIpVersionNibble( pkt, currentOffset ):
   return ( pkt.rawByte[ currentOffset ] & 0xF0 ) >> 4

def pwAchParser( pkt, currentOffset, headers ):
   pwAchHeaderSize = 4

   if( ( pkt.bytes - currentOffset ) < pwAchHeaderSize ):
      return currentOffset

   pwAchHdr = Tac.newInstance( "Arnet::PwAchHdrWrapper", pkt, currentOffset )
   headers.append( ( 'PwAchHdr', pwAchHdr ) )
   currentOffset += pwAchHeaderSize

   version = _getIpVersionNibble( pkt, currentOffset )
   if version == 4:
      ethType = 'ethTypeIp'
   elif version == 6:
      ethType = 'ethTypeIp6'
   else:
      return currentOffset

   parseFunc = ethTypeParser( ethType )
   if parseFunc is None:
      return currentOffset

   currentOffset = parseFunc( pkt, currentOffset, headers )

   return currentOffset

sapMap = {}
def llcSapParserIs( sap, parseFunc ):
   sapMap[ sap ] = parseFunc
   
def llcParser( pkt, currentOffset, headers ):
   if( (pkt.bytes - currentOffset) < Eth8022LlcHdrSize ):
      return currentOffset
   llcHdr = Tac.newInstance( "Arnet::Eth8022LlcHdrWrapper", pkt, currentOffset )
   currentOffset += Eth8022LlcHdrSize
   headers.append( ('EthLlcHdr', llcHdr) )

   sap = llcHdr.dsap
   parseFunc = sapMap.get( sap )
   if parseFunc is None:
      return currentOffset

   currentOffset = parseFunc( pkt, currentOffset, headers )

   return currentOffset

ethTypeParserIs( "ethTypeLlc", llcParser )

snapMap = {}
def snapParserIs( orgCode, localCode, parseFunc ):
   snapMap[ (orgCode, localCode) ] = parseFunc
   
def snapParser( pkt, currentOffset, headers ):
   if( (pkt.bytes - currentOffset) < Eth8022SnapHdrSize ):
      return currentOffset
   snapHdr = Tac.newInstance( "Arnet::Eth8022SnapHdrWrapper", pkt, currentOffset )
   currentOffset += Eth8022SnapHdrSize
   headers.append( ('EthSnapHdr', snapHdr) )

   org = snapHdr.orgCode
   local = snapHdr.localCode
   parseFunc = snapMap.get( (org, local) )
   if parseFunc is None:
      return currentOffset

   currentOffset = parseFunc( pkt, currentOffset, headers )

   return currentOffset
   
llcSapParserIs( Eth8022LlcSnapSap, snapParser )

ipProtoMap = {}
# Call this to register a parse function for the given ipProto.
def ipProtoParser( ipProto, parseFunc ):
   assert ipProto not in ipProtoMap
   ipProtoMap[ ipProto ] = parseFunc

ipOptionMap = {}
# Call this to register a parse function for the given ip option number.
def ipOptionParser( ipOptionNum, parseFunc ):
   assert ipOptionNum not in ipOptionMap
   ipOptionMap[ ipOptionNum ] = parseFunc

ip6OptionMap = {}
def ip6OptionParser( ip6OptionNum, parseFunc ):
   assert ip6OptionNum not in ip6OptionMap
   ip6OptionMap[ ip6OptionNum ] = parseFunc

# This is the parse function for IP headers.  Extract the IP header and
# use the ipProtoMap to find the parser for the next layer of headers.
def ipHdrParser( pkt, currentOffset, headers ):
   if( (pkt.bytes - currentOffset) < IpHdrSize ):
      return currentOffset
   ipHdr = Tac.newInstance( "Arnet::IpHdrWrapper", pkt, currentOffset )
   optionsOffset = currentOffset + IpHdrSize
   headerBytes = IpHdrSize
   if ipHdr.headerBytes > headerBytes:
      headerBytes = ipHdr.headerBytes
   currentOffset += headerBytes
   headers.append( ('IpHdr', ipHdr) )

   while optionsOffset < currentOffset:
      optCommon = Tac.newInstance( "Arnet::IpHdrOptionCommonWrapper",
                                   pkt, optionsOffset )
      optNumber = optCommon.number
      parseFunc = ipOptionMap.get( optNumber )
      if parseFunc is None:
         break
      newOptionsOffset = parseFunc( pkt, optionsOffset, currentOffset, headers )
      if newOptionsOffset == optionsOffset:
         # We didn't move the offset, so there must be something very wrong with
         # the current option (most likely is doesn't fit).  Just stop parsing
         # options.
         break
      optionsOffset = newOptionsOffset
   ipProto = ipHdr.protocolNum
   parseFunc = ipProtoMap.get( ipProto )
   if parseFunc is None:
      return currentOffset

   currentOffset = parseFunc( pkt, currentOffset, headers )

   return currentOffset

# This is the parse function for IP headers.  Extract the IP header and
# use the ipProtoMap to find the parser for the next layer of headers.
def ip6HdrParser( pkt, currentOffset, headers ):
   if( (pkt.bytes - currentOffset) < Ip6HdrSize ):
      return currentOffset
   ip6Hdr = Tac.newInstance( "Arnet::Ip6HdrWrapper", pkt, currentOffset )
   nextHdr = ip6Hdr.protocolNum
   currentOffset += Ip6HdrSize
   headers.append( ('Ip6Hdr', ip6Hdr) )
   
   while nextHdr in [ 'ipProtoIpv6HopByHop', 'ipProtoIpv6Route',
                      'ipProtoIpv6Frag', 'ipProtoIpv6Opts' ]:
      if( (pkt.bytes - currentOffset) < Ip6HdrOptionCommonSize ):
         return currentOffset
      optCommon = Tac.newInstance( "Arnet::Ip6HdrOptionCommonWrapper",
                                   pkt, currentOffset )
      # Length of options header in 8-octet units, not including
      # the first 8 octets (from rfc2460 secton 4.3)
      extHdrLen = ( optCommon.optionLength + 1 ) * 8

      parseFunc = ip6OptionMap.get( nextHdr )
      if parseFunc is None:
         break
      
      if ( pkt.bytes - currentOffset ) < extHdrLen:
         break
      currentOffset += extHdrLen
      nextHdr = optCommon.nextHdrNum
   parseFunc = ipProtoMap.get( nextHdr )
   if parseFunc is None:
      return currentOffset
   currentOffset = parseFunc( pkt, currentOffset, headers )

   return currentOffset

# Register the IP header parser with its ethType.
ethTypeParserIs( "ethTypeIp", ipHdrParser )
ethTypeParserIs( "ethTypeIp6", ip6HdrParser )

# This is the parse function for the IP router alert option.
def routerAlertOptionParser( pkt, optionsOffset, ipHdrEndOffset, headers ):
   if( (ipHdrEndOffset - optionsOffset) < IpHdrRouterAlertOptionSize ):
      return optionsOffset
   raOpt = Tac.newInstance( "Arnet::IpHdrRouterAlertOptionWrapper",
                            pkt, optionsOffset )
   optionsOffset += IpHdrRouterAlertOptionSize
   headers.append( ('IpHdrRouterAlertOption', raOpt) )

   return optionsOffset

# This is the parse function for the IP unrecognized option.
def unrecognizedOptionParser( pkt, optionsOffset, ipHdrEndOffset, headers ):
   if( ( ipHdrEndOffset - optionsOffset ) < IpHdrUnrecognizedOptionSize ):
      return optionsOffset
   uoOpt = Tac.newInstance( "Arnet::IpHdrUnrecognizedOptionWrapper",
                            pkt, optionsOffset )
   optionsOffset += IpHdrUnrecognizedOptionSize
   headers.append( ( 'IpHdrUnrecognizedOption', uoOpt ) )

   return optionsOffset

def hopByHopOptionParser( pkt, currentOffset, headers ):
   if( (pkt.bytes - currentOffset ) < IpHdrRouterAlertOptionSize ):
      assert 0  # mal-formed packet
      return currentOffset
   hopByHopOpt = Tac.newInstance( "Arnet::Ip6HdrRouterAlertOptionWrapper",
                                  pkt, currentOffset )
   headers.append( ('Ip6HdrRouterAlertOption', hopByHopOpt) )
   return IpHdrRouterAlertOptionSize + Ip6HdrRouterAlertPad

# Register the router alert option parser with its option number.
ipOptionParser( IpHdrRouterAlertOptionNum, routerAlertOptionParser )
ipOptionParser( IpHdrUnrecognizedOptionNum, unrecognizedOptionParser )
ip6OptionParser( 'ipProtoIpv6HopByHop', hopByHopOptionParser )
   
def udpHeaderParser( pkt, currentOffset, headers ):
   if ( pkt.bytes - currentOffset ) < UdpHdrSize:
      return currentOffset
   if findHeader( headers, 'Ip6Hdr' ):
      udpHeader = Tac.newInstance( "Arnet::UdpHdrWrapperIpv6", pkt, currentOffset )
   else:
      udpHeader = Tac.newInstance( "Arnet::UdpHdrWrapper", pkt, currentOffset )
   headers.append( ( 'UdpHdr', udpHeader ) )
   currentOffset += UdpHdrSize
   return currentOffset

ipProtoParser( "ipProtoUdp", udpHeaderParser )

def tcpHeaderParser( pkt, currentOffset, headers ):
   if ( pkt.bytes - currentOffset ) < TcpHdrSize:
      return currentOffset
   tcpHeader = Tac.newInstance( "Arnet::TcpHdrWrapper",
                                pkt, currentOffset )
   headers.append( ( 'TcpHdr', tcpHeader ) )
   currentOffset += TcpHdrSize
   return currentOffset

ipProtoParser( "ipProtoTcp", tcpHeaderParser )

def greHeaderParser( pkt, currentOffset, headers ):
   if ( pkt.bytes - currentOffset ) < GreHdrSize:
      return currentOffset
   greHeader = Tac.newInstance( "Arnet::GreHdrWrapper",
                                pkt, currentOffset )
   headers.append( ( 'GreHdr', greHeader ) )
   currentOffset += greHeader.greHeaderSize()
   return currentOffset

ipProtoParser( "ipProtoGre", greHeaderParser )

def sctpHeaderParser( pkt, currentOffset, headers ):
   if ( pkt.bytes - currentOffset ) < SctpHdrSize:
      return currentOffset
   headers.append( ( "SctpHdr",
                     Tac.newInstance( "Arnet::SctpHdrWrapper",
                                      pkt, currentOffset ) ) )
   return currentOffset + SctpHdrSize

ipProtoParser( "ipProtoSctp", sctpHeaderParser )

def icmpHeaderParser( pkt, currentOffset, headers ):
   if ( pkt.bytes - currentOffset ) < IcmpHdrSize:
      return currentOffset
   headers.append( ( "IcmpHdr",
                     Tac.newInstance( "Arnet::IcmpHdrWrapper",
                                      pkt, currentOffset ) ) )
   return currentOffset + IcmpHdrSize

ipProtoParser( "ipProtoIcmp", icmpHeaderParser )

def parseDataPacket( data, parseBeyondMpls=False ):
   info = parsePktStr( data, parseBeyondMpls=parseBeyondMpls )
   ( pkt, headers, offset ) = info
   ethHdr = findHeader( headers, 'EthHdr' )
   ipHdr = findHeader( headers, "IpHdr" )
   ip6Hdr = findHeader( headers, "Ip6Hdr" )
   mplsHdr = findHeader( headers, "MplsHdr" )

   # We want to look for the inner-most dot1q header
   ethDot1QHdr = None
   for hdr in headers:
      # An ethernet frame begins with EthAddr, followed by EthDot1QHdrs. Any
      # subsequent headers thereafter do not start with 'Eth,' so we know we've
      # analyzed all the dot1q headers for a given ethernet frame.
      if not hdr[ 0 ].startswith( 'Eth' ):
         break
      if hdr[ 0 ] == 'EthDot1QHdr':
         ethDot1QHdr = hdr[ 1 ]

   return PacketInfo( pkt, offset, ethHdr, ipHdr, ip6Hdr, mplsHdr, ethDot1QHdr )
