#!/usr/bin/env python3
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from argparse import ArgumentParser
from IpLibConsts import DEFAULT_VRF
import Arnet
from enum import Enum
import socket, sys
import struct
import selectors
from Ethernet import (
      convertMacAddrToPackedString
      )
from dpkt.ethernet import(
      Ethernet,
      ETH_HDR_LEN,
      ETH_TYPE_IP,
      ETH_TYPE_IP6,
      ETH_CRC_LEN
      )
from dpkt.ip import (
      IP,
      IP_HDR_LEN,
      IP_PROTO_UDP,
      IP_PROTO_TCP,
      )
from dpkt.ip6 import (
      IP6,
      )
from dpkt.udp import (
      UDP,
      UDP_HDR_LEN
       )
from dpkt.tcp import (
      TCP,
       )
from dpkt.icmp import (
      ICMP_TIMEXCEED,
      ICMP_UNREACH
       )
from dpkt.icmp6 import (
      ICMP6,
      ICMP6_TIME_EXCEEDED,
      ICMP6_DST_UNREACH
      )
from Arnet.NsLib import socketAt
import Tac

TCP_HDR_LEN = 20
IP6_HDR_LEN = 40
VXLAN_HDR_LEN = 8
UDP_DPORT_VXLAN = 4789
SOL_RAW = 255
SOL_ICMPV6 = 58
ICMP_FILTER = 1

# pylint: disable=E1101
# pylint: disable=W0212

# The script is inteneded to trace the path a packet with specified five tuple 
# would take accross the network similar to 'traceroute', however, this script
# prints each hop in underlay as well as overlay networks. This is done by
# encapsulating trace packet in pseduo VXLAN header and injecting it into fwd
# kernel interface to be routed by hardware with incrementing TTL value. The ICMP
# replies are then captured to determin the next hop

# The final trace packet has following headers
#  ----------------------------------
#  |  bridgeMacAddr  bridgeMacAddr  |    Ethernet Hdr 2
#  ----------------------------------
#  | sourceIp | desintation |  UDP  |    IP Hdr 2
#  ----------------------------------
#  |     1234   |  UDP_DPORT_VXLAN  |    L4 Hdr 2
#  ----------------------------------
#  |             VXLAN              |    VXLAN header
#  ----------------------------------
#  |  bridgeMacAddr  bridgeMacAddr  |    Ethernet Hdr 1
#  ----------------------------------
#  | source | desintation | proto   |    IP Hdr 1
#  ----------------------------------
#  |       sport   |     dport      |    L4 Hdr 1
#  ----------------------------------

# Inputs to the script:
# ( source, destination, sport, dport, proto ) - the five tuple to be used
#       for hash calculation to ensure that trace packet takes the same path
#               as regular traffic for that flow
# vrf - overlay VRF where `destination` route is
# fwdInterface - kernel interface where trace packet will be injected to be routed
#                       by hardware ( must be in specified VRF )
# intfSourceIp - the source IP to be used in pseudo header to capture ICMP replies
#               in correct VRF
parser = ArgumentParser()
parser.add_argument( 'source', metavar='ADDR', help="Source IP address" )
parser.add_argument( 'destination', metavar='ADDR', help="Destination IP address" )
parser.add_argument( 'sport', type=int, help="Source L4 port" )
parser.add_argument( 'dport', type=int, help="Destination L4 port" )
parser.add_argument( 'proto', type=int, help="Protocol ID" )
parser.add_argument( '--vrf', dest='vrf', type=str, default=DEFAULT_VRF,
                        help="Desired VRF" )
parser.add_argument( '-i', dest='fwdInterface', type=str, required=True,
                        help="Forward interface" )
parser.add_argument( '--bridgeMac', dest='bridgeMacAddr', type=str,
                        default='0.0.0',
                        help="Bridge MAC address" )
parser.add_argument( '-s', dest='intfSourceIp', type=str, default='0.0.0.0',
                        help="Source IP of trace packet" )
parser.add_argument( '-p', dest='vxlanPort', type=int, default=UDP_DPORT_VXLAN,
                        help="UDP destination port for VXLAN" )
parser.add_argument( '-m', dest='maxTtl', type=int, default=30,
                        help="Maximum number of hops" )
parser.add_argument( '-l', dest='pktSize', type=int, default=192,
                        help="Packet length" )
parser.add_argument( '--ip-version', dest='ipVer', type=int, default=4,
                        help="IP version" )
parser.add_argument( '-n', dest='nslookup', default=False, action='store_true',
                        help="Do ns lookup on hop addresses before printing" )
parser.add_argument( '-u', dest='underlayOnly', default=False, action='store_true',
                        help="Only trace underlay" )

args = parser.parse_args()

bridgeMac = convertMacAddrToPackedString( args.bridgeMacAddr )
if args.ipVer == 6 and args.intfSourceIp == '0.0.0.0':
   args.intfSourceIp = '0::0'
intfSourceIp = Arnet.IpGenAddr( args.intfSourceIp ).bytes()
src = Arnet.IpGenAddr( args.source ).bytes()
dst = Arnet.IpGenAddr( args.destination ).bytes()
fwdInterface = args.fwdInterface

nsName = f'ns-{args.vrf}' if args.vrf != DEFAULT_VRF else None
maxTtl = args.maxTtl
pktSize = args.pktSize
timeout = 1 # seconds

PathType = Enum( 'PathType', [ 'UNDERLAY', 'OVERLAY' ] )

etherType = ETH_TYPE_IP if args.ipVer == 4  else ETH_TYPE_IP6

def buildPacket( outerTtl, innerTtl ):
   # construct payload
   payloadLen = pktSize - ETH_HDR_LEN * 2 - ETH_CRC_LEN - \
         ( IP_HDR_LEN if args.ipVer == 4 else IP6_HDR_LEN ) * 2 - \
         UDP_HDR_LEN - VXLAN_HDR_LEN - \
         ( TCP_HDR_LEN if args.proto == IP_PROTO_TCP else UDP_HDR_LEN )
   if payloadLen < 1:
      print( 'Trace packet size is too small' )
      sys.exit( -1 )

   pkt = b''.join( bytes( ( b % 256, ) ) for b in range( payloadLen ) )

   # cosntruct inner L4 header
   if args.proto == IP_PROTO_UDP:
      pkt = UDP( sport=args.sport, dport=args.dport, data=pkt )
   else:
      pkt = TCP( sport=args.sport, dport=args.dport, data=pkt )

   #construct inner IP header
   if args.ipVer == 4:
      pkt = IP( dst=dst, src=src, ttl=64, p=args.proto, data=pkt )
   else:
      pkt = IP6( dst=dst, src=src, hlim=64,
            nxt=args.proto, data=pkt, plen=len( pkt ) )

   #construct inner Ethernet header
   pkt = Ethernet( src=bridgeMac, dst=bridgeMac, type=etherType, data=pkt )

   # contstruct VXLAN header
   vxlanHdr = struct.pack( '>BBH' , 0x08, 0, 0 ) #flags
   vxlanHdr += struct.pack( '>I' , 0x0100 ) # VNI

   pkt = vxlanHdr + bytes( pkt )

   # construct UDP header for VXLAN
   pkt = UDP( sport=1234, dport=UDP_DPORT_VXLAN, data=pkt,
         ulen=( len( pkt ) + UDP_HDR_LEN ) )

   #construct IP header for VXLAN
   # * ttl - will be used to trace overlay path
   # * id - copied as ttl of outer header when trace packed is encapsulated,
   #    used to trace underlay path
   # * v - set to 5/7 ( IP/IP6 ) to indicate that this is
   #    a trace packet, interpreted by DMA driver and then overwritten to
   #    correct value
   if args.ipVer == 4:
      pkt = IP( dst=dst, src=intfSourceIp, ttl=innerTtl,
            p=IP_PROTO_UDP, id=outerTtl, data=pkt )
      pkt._v_hl = 0x55
   else:
      pkt = IP6( dst=dst, src=intfSourceIp, hlim=innerTtl,
            nxt=IP_PROTO_UDP, data=pkt, plen=len( pkt ) )
      pkt._v_fc_flow = 0x70000000 | outerTtl

   # construct outer Ethernet header
   pkt = Ethernet( src=bridgeMac, dst=bridgeMac, type=etherType, data=pkt )

   return pkt

def createRxSocket( pType, ipVer, ns=None ):
   # underlay ICMP replies will always be coming in default VRF
   if ipVer == 4:
      family = socket.AF_INET
      icmpProto = socket.IPPROTO_ICMP
      sockLevel = SOL_RAW
      icmpFilter = struct.pack(
            'I', 0xFFFFFFFF ^ (
              ( 1 << ICMP_TIMEXCEED ) | ( 1 << ICMP_UNREACH ) ) )
   else:
      family = socket.AF_INET6
      icmpProto = socket.IPPROTO_ICMPV6
      sockLevel = SOL_ICMPV6
      # ICMPv6 protocol support 256 message types
      # the ICMPV6_FILTER socket option expects a bit mask for every type of message
      # hence we need to provide 8x32-bit words as a message type mask where
      # 1 - block
      # 0 - allow
      allowedTypes = 0xFFFFFFFF ^ (
            ( 1 << ICMP6_TIME_EXCEEDED ) | ( 1 << ICMP6_DST_UNREACH ) )
      icmpFilter = struct.pack("IIIIIIII", allowedTypes, 0xFFFFFFFF,
            0xFFFFFFFF,0xFFFFFFFF, 0xFFFFFFFF,0xFFFFFFFF,0xFFFFFFFF,0xFFFFFFFF )
   s = socketAt( family, socket.SOCK_RAW, icmpProto, ns )
   s.setsockopt( sockLevel, ICMP_FILTER, icmpFilter )
   s.setblocking( False )

   return s

selKey = {}
sel = selectors.DefaultSelector()

try:
   txSocket = socketAt( family=socket.PF_PACKET, type=socket.SOCK_RAW, ns=nsName )
   txSocket.bind( ( fwdInterface, etherType ) )

   for pt in PathType:
      selKey[ pt ] = []
      if pt == PathType.OVERLAY:
         sock = createRxSocket( pt, args.ipVer, nsName )
         selKey[ pt ].append( sel.register( sock, selectors.EVENT_READ ) )
      else:
         # The underlay IP version is not known to us ahead of time so
         # we have to monitor both IPv4 and IPv6 ICMP replies
         sock4 = createRxSocket( pt, 4 )
         sock6 = createRxSocket( pt, 6 )
         selKey[ pt ].append( sel.register( sock4, selectors.EVENT_READ ) )
         selKey[ pt ].append( sel.register( sock6, selectors.EVENT_READ ) )
except OSError as e:
   print( e )
   sys.exit( -1 )
except ValueError:
   print( 'Error creating socket' )
   sys.exit( -1 )

def getsrc( pktBuf, sKey ):
   # determine if this ICMP reply was result of trace packet
   unreach = False

   if sKey.fileobj.family == socket.AddressFamily.AF_INET6:
      icmpHdr = ICMP6( pktBuf )
      unreach = ( icmpHdr.type == ICMP6_DST_UNREACH )
   else:
      ipHdr = IP( pktBuf )
      icmpHdr = ipHdr.data
      unreach = ( icmpHdr.type == ICMP_UNREACH )

   # strip ICMP header to get to original packet header
   ipHdr = icmpHdr.data.data

   if not isinstance( ipHdr.data, UDP ):
      # not our reply
      return None, False

   udpHdr = ipHdr.data
   if udpHdr.dport not in [ UDP_DPORT_VXLAN, args.vxlanPort ]:
      # not our reply
      return False, False

   if len( udpHdr.data ) < ( VXLAN_HDR_LEN + ETH_HDR_LEN + IP_HDR_LEN ):
      return True, unreach

   # skip VXLAN and inner ether headers
   if args.ipVer == 4:
      innerIpHdr = IP( udpHdr.data[ VXLAN_HDR_LEN + ETH_HDR_LEN : -1 ] )
   else:
      innerIpHdr = IP6( udpHdr.data[ VXLAN_HDR_LEN + ETH_HDR_LEN : -1 ] )
   if innerIpHdr.dst == dst:
      # our reply
      return True, unreach

   return False, False

def nslookup( addr ):
   if args.nslookup:
      try:
         return socket.gethostbyaddr( addr )[ 0 ]
      except socket.herror:
         return addr
   else:
      return addr

def tracenext( ttl, pType, underlayTtl=255 ):
   underlayTtl = ttl if pType is PathType.UNDERLAY else underlayTtl
   overlayTtl = ttl if pType is PathType.OVERLAY else 2

   pkt = buildPacket( underlayTtl, overlayTtl )

   status = txSocket.send( bytes( pkt ) )
   if not status:
      print( 'Failed to send trace packet' )
      sys.exit( -1 )

   start = Tac.now()

   while True:
      events = sel.select( timeout=timeout )

      if not events:
         return None, False, False

      for sKey, _ in events:
         if sKey in selKey[ pType ]:
            pkt, replySrc = sKey.fileobj.recvfrom( 16384 )
            nh, unreach = getsrc( pkt, sKey )
            if nh:
               return nslookup( replySrc[ 0 ] ), unreach, False
         elif sKey in selKey[ PathType.OVERLAY ] and pType == PathType.UNDERLAY:
            # if tracing underlay check if we got a reply in overaly
            # in case we reached VTEP
            pkt, replySrc = sKey.fileobj.recvfrom( 16384 )
            nh, unreach = getsrc( pkt, sKey )
            if nh:
               # reached VTEP -- stop monitoring both underlay rx socket
               for key in selKey[ PathType.UNDERLAY ]:
                  key.fileobj.close()
               return nslookup( replySrc[ 0 ] ), unreach, True
      # timeout
      if Tac.now() - start > timeout:
         return None, False, False

pTypeMap = { PathType.UNDERLAY : 'U', PathType.OVERLAY : 'O' }

def printRoute( route, pType, routes, hopCount ):
   print( f'{hopCount :<5}   {route: <60}   {pTypeMap[ pType ]}' )
   routes.append( ( route, pType ) )
   return routes

def traceroute( pType=PathType.UNDERLAY,
                ttl=1, routes=None, underlayTtl=None, hopCount=1 ):
   if ( ttl == maxTtl and pType == PathType.UNDERLAY ) or \
      ttl == maxTtl + 2:
      routes = printRoute( '* * *', pType, routes, hopCount )
      if pType == PathType.UNDERLAY and not args.underlayOnly:
         # never found VTEP -- try tracing overlay
         routes = traceroute( PathType.OVERLAY, 3, routes, ttl, hopCount + 1 )
      return routes

   route, dstFound, vtep = tracenext( ttl, pType )

   if dstFound:
      # found trace destination
      if pType == PathType.UNDERLAY:
         routes = printRoute( route, PathType.OVERLAY, routes, hopCount )
      else:
         routes = printRoute( route, pType, routes, hopCount )
      return routes
   elif vtep:
      # reached vtep -- trace overlay
      routes = printRoute( route, PathType.OVERLAY, routes, hopCount )
      if not args.underlayOnly:
         routes = traceroute( PathType.OVERLAY, 3, routes, ttl, hopCount + 1 )
      return routes
   elif not route:
      routes = printRoute( '* * *', pType, routes, hopCount )
   else:
      routes = printRoute( route, pType, routes, hopCount )

   routes = traceroute( pType, ttl + 1, routes, underlayTtl, hopCount + 1 )
   return routes

print( f'traceroute to {args.destination}, {maxTtl}'
       f' hops max, {pktSize} byte packets' )

try:
   traceroute( routes=[] )
except KeyboardInterrupt:
   sys.exit( 0 )
finally:
   txSocket.close()
   sel.close()

sys.exit( 0 )
