#!/usr/bin/env python3
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=no-member
from __future__ import absolute_import, division, print_function
try:
   import argparse
   import os
   import PyClient
   import sys
   import socket
   import struct
   import statistics

   from Arnet import IpAddr
   import AvtUtilCmdLib
   from AvtUtilCmdLib import (
      checkDest,
      constructPam,
      ErrorCode,
      errorStr,
      getTimestamp,
      isMhPath,
      mhPath,
      mhPathFlag,
      ntpWarning,
      pathStr,
      strPathId,
      PingReqTlv,
      printError,
      printLabels,
      TlvCode,
      TlvParser,
   )
   from CliPlugin.BessCtlCmdLib import (
      Bessctl,
   )
   from CliPlugin.WanTEShowCli import (
      getVniAvtIdUsingPyClient,
      getSelfVtepUsingPyClient,
   )
   import CliPrint
   from CliPrint import (
      FloatAttrInfo,
      IntAttrInfo,
      Printer,
      StringAttrInfo,
   )
   import Tac

   def argsParseBuilder():
      p = argparse.ArgumentParser( description='Avt Ping Utility' )
      p.add_argument( 'destination', nargs='?', help='Destination to Ping' )
      p.add_argument( '--vrf', help='Ping destination in VRF', default='default' )
      p.add_argument( '--avt', required=True, help='Ping destination in AVT' )
      p.add_argument( '--repeat', type=int, default=sys.maxsize,
                      help='count of ping pkts' )
      p.add_argument( '--size', type=int, help='size of ping pkts', default=64 )
      p.add_argument( '--interval', type=float, default=1.0,
                      help='interval between ping pkts, min 1.0 sec' )
      p.add_argument( '--fast-reroute', action='store_true', dest='frr',
                      help='reroute packet if multihop path is broken' )
      p.add_argument( '--ttl', type=int, default=64, help='TTL of ping pkts' )

      group = p.add_mutually_exclusive_group()
      group.add_argument( '--multihop-path', metavar='id', dest='pathId',
                          type=mhPath, default=0, help='multihop path to ping' )
      group.add_argument( '--direct-path', metavar='id', dest='pathId',
                          type=int, default=0, help='direct path to ping' )
      p.add_argument( '--json', action='store_true', help='Outputs a json' )
      return p

   class Reply:
      def __init__( self, srcIp, tlv, timestamp, length ):
         self.srcIp = srcIp
         self.tlv = tlv
         self.receivedTime = timestamp
         self.len = length

   class Summary:
      def __init__( self ):
         self.rtt = []
         self.oneway = []
         self.recvCount = 0

   class PamPacketReader( Tac.Notifiee ):
      notifierTypeName = "Arnet::UdpPam"
      def __init__( self, pam, tlvParser ):
         self.reply = {}
         self.tlvParser = tlvParser
         self.requestInitiatedTimestamp = {}
         Tac.Notifiee.__init__( self, pam )

      @Tac.handler( "readableCount" )
      def handleReadableCount( self ):
         try:
            n = self.notifier_
            if not n:
               return
            pkt = n.rxPkt()
            if not pkt:
               return
            # we are expecting ping reply
            pktBytes = pkt.bytesValue
            tlvType = socket.ntohs( struct.unpack( "<H", pktBytes[ : 2 ] )[ 0 ] )
            # pylint: disable=no-member
            if tlvType != TlvCode.pingReply:
               return

            replyTlv = self.tlvParser.parseAvtUtilPingReplyBytePtr( pktBytes )
            pathId = replyTlv.pathId

            # sometimes the reply may come much later, after we moved on
            # and deleted the entry
            if replyTlv.seqNum in self.reply:
               self.reply[ replyTlv.seqNum ][ pathId ] = \
                  Reply( n.rxSrcIpAddr, replyTlv, replyTlv.getSystemTime(),
                         len( pktBytes ) )
         except KeyboardInterrupt:
            pass

   class TxPeriodic:
      def __init__( self, pam, srcIp, avtId, vni, frr, destination, interval, count,
                    size, pathIdList, reader ):
         self.udpPam = pam
         self.srcIp = srcIp
         # create ping request tlv using 0s for variables
         # pathId, seqno and timestamp are variying items
         # this will be reused by modifying the variables
         self.reqTlv = PingReqTlv( avtId, vni, 0, 0, destination, 0, frr, 0 )
         self.interval = interval
         self.count = count
         self.pathIdList = pathIdList
         self.clock =  Tac.ClockNotifiee( handler=self.handler,
                                          timeMin=Tac.endOfTime )
         self.reader = reader
         self.current = 1
         self.size = size

      def __del__( self ):
         self.clock = None

      def interrupt( self ):
         # return true if keyboard interrupt was handled by the handler function
         return self.clock.timeMin == Tac.endOfTime and self.current <= self.count

      def makePacket( self, pathId, seqNo ):
         self.reqTlv.pathId = pathId
         self.reqTlv.seqNum = seqNo
         self.reqTlv.senderTimeStamp = self.reqTlv.getSystemTime()
         pkt = Tac.newInstance( 'Arnet::Pkt' )
         tlvBytes = self.reqTlv.serialize()

         # assume 28 bytes overhead for IP+UDP header
         padLen = self.size - 28 - len( tlvBytes )
         if padLen > 0:
            pad = b'\xff' * padLen
            tlvBytes = tlvBytes + pad
         pkt.stringValue = tlvBytes
         return pkt

      def handler( self ):
         try:
            if self.current > self.count:
               self.clock.timeMin = Tac.endOfTime
               return

            self.reader.reply[ self.current ] = {}
            self.reader.requestInitiatedTimestamp[ self.current ] = Tac.now()

            for pathId in self.pathIdList:
               pkt = self.makePacket( pathId, self.current )
               self.udpPam.txSrcIpAddr = self.srcIp
               self.udpPam.txPkt = pkt

            # seq number self.current is sent out.
            self.current += 1
            self.clock.timeMin = Tac.now() + self.interval
         except KeyboardInterrupt:
            self.clock.timeMin = Tac.endOfTime
            return

   def doPing( pathsDict, avtId, vni, src, destination, count, interval, size,
               frr, ttl, printer ):
      def echoRequestSent( s ):
         return s <= txPeriodic.current

      def receivedAllReply( s ):
         return s in rxPktReader.reply and len( rxPktReader.reply[ s ] ) == pathCount

      def waitForReplyExpired( s ):
         if s not in rxPktReader.requestInitiatedTimestamp:
            return False
         diff = Tac.now() - rxPktReader.requestInitiatedTimestamp[ s ]
         return bool( diff > 2.0 )

      # pylint: disable=no-member
      def printResult( ttl, printer, seqNo ):
         def getPathwisePingInfo( path, rxPktReader, seqNo, summary ):
            pstr = pathStr( path )
            rtt = 0
            oneway = 0
            errCode = 0
            message = ""
            ttl = 0
            if path in rxPktReader.reply[ seqNo ]:
               r = rxPktReader.reply[ seqNo ][ path ]
               ttl = r.tlv.ttl
               if r.tlv.error:
                  estr = errorStr[ r.tlv.error ]
                  message = f"{pstr} reply from {r.srcIp}: {estr}"
                  errCode = r.tlv.error
               else:
                  oneway = abs( r.tlv.timestamp2 - r.tlv.timestamp1 ) / 1000.0
                  rtt = abs( r.receivedTime - r.tlv.timestamp1 ) / 1000.0
                  message = f"{pstr} ttl={r.tlv.ttl} rtt={rtt:.3f} ms" \
                            f" 1-way={oneway:.3f} ms"
                  # summary
                  summary[ pathId ].recvCount += 1
                  summary[ pathId ].oneway.append( oneway )
                  summary[ pathId ].rtt.append( rtt )
            else:
               message = f"{pstr} { errorStr[ ErrorCode.pktLost ] }"
               errCode = ErrorCode.pktLost
            return rtt, oneway, errCode, message, ttl

         def printPathPingInfo( pathId, rtt, oneway, errCode, message, ttl ):
            errMsg = ""
            if errCode:
               errMsg = message
            sPathId = strPathId( pathId )
            pathId = StringAttrInfo( printer, "pathId", sPathId )
            ttlRemaining = IntAttrInfo( printer, "ttlRemaining" )
            ttlRemaining.setValue( ttl )
            pathRtt = FloatAttrInfo( printer, "rtt" )
            pathRtt.setValue( rtt )
            onewayTime = FloatAttrInfo( printer, "onewayTime" )
            onewayTime.setValue( oneway )
            errorCode = IntAttrInfo( printer, "errorCode" )
            errorCode.setValue( errCode )
            errorMessage = StringAttrInfo( printer, "errorMessage" )
            errorMessage.setValue( errMsg )
            printer.addFrills( "   %s\n", message )
            if printer.outputFormat == CliPrint.JSON:
               printer.addAttributes( "%s %d %f %f %d %s",
                  pathId.getArg(), ttlRemaining.getArg(),
                  pathRtt.getArg(), onewayTime.getArg(),
                  errorCode.getArg(), errorMessage.getArg() )

         printer.addFrills( "%d bytes to %s: ttl=%d seq=%d\n",
                            ( txPeriodic.size, destination, ttl, seqNo ) )
         with printer.dict( "sequences" ):
            with printer.dict( f"seq:{ seqNo }" ):
               with printer.list( "paths" ):
                  for pathId in pathsDict:
                     with printer.listEntry():
                        pathPingInfo = getPathwisePingInfo( pathId, rxPktReader,
                                                            seqNo, summary )
                        rtt, oneway, errCode, message, ttl = pathPingInfo
                        printPathPingInfo( pathId, rtt, oneway, errCode, message,
                                           ttl )

      def printSummary( sentCount, timeDiff, printer ):
         def printPathSummary( printer, pathId, hopList, minRtt, maxRtt, avgRtt,
                               devRtt, minOneway, maxOneway, avgOneway, devOneway ):
            pathKey = StringAttrInfo( printer, "", strPathId( pathId ) )
            pathKey.isKey()
            with printer.dictEntry( pathKey.getArg() ):
               if isMhPath( pathId ) and printer.outputFormat == CliPrint.JSON:
                  mhPathId = pathId & ~mhPathFlag
                  printLabels( printer, mhPathId, hopList )
               pktsTx = IntAttrInfo( printer, "pktsTx", sentCount )
               pktsRx = IntAttrInfo( printer, "pktsRx" )
               pktsRx.setValue( s.recvCount )
               lossRate = FloatAttrInfo( printer, "lossRate" )
               lossRate.setValue( pktLoss )
               totalTime = FloatAttrInfo( printer, "totalTime", timeDiff )
               rttMin = FloatAttrInfo( printer, "rttMin" )
               rttMin.setValue( minRtt )
               rttMax = FloatAttrInfo( printer, "rttMax" )
               rttMax.setValue( maxRtt )
               rttAvg = FloatAttrInfo( printer, "rttAvg" )
               rttAvg.setValue( avgRtt )
               rttMdev = FloatAttrInfo( printer, "rttMdev" )
               rttMdev.setValue( devRtt )
               onewayTimeMin = FloatAttrInfo( printer, "onewayTimeMin" )
               onewayTimeMin.setValue( minOneway )
               onewayTimeMax = FloatAttrInfo( printer, "onewayTimeMax" )
               onewayTimeMax.setValue( maxOneway )
               onewayTimeAvg = FloatAttrInfo( printer, "onewayTimeAvg" )
               onewayTimeAvg.setValue( avgOneway )
               onewayTimeMdev = FloatAttrInfo( printer, "onewayTimeMdev" )
               onewayTimeMdev.setValue( devOneway )

               printer.addAttributes( "   %d packets transmitted, %d received, "
                  "%.1f%% packet loss, time %.3fms\n", pktsTx.getArg(),
                  pktsRx.getArg(), lossRate.getArg(), totalTime.getArg() )
               printer.addAttributes( "   rtt    min/avg/max/mdev = "
                  "%.3f/%.3f/%.3f/%.3f ms\n", rttMin.getArg(),
                  rttAvg.getArg(), rttMax.getArg(), rttMdev.getArg() )
               printer.addAttributes( "   oneway min/avg/max/mdev = "
                  "%.3f/%.3f/%.3f/%.3f ms\n", onewayTimeMin.getArg(),
                  onewayTimeAvg.getArg(), onewayTimeMax.getArg(),
                  onewayTimeMdev.getArg() )
               printer.addFrills( "\n" )

         with printer.dict( "summary" ):
            printer.addFrills( "\n-- %s ping statistics --\n", destination )

            for pathId, hopList in pathsDict.items():
               s = summary[ pathId ]
               pktLoss = ( sentCount -  s.recvCount ) * 100 / sentCount
               printer.addFrills( pathStr( pathId, printLabel=True,
                                           hopList=hopList ) + "\n" )
               minRtt = maxRtt = avgRtt = devRtt = 0
               minOneway = maxOneway = avgOneway = devOneway = 0
               if s.recvCount > 0:
                  minRtt = min( s.rtt )
                  maxRtt = max( s.rtt )
                  avgRtt = statistics.mean( s.rtt )
                  devRtt = 0.0 if s.recvCount == 1 else statistics.stdev( s.rtt )
                  minOneway = min( s.oneway )
                  maxOneway = max( s.oneway )
                  avgOneway = statistics.mean( s.oneway )
                  devOneway = 0.0 if s.recvCount == 1 else \
                              statistics.stdev( s.oneway )
               printPathSummary( printer, pathId, hopList, minRtt, maxRtt, avgRtt,
                                 devRtt, minOneway, maxOneway, avgOneway, devOneway )

      def finishedSeq( finalSeqNo ):
         return receivedAllReply( finalSeqNo ) or waitForReplyExpired( finalSeqNo )

      # Do the Ping
      tlvParser = TlvParser()
      udpPam = constructPam( src, destination, ttl=ttl )
      rxPktReader = PamPacketReader( udpPam, tlvParser )
      txPeriodic = TxPeriodic( udpPam, src, avtId, vni, frr, destination,
                               interval, count, size, list( pathsDict.keys() ),
                               rxPktReader )

      # kick start the tx
      txPeriodic.clock.timeMin = Tac.now()
      Tac.runActivities( 0 )
      seqNo = 1
      pathList = sorted( pathsDict )
      pathCount = len( pathList )
      summary = { pathId: Summary() for pathId in pathList }
      finalSeqNo = 0
      startTime = getTimestamp()
      try:
         while seqNo <= count and not txPeriodic.interrupt():
            if echoRequestSent( seqNo ):
               while not finishedSeq( seqNo ):
                  Tac.runActivities( 0.1 )
               printResult( ttl, printer, seqNo )
               del rxPktReader.reply[ seqNo ]
               seqNo += 1

         finalSeqNo = seqNo - 1
      except KeyboardInterrupt:
         # stop the periodic
         txPeriodic.clock.timeMin = Tac.endOfTime

         # give time for the transmitted packets to be received
         finalSeqNo = txPeriodic.current - 1
         while not finishedSeq( finalSeqNo ):
            Tac.runActivities( 0.1 )
      endTime = getTimestamp()
      totalTime = ( endTime - startTime ) / 1000 # convert in ms
      # print the last reply if we haven't done already
      if finalSeqNo in rxPktReader.reply:
         printResult( ttl, printer, finalSeqNo )
      printSummary( finalSeqNo, totalTime, printer )

   def doAvtPing( avtName, vrfName, destination, pathId,
                  count, interval, size, frr, ttl, printer ):
      # convert the cli params
      sysname = os.environ.get( 'SYSNAME', 'ar' )
      client = PyClient.PyClient( sysname, "Sysdb" )
      sfeClient = PyClient.PyClient( sysname, "Sfe" )
      vni, avtId, error = getVniAvtIdUsingPyClient( client, vrfName, avtName,
                                                    printError=False )
      if not vni or not avtId:
         errMsg = error
         printError( printer, ErrorCode.configError, errMsg )
         return
      if not checkDest( sfeClient, destination ):
         errMsg = f"{destination} is not a valid destination VTEP"
         printError( printer, ErrorCode.configError, errMsg )
         return
      src = getSelfVtepUsingPyClient( client )

      destination = IpAddr( destination )
      try:
         bestpaths = AvtUtilCmdLib.bessctl.getBestPaths(
                            vni, avtId, destination.value )
         pathsDict = {}
         for path in bestpaths.path_info:
            hopList = None
            if isMhPath( path.path_index ):
               hopList = list( path.hop )
            pathsDict[ path.path_index ] = hopList

      # pylint: disable=broad-except
      except Exception:
         printError( printer, ErrorCode.lbLookupFailure,
                     errorStr[ ErrorCode.lbLookupFailure ] )
         return

      if pathId == 0: # we are pinging the entire avt
         doPing( pathsDict, avtId, vni, src, destination,
                 count, interval, size, frr, ttl, printer )
      elif pathId in pathsDict:
         doPing( { pathId: pathsDict[ pathId ] }, avtId, vni, src, destination,
                 count, interval, size, frr, ttl, printer )
      else:
         printError( printer, ErrorCode.pathIdNotPartOfAvt,
                     errorStr[ ErrorCode.pathIdNotPartOfAvt ])
         return

   def main():
      # build argument parser
      p = argsParseBuilder()
      args = vars( p.parse_args() )
      destination = args[ 'destination' ]
      vrfName = args[ 'vrf' ]
      avtName = args[ 'avt' ]
      count = args[ 'repeat' ]
      size = args[ 'size' ]
      interval = args[ 'interval' ]
      pathId = args[ 'pathId' ]
      frr = args[ 'frr' ]
      ttl = args[ 'ttl' ]
      outputJson = args[ 'json' ]
      outputFormat = CliPrint.JSON if outputJson else CliPrint.TEXT

      # flush stdout and start cliPrinter
      sys.stdout.flush()
      printer = Printer( outputFormat )
      printer.start()

      if interval < 1.0:
         errMsg = "minimum interval should be 1.0 s"
         printError( printer, ErrorCode.configError, errMsg )
         sys.exit( 1 )

      if size < 64:
         errMsg = "minimum size should be 64 bytes."
         printError( printer, ErrorCode.configError, errMsg )
         sys.exit( 1 )

      try:
         AvtUtilCmdLib.bessctl = Bessctl()
      # pylint: disable=broad-except
      except Exception:
         errMsg = "internal error"
         printError( printer, ErrorCode.configError, errMsg )
         sys.exit( 1 )

      printer.addFrills( ntpWarning )
      if printer.outputFormat == CliPrint.JSON:
         vrf = StringAttrInfo( printer, "vrf", vrfName )
         avt = StringAttrInfo( printer, "avt", avtName )
         destIp = StringAttrInfo( printer, "destVtep", destination )
         repeat = IntAttrInfo( printer, "repeat", count )
         printer.addAttributes( "vrf %s avt %s destination %s repeat %d\n",
                                 vrf.getArg(), avt.getArg(), destIp.getArg(),
                                 repeat.getArg() )
      doAvtPing( avtName, vrfName, destination, pathId, count, interval, size,
                 frr, ttl, printer )
      printer.end()

   # Run the program
   main()

# We catch the exception but skip doing anything to handle it.
except KeyboardInterrupt:
   pass
except Exception: # pylint: disable-msg=broad-except
   pass
except SystemExit:
   pass
