#!/usr/bin/env python3
# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=no-member
from __future__ import absolute_import, division, print_function
try:
   import os
   import argparse
   import PyClient
   import sys

   from Arnet import IpAddr
   import AvtUtilCmdLib
   from AvtUtilCmdLib import (
      checkDest,
      constructPam,
      ErrorCode,
      errorStr,
      isMhPath,
      mhPath,
      mhPathFlag,
      ntpWarning,
      pathStr,
      printError,
      printLabels,
      strPathId,
      TlvCode,
      TlvParser,
      TraceRouteReqTlv,
   )
   from CliPlugin.BessCtlCmdLib import (
      Bessctl,
   )
   from CliPlugin.WanTEShowCli import (
      getVniAvtIdUsingPyClient,
      getSelfVtepUsingPyClient,
   )
   import CliPrint
   from CliPrint import (
      FloatAttrInfo,
      IntAttrInfo,
      Printer,
      StringAttrInfo,
      BoolAttrInfo,
   )
   import socket
   import struct
   import Tac

   IntfId = Tac.Type("Arnet::IntfId" )
   TimestampMask = 0xffffffff

   def int2Ip( addr ):
      return socket.inet_ntoa( struct.pack( "!I", addr ) )

   def argsParseBuilder():
      p = argparse.ArgumentParser( description='Avt TraceRoute Utility' )
      p.add_argument( 'destination', nargs='?',
                      help='Destination to trace route to' )
      p.add_argument( '--vrf', help='trace route to destination in VRF',
                      default='default' )
      p.add_argument( '--avt', required=True,
                      help='Trace route destination in AVT' )
      p.add_argument( '--fast-reroute', action='store_true', dest='frr',
                      help='reroute packet if multihop path is broken' )
      p.add_argument( '--ttl', type=int, default=64, help='TTL of traceroute pkts' )

      group = p.add_mutually_exclusive_group()
      group.add_argument( '--multihop-path', metavar='id', dest='pathId',
                          type=mhPath, default=0, help='multihop path to ping' )
      group.add_argument( '--direct-path', metavar='id', dest='pathId',
                          type=int, default=0, help='direct path to ping' )
      p.add_argument( '--json', action='store_true', help='Outputs a json' )
      return p

   class PamPacketReader( Tac.Notifiee ):
      notifierTypeName = "Arnet::UdpPam"
      def __init__( self, pam, tlvParser ):
         self.reply = {}
         self.tlvParser = tlvParser
         Tac.Notifiee.__init__( self, pam )

      @Tac.handler( "readableCount" )
      def handleReadableCount( self ):
         try:
            n = self.notifier_
            if not n:
               return
            pkt = n.rxPkt()
            if not pkt:
               return
            # we are expecting tracert reply
            pktBytes = pkt.bytesValue
            tlvType = socket.ntohs( struct.unpack( "<H", pktBytes[ : 2 ] )[ 0 ] )
            if tlvType != TlvCode.traceRouteReply: # pylint: disable=no-member
               return

            replyTlv = self.tlvParser.parseAvtUtilTraceRouteReplyBytePtr( pktBytes )
            self.reply[ replyTlv.pathId ] = replyTlv
         except KeyboardInterrupt:
            pass

   def printPathHeader( printer, tlv, hopList ):
      reroutedPath = False
      header = pathStr( tlv.pathId, printLabel=True, hopList=hopList )
      for labelInfo in tlv.tracedLabelInfo.values():
         if labelInfo.label() > 0xf000:
            header = "*" + header
            reroutedPath = True
            break
      if printer.outputFormat == CliPrint.JSON:
         rerouted = BoolAttrInfo( printer, "reroutedPath" )
         rerouted.setValue( reroutedPath )
         printer.addAttributes( "%s", rerouted.getArg() )
         if isMhPath( tlv.pathId ):
            mhPathId = tlv.pathId & ~mhPathFlag
            printLabels( printer, mhPathId, hopList )
      printer.addFrills( header + "\n" )

   def printPathTraceEntry( printer, isReroutedNode, pathId, src, egIntf,
                            dst, ingIntf, oneway, errCode, errMsg ):
      with printer.listEntry():
         if printer.outputFormat == CliPrint.JSON:
            reroutedNode = BoolAttrInfo( printer, "reroutedNode" )
            reroutedNode.setValue( isReroutedNode )
            printer.addAttributes( "%s", reroutedNode.getArg() )
            if errCode:
               printError( printer, errCode, errMsg )

         if pathId:
            directPathId = IntAttrInfo( printer, "directPathId", pathId )
            srcIp = StringAttrInfo( printer, "srcIp", src )
            egressInterface = StringAttrInfo( printer, "egressInterface", egIntf )
            dstIp = StringAttrInfo( printer, "dstIp", dst )
            ingressInterface = StringAttrInfo( printer, "ingressInterface", ingIntf )
            onewayTime = FloatAttrInfo( printer, "onewayTime", oneway )
            printer.addAttributes( " path%d %s:%s to %s:%s 1-way: %.3f ms\n",
               directPathId.getArg(), srcIp.getArg(), egressInterface.getArg(),
               dstIp.getArg(), ingressInterface.getArg(), onewayTime.getArg() )

   def printTraceRoute( src, egressTime, tlv, printer, hopList=None ):
      def timeDifference( start, end ):
         if end < start:
            diff = TimestampMask - end + 1
            diff += start
         else:
            diff = end - start
         return diff/1000.0

      printPathHeader( printer, tlv, hopList)
      # convert egresstime from seconds to uSec
      egressTime = int( egressTime * 1000 * 1000 ) & TimestampMask
      egressIntf = IntfId()
      ingressIntf = IntfId()

      # save the dst Ip for trace completion check
      dst = None

      with printer.list( "pathTrace" ):
         # print trace details about each hop in a path
         labelSeq = 0
         isReroutedNode = False
         for labelSeq, labelInfo in tlv.tracedLabelInfo.items():
            pathId = labelInfo.label()
            if pathId > 0xf000:
               labelOutput = f"  ^{labelSeq+1}"
               isReroutedNode = True
               pathId -= 0xf000
            else:
               labelOutput = f"   {labelSeq+1}"

            # gather label info
            egressIntf.intfId = labelInfo.egressIntfId()
            ingressIntf.intfId = labelInfo.ingressIntfId()
            dst = int2Ip( labelInfo.ingressVtepIp() )
            ingressTime = labelInfo.timestamp()
            oneway = timeDifference( egressTime, ingressTime )

            printer.addFrills( labelOutput )
            printPathTraceEntry( printer, isReroutedNode, pathId, src,
                                 egressIntf.stringValue, dst,
                                 ingressIntf.stringValue, oneway, 0, "" )

            # move to the next item
            src = dst
            egressTime = ingressTime

         # error condition
         # if the last labelSeq ingress info does not contain destination, the
         # we didn't trace the full route
         if tlv.endpoint != dst:
            output = f"   {labelSeq+2} * * *"
            if tlv.error:
               output += ", " + errorStr[ tlv.error ]
               printPathTraceEntry( printer, None, None, None, "", None, "", None,
                                    tlv.error, errorStr[ tlv.error ] )
            printer.addFrills( output + "\n" )
      printer.addFrills( "\n" )

   def printTraceRouteNoReply( pathId, printer, hopList ):
      with printer.list( "pathTrace" ):
         printer.addFrills( pathStr( pathId, printLabel=True,
                                     hopList=hopList ) + "\n" )
         with printer.listEntry():
            # pylint: disable=no-member
            errMsg = errorStr[ ErrorCode.pktLost ]
            printer.addFrills( "  1 * * *, %s\n", errMsg )
            if printer.outputFormat == CliPrint.JSON:
               printError( printer, ErrorCode.pktLost, errMsg )
            printer.addFrills( "\n" )

   def doTraceRoute( pathsDict, header, avtId, vni, src, destination, frr,
                     ttl, printer ):
      def replyExpired():
         diff = Tac.utcNow() - startTime
         return bool( diff > 5.0 )

      # Do the traceRoute
      reqTlv = TraceRouteReqTlv( avtId, vni, 0, destination, 0, frr )
      tlvParser = TlvParser()
      udpPam = constructPam( src, destination, ttl=ttl )
      rxPktReader = PamPacketReader( udpPam, tlvParser )
      pathList = sorted( pathsDict.keys() )
      startTime = Tac.utcNow()

      for pathId in pathList:
         pkt = Tac.newInstance( 'Arnet::Pkt' )
         reqTlv.pathId = pathId
         tlvBytes = reqTlv.serialize()
         pkt.stringValue = tlvBytes
         # we need to set the src IP before sending out every pkt
         # https://opengrok.infra.corp.arista.io/source/xref/eos-trunk/
         # src/Arnet/UdpPam.tac#56-60
         udpPam.txSrcIpAddr = src
         udpPam.txPkt = pkt

      printer.addFrills( header )
      with printer.dict( "summary" ):
         # wait for atmost 5 seconds
         remainingPaths = pathList.copy()
         while remainingPaths and not replyExpired():
            for pathId, replyTlv in rxPktReader.reply.items():
               if pathId in remainingPaths:
                  sPathId = strPathId( pathId )
                  with printer.dict( sPathId ):
                     printTraceRoute( src, startTime, replyTlv, printer,
                                      pathsDict[ pathId ] )
                     remainingPaths.remove( pathId )
            Tac.runActivities( 0.1 )

         for pathId in remainingPaths:
            sPathId = strPathId( pathId )
            with printer.dict( sPathId ):
               if printer.outputFormat == CliPrint.JSON:
                  rerouted = BoolAttrInfo( printer, "reroutedPath" )
                  rerouted.setValue( False )
                  printer.addAttributes( "%s", rerouted.getArg() )
                  if isMhPath( pathId ):
                     mhPathId = pathId & ~mhPathFlag
                     printLabels( printer, mhPathId, pathsDict[ pathId ] )
               printTraceRouteNoReply( pathId, printer, pathsDict[ pathId ] )

   def doAvtTraceRoute( avtName, vrfName, destination, pathId, frr, ttl, printer ):
      # convert the cli params
      sysname = os.environ.get( 'SYSNAME', 'ar' )
      client = PyClient.PyClient( sysname, "Sysdb" )
      sfeClient = PyClient.PyClient( sysname, "Sfe" )
      vni, avtId, error = getVniAvtIdUsingPyClient( client, vrfName, avtName,
                                                    printError=False )
      if not vni or not avtId:
         errMsg = error
         printError( printer, ErrorCode.configError, errMsg )
         return
      if not checkDest( sfeClient, destination ):
         errMsg = f"{destination} is not a valid destination VTEP"
         printError( printer, ErrorCode.configError, errMsg )
         return
      src = getSelfVtepUsingPyClient( client )

      destination = IpAddr( destination )
      try:
         bestpaths = AvtUtilCmdLib.bessctl.getBestPaths(
                 vni, avtId, destination.value )
         pathsDict = {}
         for path in bestpaths.path_info:
            hopList = None
            if isMhPath( path.path_index ):
               hopList = list( path.hop )
            pathsDict[ path.path_index ] = hopList

      except Exception: # pylint: disable-msg=W0703
         printError( printer, ErrorCode.lbLookupFailure,
                     errorStr[ ErrorCode.lbLookupFailure ] )
         return

      printer.addFrills( ntpWarning )
      header = f"\nTracing route to AVT: {avtName}, VRF: {vrfName}, " \
               f"Destination: {destination}\n" \
               f"* - rerouted path(s), ^ - rerouting node(s)\n\n"
      if pathId == 0: # we are tracing the routes for all paths in avt
         doTraceRoute( pathsDict, header, avtId, vni, src, destination, frr,
                       ttl, printer )
      elif pathId in pathsDict:
         doTraceRoute( { pathId: pathsDict[ pathId ] },
                       header, avtId, vni, src, destination, frr, ttl, printer )
      else:
         printError( printer, ErrorCode.pathIdNotPartOfAvt,
                     errorStr[ ErrorCode.pathIdNotPartOfAvt ])
         return

   def main():
      # build argument parser
      p = argsParseBuilder()
      args = vars( p.parse_args() )
      destination = args[ 'destination' ]
      vrfName = args[ 'vrf' ]
      avtName = args[ 'avt' ]
      pathId = args[ 'pathId' ]
      frr = args[ 'frr' ]
      ttl = args[ 'ttl' ]
      outputJson = args[ 'json' ]
      outputFormat = CliPrint.JSON if outputJson else CliPrint.TEXT

      sys.stdout.flush()
      printer = Printer( outputFormat )
      printer.start()
      try:
         AvtUtilCmdLib.bessctl = Bessctl()
      except Exception: # pylint: disable-msg=broad-except
         errMsg = "internal error"
         printError( printer, ErrorCode.configError, errMsg )
         sys.exit( 1 )

      if printer.outputFormat == CliPrint.JSON:
         vrf = StringAttrInfo( printer, "vrf", vrfName )
         avt = StringAttrInfo( printer, "avt", avtName )
         destIp = StringAttrInfo( printer, "destVtep", destination )
         printer.addAttributes( "vrf %s avt %s destination %s\n",
                                 vrf.getArg(), avt.getArg(), destIp.getArg() )
      doAvtTraceRoute( avtName, vrfName, destination, pathId, frr, ttl, printer )
      printer.end()

   # Run the program
   main()

# We catch the exception but skip doing anything to handle it.
except KeyboardInterrupt:
   pass
except Exception: # pylint: disable-msg=broad-except
   pass
except SystemExit:
   pass
