# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import codecs

from Arnet import EthAddr
from CliCommon import AlreadyHandledError
from CliPlugin.ForwardingDestinationCommon import (
   ArgToLabel,
   BaseFields,
   GreFields,
   GreTypeFields,
   IPv4Fields,
   IPv6Fields,
   IncompatibleFieldValues,
   IncompatibleFields,
   InnerIPv4Fields,
   InnerIPv6Fields,
   InnerL2Fields,
   L2Fields,
   L4Fields,
   NvgreFields,
   OptionalL2Fields,
   PacketInfo,
)
from SysConstants.if_ether_h import ETH_P_IP, ETH_P_IPV6, ETH_P_8021Q, \
   ETH_P_MPLS_UC, ETH_P_MPLS_MC, ETH_P_TEB
from SysConstants.in_h import IPPROTO_GRE, IPPROTO_IPIP, IPPROTO_IPV6, IPPROTO_TCP, \
   IPPROTO_UDP
from TypeFuture import TacLazyType

Dot1qHeader = TacLazyType( 'PacketTracer::Dot1qHeader' )
IpGenAddr = TacLazyType( 'Arnet::IpGenAddr' )
GreHeader = TacLazyType( 'PacketTracer::GreHeader' )
GreType = TacLazyType( 'PacketTracer::GreType' )
L2Header = TacLazyType( 'PacketTracer::L2Header' )
L3Header = TacLazyType( 'PacketTracer::L3Header' )
L4Header = TacLazyType( 'PacketTracer::L4Header' )
Request = TacLazyType( 'PacketTracer::Request' )

def printRequest( request ):
   """Print out the currently configured packet
   Example:
   MAC Source aaaa.bbbb.cccc Destination 1234.4321.abcd 802.1Q 10 802.1Q 20
   IPv4 Source 10.1.0.1 Destination 20.1.0.1 TTL 23 Protocol 6
   TCP Source Port 10000 Destination Port 20000
   """
   if request.packetOverWriteBytes or request.overwriteOffset > 0:
      # Raw packet requests are not printed out
      return

   def printL2Header( l2, innerHeader=False ):
      dot1qString = ''
      innerStr = 'Inner ' if innerHeader else ''
      etherType = l2.etherType
      if l2.dot1qHeader != Dot1qHeader():
         dot1qString += f' 802.1Q {l2.dot1qHeader.vlanId}'
         etherType = l2.dot1qHeader.nextEtherType
         if l2.innerDot1qHeader != Dot1qHeader():
            dot1qString += f' 802.1Q {l2.innerDot1qHeader.vlanId}'
            etherType = l2.innerDot1qHeader.nextEtherType
      print( '{}MAC Source {} Destination {} Ethertype {}{}'.format(
         innerStr, EthAddr( l2.srcMac ).displayString,
         EthAddr( l2.dstMac ).displayString, hex( etherType ), dot1qString ) )

   def printIpHeader( l3, innerHeader=False ):
      innerStr = 'Inner ' if innerHeader else ''
      if l3.ipVersion == 4:
         l3Protocol = l3.protocol
         print( '{}IPv4 Source {} Destination {} TTL {} Protocol {}'.format(
            innerStr, l3.srcIp, l3.dstIp, l3.ttl, l3.protocol ) )
      elif l3.ipVersion == 6:
         l3Protocol = l3.nextHeader
         print( '{}IPv6 Source {} Destination {} Hop Limit {} Flow Label {} '
                'Next Header {}'.format( innerStr, l3.srcIp, l3.dstIp, l3.hopLimit,
                                         l3.flowLabel, l3.nextHeader ) )
      return l3Protocol

   l2 = request.l2Header
   printL2Header( l2 )

   l3Protocol = 0
   if request.hasL3:
      l3 = request.l3Header
      l3Protocol = printIpHeader( l3 )

   if request.hasGreHeader:
      gre = request.greHeader
      if gre.greType == GreType.greTypeNvgre:
         print( 'GRE Protocol {} Virtual Subnet ID {} Flow ID {}'.format(
            hex( gre.protocol ), gre.virtualSubnetId, gre.flowId ) )
      elif gre.greType == GreType.greTypeGre:
         keyStr = f' GRE Key 0x{gre.key:x}' if gre.hasKey else ''
         seqNumStr = f' GRE Sequence 0x{gre.sequenceNum:x}' \
                     if gre.hasSequenceNum else ''
         checksumStr = ' GRE Checksum generated' if gre.generateChecksum else ''
         print( 'GRE Protocol {}{}{}{}'.format( hex( gre.protocol ), keyStr,
                                                seqNumStr, checksumStr ) )
      else:
         assert False, 'Invalid GRE type: %s' % gre.greType

   if request.hasInnerL2Header:
      innerL2 = request.innerL2Header
      printL2Header( innerL2, innerHeader=True )

   if request.hasInnerL3Header:
      innerL3 = request.innerL3Header
      l3Protocol = printIpHeader( innerL3, innerHeader=True )

   if request.hasL4:
      l4 = request.l4Header
      l4Type = 'TCP' if l3Protocol == IPPROTO_TCP else 'UDP'
      print( '{} Source Port {} Destination Port {}'.format(
             l4Type, l4.srcPort, l4.dstPort ) )

def updateRequest( request, treeDict ):
   """Update a request from all the user provided fields"""
   request.ingressIntf = str( treeDict.get( '<ingressIntf>' ) )

   rawPacket = treeDict.get( '<rawPacket>' )
   if rawPacket:
      # We permit spaces, so strip them out before processing it
      rawPacket = rawPacket.replace( ' ', '' )
      request.packetOverWriteBytes = codecs.decode( rawPacket, 'hex' )
      request.overwriteOffset = 0
      return request

   l2Header = L2Header()
   l2Header.srcMac = treeDict.get( '<srcMac>' )
   l2Header.dstMac = treeDict.get( '<dstMac>' )
   l2Header.etherType = treeDict.get( '<etherType>' )

   vlan = treeDict.get( '<vlan>' )
   if vlan:
      dot1qHeader = Dot1qHeader()
      dot1qHeader.vlanId = vlan
      dot1qHeader.nextEtherType = l2Header.etherType

      innerVlan = treeDict.get( '<innerVlan>' )
      if innerVlan:
         innerDot1qHeader = Dot1qHeader()
         innerDot1qHeader.vlanId = innerVlan
         innerDot1qHeader.nextEtherType = dot1qHeader.nextEtherType
         l2Header.innerDot1qHeader = innerDot1qHeader
         dot1qHeader.nextEtherType = ETH_P_8021Q

      l2Header.dot1qHeader = dot1qHeader
      l2Header.etherType = ETH_P_8021Q
   request.l2Header = l2Header

   packetType = treeDict.get( '<packetType>' )
   if ( packetType == 'ipv4' or packetType == 'ipv6' or packetType == 'gre' ):
      protocol = 0
      request.hasL3 = True
      packetIpVersion = treeDict.get( '<ipVersion>' ) if packetType == 'gre' \
                        else packetType
      ipVersion = 4 if packetIpVersion == 'ipv4' else 6
      l3Header = L3Header( ipVersion )
      if ipVersion == 4:
         l3Header.srcIp = IpGenAddr( treeDict.get( '<srcIpv4>' ) )
         l3Header.dstIp = IpGenAddr( treeDict.get( '<dstIpv4>' ) )
         l3Header.ttl = treeDict.get( '<ipTtl>' )
         l3Header.protocol = treeDict.get( '<ipProto>' )
         protocol = treeDict.get( '<ipProto>' )
      elif ipVersion == 6:
         l3Header.srcIp = IpGenAddr( treeDict.get( '<srcIpv6>' ).stringValue )
         l3Header.dstIp = IpGenAddr( treeDict.get( '<dstIpv6>' ).stringValue )
         l3Header.hopLimit = treeDict.get( '<hopLimit>' )
         l3Header.nextHeader = treeDict.get( '<nextHeader>' )
         l3Header.flowLabel = treeDict.get( '<flowLabel>' )
         protocol = treeDict.get( '<nextHeader>' )
      request.l3Header = l3Header

      if packetType == 'gre':
         request.hasGreHeader = True
         packetGreType = treeDict.get( '<greType>', 'gre' )
         if packetGreType == 'gre':
            greHeader = GreHeader( GreType.greTypeGre )
            key = treeDict.get( '<greKey>' )
            checksum = treeDict.get( 'gre-checksum' )
            sequence = treeDict.get( '<greSequence>' )
            if key:
               greHeader.hasKey = True
               greHeader.key = key
            if checksum:
               greHeader.generateChecksum = True
            if sequence:
               greHeader.hasSequenceNum = True
               greHeader.sequenceNum = sequence
            greHeader.protocol = treeDict.get( '<greProto>' )
         elif packetGreType == 'nvgre':
            greHeader = GreHeader( GreType.greTypeNvgre )
            greHeader.virtualSubnetId = treeDict.get( '<nvgreVirtualSubnetId>' )
            greHeader.flowId = treeDict.get( '<nvgreFlowId>' )
            greHeader.protocol = ETH_P_TEB

            request.hasInnerL2Header = True
            l2Header = L2Header()
            l2Header.srcMac = treeDict.get( '<innerL2SrcMac>' )
            l2Header.dstMac = treeDict.get( '<innerL2DstMac>' )
            l2Header.etherType = treeDict.get( '<innerL2EtherType>' )
            request.innerL2Header = l2Header
         request.greHeader = greHeader

      innerPacketType = treeDict.get( '<innerPacketType>' )
      if innerPacketType in [ 'ipv4', 'ipv6' ]:
         request.hasInnerL3Header = True
         ipVersion = 4 if innerPacketType == 'ipv4' else 6
         l3Header = L3Header( ipVersion )
         if ipVersion == 4:
            l3Header.srcIp = IpGenAddr( treeDict.get( '<innerSrcIpv4>' ) )
            l3Header.dstIp = IpGenAddr( treeDict.get( '<innerDstIpv4>' ) )
            l3Header.ttl = treeDict.get( '<innerIpTtl>' )
            l3Header.protocol = treeDict.get( '<innerIpProto>' )
            protocol = treeDict.get( '<innerIpProto>' )
         elif ipVersion == 6:
            l3Header.srcIp = IpGenAddr(
               treeDict.get( '<innerSrcIpv6>' ).stringValue )
            l3Header.dstIp = IpGenAddr(
               treeDict.get( '<innerDstIpv6>' ).stringValue )
            l3Header.hopLimit = treeDict.get( '<innerHopLimit>' )
            l3Header.nextHeader = treeDict.get( '<innerNextHeader>' )
            l3Header.flowLabel = treeDict.get( '<innerFlowLabel>' )
            protocol = treeDict.get( '<innerNextHeader>' )
         request.innerL3Header = l3Header

      if treeDict.get( '<l4Type>' ) in [ 'tcp', 'udp' ]:
         request.hasL4 = True
         l4Header = L4Header()
         l4Header.protocol = protocol
         l4Header.srcPort = treeDict.get( '<srcL4Port>' )
         l4Header.dstPort = treeDict.get( '<dstL4Port>' )
         request.l4Header = l4Header

   return request

def generatePacketType( treeDict ):
   originalPacketType = treeDict.get( '<packetType>' )
   if originalPacketType == 'raw' or '<rawPacket>' in treeDict:
      return 'raw'

   if originalPacketType:
      return originalPacketType

   # The inferredPacketType will be the packet type deduced from the fields that are
   # present, not the one provided by the packet type prompt or packet-type CLI
   packetType = None

   # If any of the L2 fields have been configured set type to ethernet. This is
   # performed prior to IPv4/v6 checks as those overrule the configuration here.
   if any( k in treeDict for k in L2Fields + OptionalL2Fields ):
      packetType = 'ethernet'

   # Set the packet type to IPv4/v6 if any of their corresponding fields are
   # configured.
   if any( k in treeDict for k in IPv6Fields ):
      packetType = 'ipv6'
   elif any( k in treeDict for k in IPv4Fields ):
      packetType = 'ipv4'

   # Set packet type to "gre" if any GRE, NVGRE, or inner L2 fields are configured.
   # Note that currently, an inner L2 header is supported only with NVGRE packets.
   if any( k in treeDict for k in GreFields + NvgreFields + InnerL2Fields ):
      packetType = 'gre'

   return packetType

def generateInnerPacketType( treeDict, packetType=None ):
   originalInnerPacketType = treeDict.get( '<innerPacketType>' )
   if originalInnerPacketType:
      return originalInnerPacketType

   # If packet type is provided already as 'ethernet', don't set an inner packet
   # type. This is a misconfiguration since an 'ethernet' type packet cannot have
   # inner fields.
   originalPacketType = treeDict.get( '<packetType>' )
   if originalPacketType == 'ethernet':
      return None

   innerPacketType = None
   # Set the inner packet type to ipv4 or ipv6 if any of the fields are configured
   if any( k in treeDict for k in InnerIPv4Fields ):
      innerPacketType = 'ipv4'
   elif any( k in treeDict for k in InnerIPv6Fields ):
      innerPacketType = 'ipv6'

   return innerPacketType

def generateL4Type( treeDict ):
   originalL4Type = treeDict.get( '<l4Type>' )
   if originalL4Type:
      return originalL4Type

   protocol = None
   if any( k in treeDict for k in IPv6Fields ):
      protocol = treeDict.get( '<nextHeader>' )
   elif any( k in treeDict for k in IPv4Fields ):
      protocol = treeDict.get( '<ipProto>' )

   if any( k in treeDict for k in InnerIPv4Fields ):
      protocol = treeDict.get( '<innerIpProto>' )
   elif any( k in treeDict for k in InnerIPv6Fields ):
      protocol = treeDict.get( '<innerNextHeader>' )

   # Determine the L4 type based on the protocol in the IPv4 or v6 header
   l4Type = None
   if protocol == IPPROTO_TCP:
      l4Type = 'tcp'
   elif protocol == IPPROTO_UDP:
      l4Type = 'udp'
   elif isinstance( protocol, int ):
      # If a protocol that's not TCP nor UDP has been set we'll set l4Type to 'none'
      # so we won't prompt the user for the L4 header.
      l4Type = 'none'

   return l4Type

def generateGreType( treeDict ):
   originalGreType = treeDict.get( '<greType>' )
   if originalGreType:
      return originalGreType

   originalPacketType = treeDict.get( '<packetType>' )
   # If packet type is provided already and it isn't set to gre, don't set a greType
   if originalPacketType and originalPacketType != 'gre':
      return None

   # Set greType to "gre" if any greFields are configured
   greType = None
   if any( k in treeDict for k in GreFields ):
      greType = 'gre'

   # Set greType to "nvgre" if any NVGRE fields are present
   if any( k in treeDict for k in NvgreFields ):
      greType = 'nvgre'

   # Currently, an inner L2 header is only supported with NVGRE
   if any( k in treeDict for k in InnerL2Fields ):
      greType = 'nvgre'

   return greType

def generateIpVersion( treeDict ):
   originalIpVersion = treeDict.get( '<ipVersion>' )
   if originalIpVersion:
      return originalIpVersion

   ipVersion = None
   if any( k in treeDict for k in IPv6Fields ):
      ipVersion = 'ipv6'
   elif any( k in treeDict for k in IPv4Fields ):
      ipVersion = 'ipv4'

   return ipVersion

def generatePacketTypes( treeDict ):
   """Attempt to generate the packet information from the fields that have been
   provided. Returns a PacketInfo namedtuple"""
   packetType = generatePacketType( treeDict )
   innerPacketType = generateInnerPacketType( treeDict, packetType )
   l4Type = generateL4Type( treeDict )
   ipVersion = generateIpVersion( treeDict )
   greType = generateGreType( treeDict )

   # If the packet type was not provided (we're inferring the type from the
   # fields), prompt for the packetType in the case that an IP-in-IP packet was
   # detected. This is because a GRE packet may also have outer and inner IP
   # fields, so we still want to prompt for the packet type
   if ( not treeDict.get( '<packetType>' ) and packetType != 'gre' and
        innerPacketType in [ 'ipv4', 'ipv6' ] ):
      packetType = None

   return PacketInfo( packetType, innerPacketType, l4Type, ipVersion, greType )

def generateRequiredFields( packetInfo ):
   """Generate a list of required fields based on both the packet type (Ethernet,
   IPv4, IPv6, GRE, raw), inner packet type (IPv4, IPv6)  and L4 type (TCP, UDP)."""
   packetType = packetInfo.packetType
   innerPacketType = packetInfo.innerPacketType
   l4Type = packetInfo.l4Type

   if packetType == 'raw':
      return [ '<rawPacket>', '<ingressIntf>' ]

   fields = BaseFields + L2Fields
   if packetType != 'ethernet' and '<etherType>' in fields:
      fields.remove( '<etherType>' )

   if packetType == 'ipv4':
      fields += IPv4Fields
   elif packetType == 'ipv6':
      fields += IPv6Fields
   elif packetType == 'gre':
      if packetInfo.ipVersion == 'ipv4':
         fields += IPv4Fields
      elif packetInfo.ipVersion == 'ipv6':
         fields += IPv6Fields

      if packetInfo.greType is None:
         fields += GreTypeFields

      if packetInfo.greType == 'nvgre':
         fields += NvgreFields
         fields += InnerL2Fields
         if ( innerPacketType and innerPacketType != 'ethernet' and
              '<innerL2EtherType>' in fields ):
            fields.remove( '<innerL2EtherType>' )

   # Note that the <packetType> field is required in non-interactive mode for
   # IP-in-IP to differentiate them from GRE packets. Since all GRE fields are
   # optional, they may appear identical to IP-in-IP packets
   if innerPacketType == 'ipv4':
      fields += InnerIPv4Fields
      fields += [ '<packetType>' ]
   elif innerPacketType == 'ipv6':
      fields += InnerIPv6Fields
      fields += [ '<packetType>' ]

   if l4Type in [ 'tcp', 'udp' ]:
      fields += L4Fields

   return fields

def updateEthertypeAndProtocol( treeDict ):
   packetType = treeDict[ '<packetType>' ]
   # IP version and GRE type fields are only populated if the packet type is "gre"
   ipVersion = treeDict.get( '<ipVersion>' )
   greType = treeDict.get( '<greType>' )
   innerPacketType = treeDict[ '<innerPacketType>' ]
   protocolKey = None

   def isGreVersion( version ):
      return packetType == 'gre' and ipVersion == version

   if packetType == 'raw':
      return
   elif packetType == 'ipv4' or isGreVersion( 'ipv4' ):
      treeDict[ '<etherType>' ] = ETH_P_IP
      protocolKey = '<ipProto>'
   elif packetType == 'ipv6' or isGreVersion( 'ipv6' ):
      treeDict[ '<etherType>' ] = ETH_P_IPV6
      protocolKey = '<nextHeader>'

   if protocolKey:
      outerProtocolKey = protocolKey
      if innerPacketType == 'ipv4':
         treeDict[ protocolKey ] = IPPROTO_IPIP
         protocolKey = '<innerIpProto>'
      elif innerPacketType == 'ipv6':
         treeDict[ protocolKey ] = IPPROTO_IPV6
         protocolKey = '<innerNextHeader>'

      if packetType == 'gre':
         treeDict[ outerProtocolKey ] = IPPROTO_GRE
         if greType == 'nvgre':
            treeDict[ '<greProto>' ] = ETH_P_TEB
            if innerPacketType == 'ipv4':
               treeDict[ '<innerL2EtherType>' ] = ETH_P_IP
            elif innerPacketType == 'ipv6':
               treeDict[ '<innerL2EtherType>' ] = ETH_P_IPV6
         elif greType == 'gre':
            if innerPacketType == 'ipv4':
               treeDict[ '<greProto>' ] = ETH_P_IP
            elif innerPacketType == 'ipv6':
               treeDict[ '<greProto>' ] = ETH_P_IPV6

      if treeDict[ '<l4Type>' ] == 'tcp':
         treeDict[ protocolKey ] = IPPROTO_TCP
      elif treeDict[ '<l4Type>' ] == 'udp':
         treeDict[ protocolKey ] = IPPROTO_UDP

def fetchConfiguredPacket( mode ):
   """Fetch the previously configured packet from the session data"""
   treeDict = mode.session.sessionData( 'PacketTracer.Packet' )
   newConfiguration = False
   if not treeDict:
      treeDict = {}
      newConfiguration = True
      mode.session.sessionDataIs( 'PacketTracer.Packet', treeDict )

   return ( treeDict, newConfiguration )

def clearConfiguredPacket( mode ):
   """Clear a previously configured packet from the session data"""
   mode.session.sessionDataIs( 'PacketTracer.Packet', None )

def checkForIncompatibleFields( mode, treeDict ):
   error = 'Invalid option specified: {} {} incompatible with {}'

   def addErrorAndClearPacket( field, value, errorStr ):
      errorMsg = error.format( ArgToLabel[ field ], value, errorStr )
      mode.addError( errorMsg )
      clearConfiguredPacket( mode )
      raise AlreadyHandledError

   for field, fieldMap in IncompatibleFields.items():
      value = treeDict.get( field )
      if value in fieldMap:
         incompatFields = [ ArgToLabel[ i ].lower() for i in treeDict if
                            i in fieldMap[ value ] and treeDict[ i ] is not None ]
         if incompatFields:
            incompatFieldsStr = ', '.join( incompatFields )
            addErrorAndClearPacket( field, value, incompatFieldsStr )

   for field, fieldMap in IncompatibleFieldValues.items():
      value = treeDict.get( field )
      if value in fieldMap:
         fieldDict = fieldMap[ value ]
         for incompatField, incompatValue in fieldDict.items():
            if incompatValue == treeDict.get( incompatField ):
               incompatFieldValueStr = '{} {}'.format(
                  ArgToLabel[ incompatField ].lower(), incompatValue )
               addErrorAndClearPacket( field, value, incompatFieldValueStr )

def validateFields( mode, treeDict, packetTracerHwStatus, packetTracerSwStatus ):
   """Validate the user's specified fields against what is supported by the
   hardware."""
   def addErrorAndClearPacket( errorStr ):
      mode.addError( errorStr )
      clearConfiguredPacket( mode )
      raise AlreadyHandledError

   etherType = treeDict.get( '<etherType>' )
   if ( not packetTracerSwStatus.mplsPacketSupported and
        etherType in [ ETH_P_MPLS_UC, ETH_P_MPLS_MC ] ):
      addErrorAndClearPacket( 'Invalid request: MPLS packets are unsupported. '
                              'Clearing packet configuration.' )

   rawPacket = treeDict.get( '<rawPacket>' )
   if rawPacket:
      if not packetTracerHwStatus.rawPacketSupported:
         addErrorAndClearPacket( 'Invalid request: Raw packets are unsupported.' )

      # We permit spaces, so strip them out before processing it
      rawPacket = rawPacket.replace( ' ', '' )
      rawPacketLen = len( rawPacket )
      if ( rawPacketLen % 2 ) == 1:
         # All bytes need to be specified in full, 2 characters each
         addErrorAndClearPacket( 'Hexadecimal string is an odd length' )

      packetLen = rawPacketLen / 2
      if packetLen > packetTracerHwStatus.maximumPacketSize:
         errorStr = ( 'Raw packet of {} bytes is larger than the maximum '
                      'supported {} bytes'.format(
                         packetLen,
                         packetTracerHwStatus.maximumPacketSize ) )
         addErrorAndClearPacket( errorStr )
      elif packetLen < packetTracerHwStatus.minimumPacketSize:
         errorStr = ( 'Raw packet of {} bytes is smaller than the minimum '
                      'supported {} bytes'.format(
                         packetLen,
                         packetTracerHwStatus.minimumPacketSize ) )
         addErrorAndClearPacket( errorStr )

   if not packetTracerSwStatus.extendedTunnelTypesSupported:
      tunnelType = False
      packetType = treeDict.get( '<packetType>' )
      if packetType == 'gre':
         tunnelType = True
      elif packetType in [ 'ipv4', 'ipv6' ]:
         tunnelType = treeDict.get( '<innerPacketType>' ) in [ 'ipv4', 'ipv6' ]
      if tunnelType:
         addErrorAndClearPacket( 'Invalid request: Tunnel packets are unsupported. '
                                 'Clearing packet configuration.' )

   checkForIncompatibleFields( mode, treeDict )

def checkForMissingFields( mode, requiredFields, treeDict ):
   # Check if any fields are missing, and if so generate a list of them and
   # output them with addError.
   fields = [ key for key in treeDict if treeDict[ key ] is not None ]
   missingFields = ( set( requiredFields ) - set( fields ) )
   if missingFields:
      # Update the fields with their human readable values
      missingFields = [ ArgToLabel[ value ] for value in missingFields ]
      mode.addError( 'Missing field(s): ' + ', '.join( missingFields ) )
      return True
   return False
