#!/usr/bin/env python3
# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pkgdeps: rpmwith %{_libdir}/libProbeClientLib.so*

import argparse
import json
import socket
import sys

import Tac
import Tracing

t0 = Tracing.trace0

IpGenAddr = Tac.Type( 'Arnet::IpGenAddr' )
ProbeCode = Tac.Type( 'IcmpResponder::ProbeClient::ProbeCode' )
NeighborState = Tac.Type( 'IcmpResponder::ProbeClient::ProbeNeighborState' )
InterfaceState = Tac.Type( 'IcmpResponder::ProbeClient::ProbeInterfaceState' )

def pct( numerator, denominator ):
   if denominator == 0:
      return "0%"
   return f"{int( numerator * 100 / denominator )}%"

def json_summary( config, status ):
   ret = { 'sent': status.count, 'errors': status.txErr,
               'received': status.rcvdCount }
   # The code and state counts can be many-to-one, since
   # any unknown value maps to "unknown".  So we make sure
   # to sum up all of the values with the same enum value.
   ret[ 'codes' ] = {}
   for code, count in status.codeCount.items():
      idx = ProbeCode( code ).grpcEnum()
      ret[ 'codes' ].setdefault( idx, 0 )
      ret[ 'codes' ][ idx ] += count
   ret[ 'states' ] = {}
   for state, count in status.stateCount.items():
      idx = NeighborState( state ).grpcEnum()
      ret[ 'states' ].setdefault( idx, 0 )
      ret[ 'states' ][ idx ] += count
   ret[ 'actives' ] = {}
   for active, count in status.intfStateCount.items():
      ret[ 'actives' ][ active.grpcEnum() ] = count
   return json.dumps( ret )

def summary( config, status ):
   ret = ""
   ret += f"--- {config.dest} PROBE statistics ---\n"
   ret += f"{status.count} packets transmitted, "
   if status.txErr:
      ret += f"{status.txErr} errors, "
   ret += f"{status.rcvdCount} received, "
   ret += f"{pct( status.count - status.rcvdCount, status.count )} "
   ret += "packet loss\n"
   goodProbes = status.codeCount.get( Tac.enumValue( 'Arnet::IcmpExtendedEchoCode',
                                                     'extendedEchoNoError' ), 0 )
   ret += f"{goodProbes} "
   ret += f"({pct( goodProbes, status.rcvdCount )}) "
   ret += "successful probes\n"
   if goodProbes:
      if config.destLocal:
         for state in sorted( status.intfStateCount ):
            count = status.intfStateCount[ state ]
            desc = state.humanReadable()
            ret += f"  {count} ({pct( count, goodProbes )}) {desc}\n"
      else:
         for state in sorted( status.stateCount ):
            count = status.stateCount[ state ]
            desc = NeighborState( state ).humanReadable()
            ret += f"  {count} ({pct( count, goodProbes )}) {desc}\n"
   errorCodes = sorted( status.codeCount )
   if 0 in errorCodes:
      errorCodes.pop( 0 )
   if errorCodes:
      ret += "Errors:\n"
      for code in errorCodes:
         ret += f"{status.codeCount[ code ]} "
         ret += f"({pct( status.codeCount[ code ], status.rcvdCount )}) "
         ret += ProbeCode( code ).humanReadable() + "\n"
   if status.badPkt:
      if not errorCodes:
         ret += "Errors:\n"
      ret += f"{status.badPkt} packets could not be parsed\n"
   return ret

class ProbeCompleteReactor( Tac.Notifiee ):
   notifierTypeName = 'IcmpResponder::ProbeClient::ProbeClientSm'

   def __init__( self, notifier, useJson ):
      Tac.Notifiee.__init__( self, notifier )
      self.json_ = useJson

   @Tac.handler( 'retval' )
   def handleRetval( self ):
      sm = self.notifier()
      if self.json_:
         print( json_summary( sm.config, sm.status ) )
      else:
         print( summary( sm.config, sm.status ), end='' )
      sys.stdout.flush()
      raise SystemExit( self.notifier().retval )

class JsonResultReactor( Tac.Notifiee ):
   notifierTypeName = 'IcmpResponder::ProbeClient::ProbeStatus'

   @Tac.handler( 'rcvd' )
   def handleRcvd( self, key ):
      status = self.notifier()
      resp = status.rcvd[ key ]
      out = { 'sequence': key }
      out[ 'source' ] = str( resp.srcAddr )
      out[ 'time' ] = resp.time - status.sent[ key ]
      out[ 'code' ] = ProbeCode( resp.code ).grpcEnum()
      if resp.code == 0:
         if resp.reply.state == 0:
            out[ 'active' ] = InterfaceState( resp.reply.active, resp.reply.ipv4,
                                              resp.reply.ipv6 ).grpcEnum()
         else:
            out[ 'state' ] = NeighborState( resp.reply.state ).grpcEnum()
      print( json.dumps( out ) )

def main( argv=None ):
   argv = argv or sys.argv[ 1 : ]
   parser = argparse.ArgumentParser()
   parser.add_argument( 'dest', help='The proxy node id' )
   group = parser.add_mutually_exclusive_group( required=True )
   group.add_argument( '--ifindex', type=int,
         help='Specify probed interface by ifIndex' )
   group.add_argument( '--ifname',
         help='Specify probed interface by ifName' )
   group.add_argument( '--addr',
         help='Specify probed interface by address' )
   proto = parser.add_mutually_exclusive_group()
   proto.add_argument( '-4', '--ipv4', action='store_true',
         help='Force IPv4' )
   proto.add_argument( '-6', '--ipv6', action='store_true',
         help='Force IPv6' )
   parser.add_argument( '--remote', action='store_true',
         help='The address being probed is remote from the proxy node' )
   parser.add_argument( '--source',
         help='Source address to use for the probe packets' )
   parser.add_argument( '--count', type=int, default=3,
         help='The number of packets to send (default=%(default)s)' )
   parser.add_argument( '--wait', type=float, default=1.0,
         help='Wait between sending packets (default=%(default)s)' )
   parser.add_argument( '--finalWait', type=float, default=5.0,
         help='Wait after sending all packets to receive all responses '
              '(default=%(default)s)' )
   parser.add_argument( '--quiet', action='store_true',
         help='Only print summary, not information about each packet' )
   parser.add_argument( '--no-padding', action='store_true',
         help=argparse.SUPPRESS )
   parser.add_argument( '--json', action='store_true',
         help='Output result in jsonlines format' )
   args = parser.parse_args( args=argv )
   if args.remote and ( args.ifname or args.ifindex ):
      parser.error( "With --remote, must specify --addr" )

   config = Tac.newInstance( 'IcmpResponder::ProbeClient::ProbeConfig' )
   status = Tac.newInstance( 'IcmpResponder::ProbeClient::ProbeStatus' )

   # Handle hostname in dest
   try:
      addrs = socket.getaddrinfo( args.dest, None, 0, socket.SOCK_RAW )
   except socket.gaierror as e:
      parser.error( e )
   t0( args.dest, "=>", addrs )
   afLimit = None
   limitStr = ''
   if args.ipv4:
      afLimit = socket.AF_INET
      limitStr = " into an IPv4 address"
   elif args.ipv6:
      afLimit = socket.AF_INET6
      limitStr = " into an IPv6 address"
   # Starting with the example from the python docs:
   # https://docs.python.org/3/library/socket.html#socket.getaddrinfo
   # >>> socket.getaddrinfo( "example.org", 80, proto=socket.IPPROTO_TCP )
   # [ ( socket.AF_INET6, socket.SOCK_STREAM,
   #     6, '', ( '2606:2800:220:1:248:1893:25c8:1946', 80, 0, 0 ) ),
   #   ( socket.AF_INET, socket.SOCK_STREAM,
   #     6, '', ( '93.184.216.34', 80 ) ) ]
   # if asked for AF_INET6, we reduce the list to only the first
   # element; if asked for AF_INET, we reduce the list to only
   # the second element, by filtering on addr[ 0 ] below.
   # Then we choose the first element of the remaining list, if
   # any, relying on the definition of getaddrinfo to have given us
   # the best address first.
   # The 4th entry of that tuple is the address and port,
   # but since we are using raw sockets there is no port, so
   # we just want the address out of that tuple.
   if afLimit:
      addrs = [ addr for addr in addrs if addr[ 0 ] == afLimit ]
      t0( "addrs subset =>", addrs )
   if addrs:
      dest = addrs[ 0 ][ 4 ][ 0 ]
   else:
      errStr = f"Could not resolve {args.dest}{limitStr}"
      if args.json:
         print( '{"error":"' + errStr + '"}' )
         raise SystemExit( 1 )
      parser.error( errStr )

   # Copy args into config
   try:
      config.dest = IpGenAddr( dest )
      if args.source:
         config.source = IpGenAddr( args.source )
      config.count = args.count
      config.wait = args.wait
      config.finalWait = args.finalWait
      config.destLocal = not args.remote
      config.quiet = args.quiet or args.json
      config.noPadding = args.no_padding
      what = ''
      if args.ifname:
         config.ifName = args.ifname.encode()
         what = 'interface ' + args.ifname
      elif args.ifindex:
         config.ifIndex = args.ifindex
         what = f'ifIndex {args.ifindex}'
      else:
         # parser requires one to be present
         assert args.addr
         config.ifAddress = IpGenAddr( args.addr )
         what = 'remote neighbor' if args.remote else 'interface'
         what += ' with address ' + args.addr
   except IndexError as e:
      parser.error( e )

   if config.dest.af == 'ipv4':
      pam = Tac.newInstance( 'Arnet::IpSocketPam', 'icmp', 'ipProtoIcmp', 'default' )
   else:
      pam = Tac.newInstance( 'Arnet::Ip6SocketPam',
                             'icmp6', 'ipProtoIcmpv6', 'default' )
   if args.json:
      topLine = { 'proxy': str( config.dest ), 'local': config.destLocal }
      if config.ifName:
         topLine[ 'ifName' ] = config.ifName.decode( 'utf8' )
      elif config.ifIndex:
         topLine[ 'ifIndex' ] = config.ifIndex
      else:
         topLine[ 'addr' ] = str( config.ifAddress )
      print( json.dumps( topLine ) )
   else:
      print( "PROBE", config.dest, what )
   sys.stdout.flush()

   try:
      sm = Tac.newInstance( 'IcmpResponder::ProbeClient::ProbeClientSm',
            pam, config, status )
   except Exception as e: # pylint: disable-msg=broad-except
      parser.error( e )
   exiter = ProbeCompleteReactor( sm, args.json )
   if args.json:
      _jsonPrinter = JsonResultReactor( status )
   try:
      _ = exiter
      Tac.runActivities()
   except KeyboardInterrupt:
      sm.retval = -1

if __name__ == "__main__":
   main()
