# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-in
# pylint: disable=consider-using-f-string
# pylint: disable=unnecessary-comprehension

import Tac
import Tracing
import BothTrace
import socket
import sys
from TypeFuture import TacLazyType

from ClientCommonLib import (
   getThreadLocalData,
   lspPingRetCodeStr,
   LspPingReturnCode,
   LspPingTypeGeneric,
   LspPingTypeLdp,
   LspPingTypeMldp,
   LspPingTypeRsvp,
   LspPingTypeSr,
   receiveMessage,
   sendViaSocket,
   setThreadLocalData,
   fecTlvRenderMap,
)
from TracerouteModel_pb2 import ( # pylint: disable=no-name-in-module
   ErrInfo as Err,
   MplsTraceroute,
)

Ipv4Numbered = Tac.Type( 'LspPing::LspPingAddrType' ).ipv4Numbered
Ipv6Numbered = Tac.Type( 'LspPing::LspPingAddrType' ).ipv6Numbered
ConnectorKey = Tac.Type( 'Pseudowire::ConnectorKey' )
DynTunnelIntfId = Tac.Type( 'Arnet::DynamicTunnelIntfId' )
LspPingDownstreamMappingType = TacLazyType(
      'LspPing::LspPingDownstreamMappingType' )
LspPingReturnCode = Tac.Type( 'LspPing::LspPingReturnCode' )
igpProtocolType = Tac.Type( 'LspPing::LspPingSrIgpProtocol' )

NtpWarning = ( 'Warning: NTP synchronization is required for '
               '1-Way time measurement accuracy.' )
#--------------------------------
# BothTrace Short hand variables
#--------------------------------
__defaultTraceHandle__ = Tracing.Handle( "MplsTracerouteClientLib" )
bv = BothTrace.Var
bt8 = BothTrace.trace8

# ---------------------------------------------------------
#           Traceroute render helpers
# ---------------------------------------------------------

def sendOrRenderTracerouteErr( err, sock=None ):
   if sock:
      errInfo = Err( err=err )
      sendViaSocket( sock, MplsTraceroute( errInfo=errInfo ) )
   else:
      print( err )

def getIgpProtocolStr( igpProtocol ):
   igpProtocolToStrDict = {
      igpProtocolType.isis : "IS-IS",
      igpProtocolType.ospf : "OSPF"
   }
   if igpProtocol in igpProtocolToStrDict:
      return igpProtocolToStrDict[ igpProtocol ]
   return ""

def renderHdr( tracerouteHdr ):
   output = ""
   if tracerouteHdr.protocol == LspPingTypeSr:
      prefix = tracerouteHdr.srModel.prefix
      algorithm = tracerouteHdr.srModel.algorithm
      algoStr = f', algorithm {algorithm}' if algorithm else ""
      if tracerouteHdr.srModel.igpProtocolType:
         igpProtocol = tracerouteHdr.srModel.igpProtocolType
         output += ( 'LSP traceroute {} segment-routing to {} {}'.format(
                                                   getIgpProtocolStr( igpProtocol ),
                                                   prefix,
                                                   algoStr ) )
      else:
         output += f'LSP traceroute to {prefix}{algoStr}'
   elif tracerouteHdr.protocol == LspPingTypeLdp:
      prefix = tracerouteHdr.ldpModel.prefix
      output += f'LSP traceroute to {prefix}'
   elif tracerouteHdr.protocol == LspPingTypeGeneric:
      # Generic traceroute will call either sr or ldp traceroute directly
      # Either ldpModel or srModel will be appear
      if tracerouteHdr.ldpModel:
         output += f'LSP traceroute to {tracerouteHdr.ldpModel.prefix}'
      elif tracerouteHdr.srModel:
         output += f'LSP traceroute to {tracerouteHdr.srModel.prefix}'
   elif tracerouteHdr.protocol == LspPingTypeRsvp:
      session = tracerouteHdr.rsvpModel.session
      lsp = tracerouteHdr.rsvpModel.lspId
      tunnel = tracerouteHdr.rsvpModel.tunnel
      subTunnelId = tracerouteHdr.rsvpModel.subTunnelId
      if session:
         lspStr = f" LSP #{lsp}" if lsp else ""
         sessionStr = f"#{session}" if session.isdigit() else session
         output += f"LSP traceroute to RSVP session {sessionStr}{lspStr}"
      else:
         subTunnelStr = ( f" sub-tunnel #{subTunnelId}" if subTunnelId
                          else "" )
         output += "LSP traceroute to RSVP tunnel {}{}".format( tunnel,
                                                                subTunnelStr )
   elif tracerouteHdr.protocol == LspPingTypeMldp:
      prefix = tracerouteHdr.ldpModel.prefix
      output += f'Lsp traceroute to MLDP route {prefix}'
   print( output )

def renderVia( viaInfo ):
   output = '  via {}, label stack: {}'.format( viaInfo.nextHopIp,
                                                str( viaInfo.labelStack ) )
   print( output )

def renderStatsVia( viaInfo ):
   output = 'Via {}, {}, label stack: {}'.format( viaInfo.nextHopIp,
                                                  viaInfo.interface,
                                                  str( viaInfo.labelStack ) )
   print( output )

def renderStatisticsHeader( tracerouteModel ):
   if getThreadLocalData( 'printStats' ):
      return

   # we only support mldp statistics summary for now
   protocol = tracerouteModel.protocol
   if protocol == LspPingTypeMldp:
      print( "\n--- MLDP target fec %s : lsptraceroute statistics ---" %
             tracerouteModel.ldpModel.prefix )
   setThreadLocalData( 'printStats', True )

def renderStatisticsSummary( summary ):
   string = '   '
   string += str( summary.packetsTransmitted ) + " packets transmitted, "
   string += str( summary.packetsReceived ) + " received, time "
   string += str( summary.replyTime ) + "ms"
   print( string )

def renderStatistics( stats ):
   string = '   '
   string += ( '{} received from {}, rtt min/max/avg '
               '{}/{}/{} ms'.format( stats.packetsReceived,
                                     stats.destination,
                                     stats.roundTripTimeMinUs / 1000, # us to ms conv
                                     stats.roundTripTimeMaxUs / 1000,
                                     stats.roundTripTimeAvgUs / 1000 ) )
   print( string )

def tracerouteReplyHdrRender( hopPktInfo ):
   if hopPktInfo.replyHost:
      output = f'{hopPktInfo.ttl:>3}  '
      output += f'{hopPktInfo.replyHost:<16}  '

      if int( hopPktInfo.hopMtu ) > 0:
         output += f'MTU {hopPktInfo.hopMtu:<4}  '

      output += '{:<12}  '.format( 'RTT:' + str( hopPktInfo.roundTrip ) +
                '.' + str( hopPktInfo.roundTripUs ) + 'ms' )
      output += '{:<14}  '.format( '1-Way:' + str( hopPktInfo.oneWayDelay // 1000 )
                 + '.' + str( hopPktInfo.oneWayDelay % 1000 ) + 'ms' )

      if hopPktInfo.retCode == LspPingReturnCode.seeDdmTlv:
         output += '  return code: {}'.format( lspPingRetCodeStr(
                                                  hopPktInfo.retCode ) )
      elif ( hopPktInfo.retCode == LspPingReturnCode.labelSwitchedAtStackDep or
             hopPktInfo.retCode == LspPingReturnCode.repRouterEgress ):
         output += '  success: {}'.format( lspPingRetCodeStr(
                                              hopPktInfo.retCode ) )
      else:
         output += '  error: {}'.format( lspPingRetCodeStr(
                                            hopPktInfo.retCode ) )
      print( output )
   else:
      print( f'{hopPktInfo.ttl:>3} ***' )

def getFecTypeStr( fecTypeInfo ):
   fecTypeStr = ""
   if fecTypeInfo and fecTypeInfo.fecTypeList:
      fecInfoList = [ fecTlvRenderMap.get( tlvType, 'Unknown' ) for tlvType in
                        fecTypeInfo.fecTypeList ]
      fecTypeStr = ", ".join( fecInfoList )
   return fecTypeStr

def getFscStr( popFecTypeList, pushFecTypeList ):

   popStr = getFecTypeStr( popFecTypeList )
   pushStr = getFecTypeStr( pushFecTypeList )

   popFecTypeStr = "Pop:[{}]".format( popStr ) if popStr else ""
   pushFecTypeStr = "Push:[{}]".format( pushStr ) if pushStr else ""

   fscStrList = [ typeStr for typeStr in
                    [ popFecTypeStr, pushFecTypeStr ] if typeStr != "" ]
   fscStr = ", ".join( fscStrList )
   return fscStr

def tracerouteDownstreamInfoRender( downstreamInfos ):
   for index, dsInfo in enumerate( downstreamInfos ):
      if dsInfo.dsType == str( LspPingDownstreamMappingType.downstreamMap ):
         print( '     downstream information (DSMAP) ' + str( index + 1 ) + ':' )
      else:
         retCode = lspPingRetCodeStr( dsInfo.retCode )
         print( '     downstream information (DDMAP) ' + str( index + 1 ) + ':' )
         print( f'        return code: {retCode}' )
      if ( dsInfo.addrType == Ipv4Numbered or dsInfo.addrType == Ipv6Numbered ):
         print( f'        interface address: {dsInfo.dsIntfAddr}' )
      else:
         print( f'        interface index: {dsInfo.dsIntfIndex}' )
      print( f'        IP address: {dsInfo.dsIpAddr}' )
      labelStackStr = ( ', '.join( dsInfo.labelStack ) )
      print( f'        label stack: [{labelStackStr}]' )
      fscStr = getFscStr( dsInfo.popFecTypeList, dsInfo.pushFecTypeList )
      if fscStr:
         print( '        switching protocols: {}'.format( fscStr ) )

def tracerouteLabelStackRender( labelStackInfo ):
   print( "     interface and label stack tlv (ILS):" )
   if ( labelStackInfo.addrType == Ipv4Numbered or
        labelStackInfo.addrType == Ipv6Numbered ):
      print( f'        interface address: {labelStackInfo.intfAddr}' )
   else:
      print( f'        interface index: {labelStackInfo.intfIndex}' )
   print( f'        IP address: {labelStackInfo.addr}' )
   labelStackStr = ( ', '.join( labelStackInfo.labelStack ) )
   print( f'        label stack: [{labelStackStr}]' )

def cleanTracerouteSocket():
   ''' Cleanup client and server socket'''
   sock = getThreadLocalData( 's' )
   if sock:
      bt8( "Clean up server socket ", bv( sock.getsockname() ) )
      sock.shutdown( socket.SHUT_RDWR )
      sock.close()
   setThreadLocalData( 'cs', None )
   setThreadLocalData( 's', None )

# ---------------------------------------------------------
#           Traceroute model generation helpers
# ---------------------------------------------------------

def getTracerouteModel():
   if getThreadLocalData( 'cache' ) is not None:
      return
   else:
      try:
         s = getThreadLocalData( 's' )
         s.settimeout( 1 ) # set timeout to return
         clientsocket = getThreadLocalData( 'cs' )
         if not clientsocket:
            clientsocket, _ = s.accept()
            bt8( "Create client socket ", bv( clientsocket.getsockname() ) )
            setThreadLocalData( 'cs', clientsocket )
         clientsocket.settimeout( 1 ) # set timeout to return
         msg = receiveMessage( clientsocket )
         setThreadLocalData( 'cache', msg )
         if not msg:
            raise socket.timeout
         storeAndRenderModel( render=getThreadLocalData( 'render' ) )
      except socket.timeout:
         # either a process is spawned by the cli or a thread is
         # spawned by binary LspTraceroute.
         p = getThreadLocalData( 'p' )
         t = getThreadLocalData( 't' )
         if not p and not t:
            # either process or thread should be present
            bt8( "No process or thread present" )
            setThreadLocalData( 'cmdBreak', True )
         # pylint: disable-next=singleton-comparison
         if p and p.poll() != None: # subprocess is not running
            bt8( "Subprocess was created and is not running anymore" )
            setThreadLocalData( 'cmdBreak', True )
         if t and not t.is_alive(): # thread is not running
            bt8( "Thread was created and is not running anymore" )
            setThreadLocalData( 'cmdBreak', True )

def storeAndRenderModel( render=False ):
   # pylint: disable=no-member
   # pylint: disable-next=import-outside-toplevel
   from CliDynamicSymbol import CliDynamicPlugin
   MplsUtilModel = CliDynamicPlugin( "MplsUtilModel" )
   ErrInfo = MplsUtilModel.ErrInfo
   MplsPath = MplsUtilModel.MplsPath
   MplsTracerouteHdr = MplsUtilModel.MplsTracerouteHdr
   MplsVia = MplsUtilModel.MplsVia
   MplsTracerouteStatistics = MplsUtilModel.MplsTracerouteStatistics
   MplsTracerouteStatisticsSummary = MplsUtilModel.MplsTracerouteStatisticsSummary
   HopPktInfo = MplsUtilModel.HopPktInfo
   DownstreamInfo = MplsUtilModel.DownstreamInfo
   IntfAndLabelStackInfo = MplsUtilModel.IntfAndLabelStackInfo
   GenericModel = MplsUtilModel.GenericModel
   RsvpModel = MplsUtilModel.RsvpModel
   SrModel = MplsUtilModel.SrModel
   FecTypeList = MplsUtilModel.FecTypeList

   traceroute = MplsTraceroute()
   traceroute.ParseFromString( getThreadLocalData( 'cache' ) )
   if traceroute.HasField( 'mplsTracerouteHdr' ):
      mplsTracerouteHdr = traceroute.mplsTracerouteHdr
      # Parent model instance for traceroute which contains all other submodel
      # instances. There can be only one tracerouteModel instance per cli cmd output.
      tracerouteModel = (
         MplsTracerouteHdr( protocol=str( mplsTracerouteHdr.protocol ),
                            dsInfoType=str( mplsTracerouteHdr.dsInfoType ) ) )
      setThreadLocalData( 'tracerouteModel', tracerouteModel )
      setThreadLocalData( 'printStats', False )
   elif traceroute.HasField( 'ldpModel' ):
      # ldpModel covers ldp, mldp and generic traceroute case
      ldpModel = traceroute.ldpModel
      tracerouteModel = getThreadLocalData( 'tracerouteModel' )
      tracerouteModel.ldpModel = GenericModel( prefix=str( ldpModel.prefix ) )
      if render:
         renderHdr( tracerouteModel )
   elif traceroute.HasField( 'srModel' ):
      srModel = traceroute.srModel
      tracerouteModel = getThreadLocalData( 'tracerouteModel' )
      tracerouteModel.srModel = SrModel( prefix=str( srModel.prefix ),
                                         algorithm=str( srModel.algorithm ),
                                    igpProtocolType=str( srModel.igpProtocolType ) )
      if render:
         renderHdr( tracerouteModel )
   elif traceroute.HasField( 'rsvpModel' ):
      rsvpModel = traceroute.rsvpModel
      tracerouteModel = getThreadLocalData( 'tracerouteModel' )
      tracerouteModel.rsvpModel = RsvpModel( session=str( rsvpModel.session ),
                                             lspId=rsvpModel.lspId,
                                             tunnel=str( rsvpModel.tunnel ),
                                             subTunnelId=rsvpModel.subTunnelId )
      if render:
         renderHdr( tracerouteModel )
   elif traceroute.HasField( 'errInfo' ):
      errInfo = traceroute.errInfo
      tracerouteModel = getThreadLocalData( 'tracerouteModel' )
      tracerouteModel.errInfo += [ ErrInfo( error=errInfo.err ) ]
      if render:
         print( errInfo.err )
   elif traceroute.HasField( 'mplsPath' ):
      mplsPath = traceroute.mplsPath
      tracerouteModel = getThreadLocalData( 'tracerouteModel' )
      tracerouteModel.viaModel[ -1 ].path[ -1 ].pathResult = (
         str( mplsPath.pathResult ) )
      if render:
         print( tracerouteModel.viaModel[ -1 ].path[ -1 ].pathResult )
   elif traceroute.HasField( 'mplsVia' ):
      mplsVia = traceroute.mplsVia
      tracerouteModel = getThreadLocalData( 'tracerouteModel' )
      labels = [ label for label in mplsVia.labelStack ]
      if mplsVia.statisticsVia:
         tracerouteModel.statsModel += [ MplsVia( nextHopIp=str( mplsVia.nextHopIp ),
                                                  labelStack=labels,
                                                  interface=mplsVia.interface ) ]
      else:
         tracerouteModel.viaModel += [ MplsVia( nextHopIp=str( mplsVia.nextHopIp ),
                                                labelStack=labels ) ]
      if render:
         # render the most recently added via from the list
         if mplsVia.statisticsVia:
            renderStatisticsHeader( tracerouteModel )
            renderStatsVia( tracerouteModel.statsModel[ -1 ] )
         else:
            renderVia( tracerouteModel.viaModel[ -1 ] )
   elif traceroute.HasField( 'mplsTracerouteStatisticsSummary' ):
      summary = traceroute.mplsTracerouteStatisticsSummary
      tracerouteModel = getThreadLocalData( 'tracerouteModel' )
      tracerouteModel.statsModel[ -1 ].stats = MplsTracerouteStatisticsSummary(
         packetsTransmitted=summary.packetsTransmitted,
         packetsReceived=summary.packetsReceived,
         packetLoss=summary.packetLoss,
         replyTime=summary.replyTime )
      if render:
         renderStatisticsSummary( tracerouteModel.statsModel[ -1 ].stats )
   elif traceroute.HasField( 'mplsTracerouteStatistics' ):
      statistics = traceroute.mplsTracerouteStatistics
      tracerouteModel = getThreadLocalData( 'tracerouteModel' )
      tracerouteModel.statsModel[ -1 ].stats.mplsStatistics += (
         [ MplsTracerouteStatistics(
            destination=statistics.destination,
            roundTripTimeMinUs=statistics.roundTripTimeMinUs,
            roundTripTimeMaxUs=statistics.roundTripTimeMaxUs,
            roundTripTimeAvgUs=statistics.roundTripTimeAvgUs,
            packetsReceived=statistics.packetsReceived )
         ] )
      if render:
         renderStatistics(
            tracerouteModel.statsModel[ -1 ].stats.mplsStatistics[ -1 ] )
   elif traceroute.HasField( 'hopPktInfo' ):
      tracerouteModel = getThreadLocalData( 'tracerouteModel' )
      hopPktInfo = traceroute.hopPktInfo
      # If no path create one to store the hopPktInfo for the path.
      # Also if previous path already has path result create one to store
      # new hopPktInfos as this is a new path then.
      if ( not tracerouteModel.viaModel[ -1 ].path or
           tracerouteModel.viaModel[ -1 ].path[ -1 ].pathResult ):
         tracerouteModel.viaModel[ -1 ].path += [ MplsPath() ]
      # store latest received hopPktInfo to the
      # current viaModel path being explored.
      tracerouteModel.viaModel[ -1 ].path[ -1 ].hopPkt += (
         [ HopPktInfo( replyHost=str( hopPktInfo.replyHost ),
                       hopMtu=hopPktInfo.hopMtu,
                       roundTrip=hopPktInfo.roundTrip,
                       roundTripUs=hopPktInfo.roundTripUs,
                       oneWayDelay=hopPktInfo.oneWayDelay,
                       retCode=str( hopPktInfo.retCode ),
                       ttl=hopPktInfo.ttl )
         ] )
      if render:
         # render the most recently added hopPkt from the most recent via in list
         tracerouteReplyHdrRender(
            tracerouteModel.viaModel[ -1 ].path[ -1 ].hopPkt[ -1 ] )
   elif traceroute.HasField( 'intfAndLabelStackInfo' ):
      intfAndLabelStackInfo = traceroute.intfAndLabelStackInfo
      labels = [ label for label in intfAndLabelStackInfo.labelStack ]
      tracerouteModel = getThreadLocalData( 'tracerouteModel' )
      # store intfAndLabelStackInfo to the current hopPkt being explored
      hopPkt = tracerouteModel.viaModel[ -1 ].path[ -1 ].hopPkt[ -1 ]
      hopPkt.intfAndLabelStackInfo = (
         IntfAndLabelStackInfo( labelStack=labels,
                                addrType=str( intfAndLabelStackInfo.addrType ),
                                intfAddr=str( intfAndLabelStackInfo.intfAddr ),
                                intfIndex=intfAndLabelStackInfo.intfIndex,
                                addr=str( intfAndLabelStackInfo.addr ) ) )
      if render:
         # render the most recently added labelStackInfo
         tracerouteLabelStackRender( hopPkt.intfAndLabelStackInfo )
   elif traceroute.HasField( 'downstreamInfos' ):
      downstreamInfos = traceroute.downstreamInfos
      dsInfos = []
      for dsInfo in downstreamInfos.downstreamInfo:
         labels = [ label for label in dsInfo.labelStack ]
         popFecTypes = [ fec for fec in dsInfo.popFecTypeList.fecTypeList ]
         pushFecTypes = [ fec for fec in dsInfo.pushFecTypeList.fecTypeList ]
         popFecTypeList = ( FecTypeList( fecTypeList=popFecTypes ) if popFecTypes
                              else None )
         pushFecTypeList = ( FecTypeList( fecTypeList=pushFecTypes ) if pushFecTypes
                               else None )

         dsInfos.append( DownstreamInfo( addrType=str( dsInfo.addrType ),
                                         dsIntfAddr=str( dsInfo.dsIntfAddr ),
                                         dsIntfIndex=dsInfo.dsIntfIndex,
                                         dsIpAddr=str( dsInfo.dsIpAddr ),
                                         labelStack=labels,
                                         retCode=str( dsInfo.retCode ),
                                         dsType=str( dsInfo.dsType ),
                                         popFecTypeList=popFecTypeList,
                                         pushFecTypeList=pushFecTypeList ) )
      tracerouteModel = getThreadLocalData( 'tracerouteModel' )
      # store dsInfo to the current hopPkt being explored
      tracerouteModel.viaModel[ -1 ].path[ -1 ].hopPkt[ -1 ].dsInfo = dsInfos
      if render:
         # render the most recently added dsInfo
         tracerouteDownstreamInfoRender(
            tracerouteModel.viaModel[ -1 ].path[ -1 ].hopPkt[ -1 ].dsInfo )
   else:
      # should never come here as all models covered up.
      bt8( "Received unknown model", bv( traceroute.SerializeToString() ) )
   sys.stdout.flush()
   # pylint: enable=no-member
