# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import errno
import collections
from IpLibConsts import DEFAULT_VRF
from ClientCore import traceroute
from ClientCommonLib import (
   getNhgId, 
   labelStackToList,
   resolveHierarchical,
   isIpv6Addr, 
   resolveNexthop,
   getTunnelNhgName,
   _resolveSrTePolicyTunnels,
   getStaticFec,
   isNexthopGroupVia, 
   getNexthopGroupId,
   getNhgIdToName,
   isNexthopGroupTunnelVia,
   DynTunnelIntfId,
   resolveNhgTunnelFibEntry,
   getRsvpTunnelInfo,
   getRsvpFec,
   getDsMappingInfo,
   getLdpFec,
   getMldpFec,
   getProtocolIpFec, 
   getSrFec,
   IPV4, IPV6,
   MplsLabel, MldpInfo,
   getBgpLuTunnelFibEntry,
   getNhAndLabelFromTunnelFibVia,
   validateBgpLuResolvedPushVia,
   LspPingTypeBgpLu,
   LspPingTypeOspfSr,
   sendViaSocket,
   generateMplsEntropyLabel,
   getIntfPrimaryIpAddr,
   fillEntropyLabelPlaceholders,
   getNhgTunnelLabelStack,
)
from TracerouteModel_pb2 import ( # pylint: disable=no-name-in-module
   GenericModel,
   MplsTraceroute,
   MplsTracerouteHdr,
   MplsTracerouteStatistics,
   MplsTracerouteStatisticsSummary,
   MplsVia,
   RsvpModel,
   SrModel,
)
from ClientState import getGlobalState
from MplsTracerouteClientLib import sendOrRenderTracerouteErr
from ForwardingHelper import getNhgSize

import random
import sys
import Tac
from TypeFuture import TacLazyType

# ---------------------------------------------------------
#                   Local Utils 
# ---------------------------------------------------------

IpGenAddr = TacLazyType( 'Arnet::IpGenAddr' )
IpGenPrefix = TacLazyType( 'Arnet::IpGenPrefix' )
MplsLabel = TacLazyType( 'Arnet::MplsLabel' )
MplsLabelStack = TacLazyType( 'Arnet::MplsLabelOperation' )
NexthopGroupType = TacLazyType( 'Routing::NexthopGroup::NexthopGroupType' )
TunnelId = TacLazyType( 'Tunnel::TunnelTable::TunnelId' )
igpProtocolType = TacLazyType( 'LspPing::LspPingSrIgpProtocol' )
implicitNull = TacLazyType( "Arnet::MplsLabel" ).implicitNull

def getL3IntfMtu( l3Intf, mount ):
   return mount.allIntfStatusDir.intfStatus[ l3Intf ].mtu

# ---------------------------------------------------------
#                   LspTraceroute BGP LU
# ---------------------------------------------------------

def handleLspTracerouteBgpLu( prefix, mount, src=None, dst=None, smac=None,
                              dmac=None, vrf=None, interface=None, interval=1,
                              count=None, label=None, verbose=False, entry=None,
                              tc=None, nexthop=None, standard=None, size=None,
                              padReply=False, egressValidateAddress=None,
                              multiLabel=1, tos=None, dstype=None,
                              **kwargs ):
   state = getGlobalState()
   viaInfo, clientIdToVias = [], {}
   resolvedPushVia = collections.defaultdict( list )
   unresolvedPushVia = []
   ipv = IPV6 if isIpv6Addr( str( prefix ) ) else IPV4

   tunnelFibEntry, err = getBgpLuTunnelFibEntry( mount, prefix )
   if err:
      print( err )
      return errno.EINVAL

   for tunnelVia in tunnelFibEntry.tunnelVia.values():
      nextHopAndLabel, err = getNhAndLabelFromTunnelFibVia( mount, tunnelVia )
      if err:
         print( err )
         return errno.EINVAL
      nexthopIp = nextHopAndLabel.nextHopIp
      labels = nextHopAndLabel.label
      intfId = nextHopAndLabel.intfId

      # If nexthop and/or label stack are specified, use only the via that has
      # that nexthop/label stack.
      if nexthop and IpGenAddr( nexthop ) != nexthopIp:
         continue
      if label and label != labels:
         continue

      nexthopAddr = nexthopIp.v4Addr if ipv == IPV4 else nexthopIp.v6Addr
      # Map L3 nexthop to L2 nexthop
      nexthopIntf, nexthopEthAddr = resolveNexthop( mount, state, nexthopAddr,
                                                    intf=intfId )
      if not ( nexthopIntf and nexthopEthAddr ): # pylint: disable=no-else-continue
         unresolvedPushVia.append( ( nexthopIp, labels ) )
         continue
      else:
         key = ( nexthopIntf, tuple( labels ), nexthopEthAddr )
         resolvedPushVia[ key ].append( ( nexthopIp, labels, nexthopIntf ) )
         # Use only first tunnelVia for traceroute
         break

   if nexthop:
      err = validateBgpLuResolvedPushVia( resolvedPushVia, nexthop, label )
      if err:
         print( err )
         return errno.EINVAL

   for idx, ( key, val ) in enumerate( resolvedPushVia.items() ):
      clientIdToVias[ idx ] = val
      viaInfo.append( key )

   if not viaInfo:
      # No resolved interface
      print( 'Failed to find a valid output interface' )
      return errno.EINVAL

   setFecValidateFlag = bool( egressValidateAddress )
   ( nexthopIntf, labelStack, nexthopEthAddr ) = viaInfo[ 0 ]
   nexthopIp = clientIdToVias[ 0 ][ 0 ][ 0 ]
   dsMappingInfo = getDsMappingInfo( str( nexthopIp ), labelStack,
                                     getL3IntfMtu( nexthopIntf, mount ),
                                     multipath=False, baseip='127.0.0.0',
                                     numMultipathBits=0 )
   print( f'  via {nexthopIp}, label stack: {list( labelStack )}' )
   return traceroute( mount, state, nexthopIntf, labelStack,
                      src, dst, smac, nexthopEthAddr, count=1, interval=1,
                      protocol=LspPingTypeBgpLu, prefix=prefix,
                      ipv=ipv, tc=tc, standard=standard, size=size,
                      padReply=padReply, setFecValidateFlag=setFecValidateFlag,
                      tos=tos, egressValidateAddress=egressValidateAddress,
                      dsMappingInfo=dsMappingInfo, dstype=dstype,
                      lookupLabelCount=multiLabel ).retVal

# ---------------------------------------------------------
#                   LspTraceroute nexthop-group
# ---------------------------------------------------------

def handleLspTracerouteNhg( nhgName, mount, entry, src=None, dst=None, smac=None,
                            dmac=None, count=None, interval=1, verbose=False,
                            vrf=None, tc=None, nhgTunnel=None, standard=None,
                            size=None, padReply=False, egressValidateAddress=None,
                            multiLabel=1, tos=None, backup=False, **kwargs ):
   nhgId = getNhgId( nhgName, mount )
   if nhgId is None:
      print( 'Nexthop-group %s does not exist.' % nhgName )
      return errno.EINVAL
   
   # we are expecting a programmed MPLS nexthop-group
   nhgConfig = mount.routingNhgConfig.nexthopGroup.get( nhgName )
   if not nhgConfig:
      print( f'Nexthop-group {nhgId} cannot be found' )
      return errno.EINVAL
   if nhgConfig.type != NexthopGroupType.mpls:
      print( 'Nexthop-group %s is not in MPLS type.' % nhgName )
      return errno.EINVAL

   entrySize = getNhgSize( nhgConfig, isBackup=backup )

   if entry is None:
      entry = random.randint( 0, entrySize - 1 )
      if nhgTunnel is not None:
         print( ( "Traceroute over nexthop-group tunnel index %s,"
                 " nexthop-group %s Entry %d" ) % ( nhgTunnel, nhgName, entry ) )
      else:
         print( "Traceroute over nexthop-group %s Entry %d" % ( nhgName, entry ) )
   
   # validate entry
   if entry < 0 or entry >= entrySize:
      print( 'Invalid nexthop-group entry: %d' % entry )
      return errno.EINVAL

   # entry resolution
   vrf = vrf or DEFAULT_VRF
   hwVrfStatus = mount.routingHwNexthopGroupStatus.vrfStatus.get( vrf )
   if not hwVrfStatus:
      print( 'Cannot find next-hop group status.' )
      return errno.EINVAL
   if nhgId not in hwVrfStatus.nexthopGroupAdjacency:
      print( 'Nexthop-group %s is not programmed.' % nhgName )
      return errno.EINVAL

   nhgAdj = hwVrfStatus.nexthopGroupAdjacency.get( nhgId )
   nhgAdjVia = nhgAdj.backupVia if backup else nhgAdj.via

   if not nhgAdjVia:
      print( f'Via for Nexthop-group {nhgName} cannot be resolved' )
      return errno.EINVAL
   via = nhgAdjVia.get( entry )
   viaInfoList = []
   if via and via.l2Via.vlanId != 0:
      # resolved tunnel

      # XXX We are not supporting a nexthop-group having both V4 & V6 tunnel dests.
      ipv = IPV6 if isIpv6Addr( via.hop.stringValue ) else IPV4
      labelStack = labelStackToList(
                   getNhgTunnelLabelStack( nhgConfig, entry, backup ) )
      nhMac = via.l2Via.macAddr
      nhIntf = via.l3Intf
      viaInfoTuple = ( nhIntf, labelStack, nhMac )
      viaInfoList.append( viaInfoTuple )
   elif via and via.nextLevelFecIndex:
      labelStack = labelStackToList(
                   getNhgTunnelLabelStack( nhgConfig, entry, backup ) )
      nextHopAndLabels = resolveHierarchical( mount, fecId=via.nextLevelFecIndex )
      ipv = IPV6 if isIpv6Addr( via.route.stringValue ) else IPV4
      state = getGlobalState()
      if nextHopAndLabels:
         for nhAndLabel in nextHopAndLabels:
            nexthop = nhAndLabel.nextHopIp
            nhIntf, nhMac = resolveNexthop( mount, state, nexthop )
            if not nhIntf or not nhMac:
               continue
            viaInfoTuple = ( nhIntf, labelStack, nhMac )
            viaInfoList.append( viaInfoTuple )
   if not viaInfoList:
      print( f'Nexthop-group {nhgName}: {"backup " if backup else ""}'
             f'entry {entry} not configured or resolved' )

      return errno.EINVAL

   print( f'{"Backup entry" if backup else "Entry"} {entry}' )

   setFecValidateFlag = bool( egressValidateAddress )
   for viaInfoTuple in viaInfoList:
      state = getGlobalState()
      retVal = traceroute( mount, state, viaInfoTuple[ 0 ], viaInfoTuple[ 1 ], src,
                           dst, smac, viaInfoTuple[ 2 ], 1, 1, ipv=ipv, tc=tc,
                           standard=standard, size=size, padReply=padReply,
                           setFecValidateFlag=setFecValidateFlag, tos=tos,
                           egressValidateAddress=egressValidateAddress,
                           lookupLabelCount=multiLabel ).retVal
      if retVal < 0:
         return retVal
      print( '\n' )
   return 0

# ---------------------------------------------------------
#           LspTraceroute nexthop-group tunnel
# ---------------------------------------------------------

def handleLspTracerouteNhgTunnel( endpoint, mount, src=None, dst=None, smac=None,
                                  dmac=None, count=None, interval=1, verbose=False,
                                  vrf=None, tc=None, entry=None, standard=None,
                                  size=None, padReply=False, tos=None,
                                  egressValidateAddress=None, multiLabel=1,
                                  **kwargs ):
   ( nhgName, tunnelIndex, err ) = getTunnelNhgName( mount, endpoint )
   if err is not None:
      print( err )
      return errno.EINVAL

   if egressValidateAddress == 'default':
      egressValidateAddress = endpoint

   return handleLspTracerouteNhg( nhgName, mount, src=src, dst=dst, smac=smac,
                                  dmac=dmac, vrf=vrf, interval=interval,
                                  count=count, verbose=verbose, entry=entry,
                                  tc=tc, nhgTunnel=tunnelIndex, standard=standard,
                                  size=size, padReply=padReply, tos=tos,
                                  egressValidateAddress=egressValidateAddress,
                                  multiLabel=multiLabel )

# ---------------------------------------------------------
#                   LspTraceroute SR-TE
# ---------------------------------------------------------

def handleLspTracerouteSrTe( endpoint, mount, color, trafficAf, dst=None, smac=None,
                             dmac=None, count=None, interval=1, verbose=False,
                             tc=None, standard=None, size=None, padReply=False,
                             egressValidateAddress=None, tos=None, **kwargs ):

   ipv = IPV6 if isIpv6Addr( str( endpoint ) ) else IPV4
   if trafficAf:
      ipv = IPV6 if trafficAf == 'v6' else IPV4
   tunnelToAdjacencies = _resolveSrTePolicyTunnels( mount, endpoint, color,
                                                    trafficAf=trafficAf )
   if tunnelToAdjacencies == errno.EINVAL:
      return errno.EINVAL

   # If the user does not provide any egress address, we consider the endpoint
   # as the egressValidateAddress by default.
   setFecValidateFlag = bool( egressValidateAddress )
   if egressValidateAddress == 'default':
      egressValidateAddress = endpoint


   for tunnel in sorted( tunnelToAdjacencies.keys() ):
      state = getGlobalState()
      labelTup, nhMac, l3Intf = tunnelToAdjacencies[ tunnel ]
      labelStack = list( labelTup )

      print( f'Segment list label stack: {labelStack}' )
      retVal = traceroute( mount, state, l3Intf, tuple( labelStack ), None,
                           dst, smac, nhMac, 1, 1, ipv=ipv, tc=tc,
                           standard=standard, size=size, padReply=padReply,
                           setFecValidateFlag=setFecValidateFlag, tos=tos,
                           egressValidateAddress=egressValidateAddress ).retVal
      if retVal < 0 :
         return retVal
      print( '\n' )

   return 0

# ---------------------------------------------------------
#                   LspTraceroute static
# ---------------------------------------------------------

def handleLspTracerouteStatic( prefix, mount, label, nexthop, src, dst, vrf,
                               smac, dmac, count, interval, verbose, tc,
                               standard=None, size=None, padReply=False,
                               egressValidateAddress=None, multiLabel=1,
                               tos=None, **kwargs ):
   fec, err = getStaticFec( mount, prefix )
   if fec is None:
      return err
   
   if not fec.via:
      print( 'No adjacency found for prefix %s' % prefix )
      return errno.EINVAL

   state = getGlobalState()
   unresolvedPushVias = []
   resolvedPushVias = {}
   nhgNames = []
   nhgTunIdexToNhgName = {}
   for via in fec.via.values():
      if isNexthopGroupVia( via ):
         nexthopGroupId = getNexthopGroupId( via )
         nhgName = getNhgIdToName( nexthopGroupId, mount )
         if not nhgName:
            print( 'No nexthop-group with id %d' % nexthopGroupId )
            return errno.EINVAL
         if nexthop and label is not None:
            print( ( '%s is a nexthop-group route. '
                     'Please use traceroute mpls nexthop-group.' ) % prefix )
            return errno.EINVAL
         print( '{}: nexthop-group route (nexthop-group name: {})'.format( prefix,
                                                                       nhgName ) )
         nhgNames.append( nhgName )
         continue
      if isNexthopGroupTunnelVia( via ):
         tunnelId = DynTunnelIntfId.tunnelId( via.intfId )
         ( nhgName, tunnelIndex, err ) = resolveNhgTunnelFibEntry( mount, tunnelId )
         if err is not None:
            print( err )
            return errno.EINVAL
         nhgTunIdexToNhgName[ tunnelIndex ] = nhgName
         nhgNames.append( nhgName )
         continue
      if not MplsLabel( via.mplsLabel ).isValid():
         continue
      l3Intf, nhMac = resolveNexthop( mount, state, via.hop )
      if hasattr( via.hop, 'stringValue' ):
         viaNexthop = via.hop.stringValue
      else:
         assert isinstance( via.hop, str )
         viaNexthop = via.hop
      if not ( l3Intf and nhMac ):
         unresolvedPushVias.append( ( viaNexthop, via.mplsLabel ) )
         continue
      resolvedPushVias[ ( viaNexthop, via.mplsLabel ) ] = ( l3Intf, nhMac )

   if nhgNames:
      # We are only choosing one of the NHGs because we want to keep consistent with
      # how traceroute static only picks one of the vias.
      selectedNhg = random.choice( nhgNames )
      selectedNhgTunnel = None
      if nhgTunIdexToNhgName:
         selectedNhgTunnel = random.choice( list( nhgTunIdexToNhgName ) )
         selectedNhg = nhgTunIdexToNhgName[ selectedNhgTunnel ]
      return handleLspTracerouteNhg( selectedNhg, mount, src=src, dst=dst, smac=smac,
                                     dmac=dmac, vrf=vrf, interval=interval,
                                     count=count, verbose=verbose, entry=None,
                                     tc=tc, nhgTunnel=selectedNhgTunnel,
                                     standard=standard, size=size, tos=tos,
                                     padReply=padReply, multiLabel=multiLabel,
                                     egressValidateAddress=egressValidateAddress )
   if label and nexthop:
      selectedPushVia = ( nexthop, label[ 0 ] )
   else:
      selectedPushVia = random.choice( list( resolvedPushVias ) )\
                                     if resolvedPushVias else None
   labelStack = []
   nextHop = ''
   viaStr = '  '
   if selectedPushVia:
      assert isinstance( selectedPushVia[ 1 ], int )
      nextHop = selectedPushVia[ 0 ]
      labelStack = [ selectedPushVia[ 1 ] ]
      viaStr += f'via {nextHop}, label stack: {labelStack}'
   print( viaStr )
   if not selectedPushVia or selectedPushVia in unresolvedPushVias:
      print( 'via not resolved' )
      return errno.EINVAL
   if selectedPushVia not in resolvedPushVias:
      print( 'via not found' )
      return errno.EINVAL

   ipv = IPV6 if isIpv6Addr( prefix ) else IPV4
   l3Intf, nhMac = resolvedPushVias[ selectedPushVia ]
   setFecValidateFlag = bool( egressValidateAddress )

   # Send dsmap if no NHG is being used
   dsMappingInfo = getDsMappingInfo( nextHop, labelStack,
                                     getL3IntfMtu( l3Intf, mount ),
                                     multipath=False, baseip='127.0.0.0',
                                     numMultipathBits=0 )

   retVal = traceroute( mount, state, l3Intf, labelStack, src, dst,
                        smac, nhMac, 1, 1, ipv=ipv, tc=tc, standard=standard,
                        size=size, padReply=padReply, tos=tos,
                        dsMappingInfo=dsMappingInfo,
                        setFecValidateFlag=setFecValidateFlag,
                        egressValidateAddress=egressValidateAddress,
                        lookupLabelCount=multiLabel ).retVal
   return retVal

# ---------------------------------------------------------
#                   LspTraceroute RSVP
# ---------------------------------------------------------
def handleLspTracerouteRsvp( prefix, mount, label, nexthop, src, dst, vrf,
                             smac, dmac, count, interval, verbose, tc,
                             multipath, multipathbase, multipathcount, sock,
                             session=None, lsp=None, standard=None, size=None,
                             padReply=False, tos=None, dstype=None, **kwargs ):
   protocol = kwargs[ 'type' ]
   rsvpTunnel = kwargs.get( 'tunnel' )
   rsvpSubTunnel = None
   if rsvpTunnel:
      rsvpSubTunnel = kwargs.get( 'sub_tunnel_id' )
      nextHopsAndLabels, err, rsvpSpIds, rsvpSenderAddr, _ = \
         getRsvpTunnelInfo( mount, rsvpTunnel, rsvpSubTunnel )
   else:
      nextHopsAndLabels, err, rsvpSpIds, rsvpSenderAddr, _ = \
         getRsvpFec( mount, session, lsp )
   if nextHopsAndLabels is None or rsvpSpIds is None:
      return err
   # The CLI/script enforces that either a specific LSP is specified (session + LSP
   # IDs) or that a name is specified. In the latter case, it is possible that we
   # find more than on matching LSP, in this case we reject the input and will later
   # display the potential matches to the user (BUG286407)
   if len( rsvpSpIds ) > 1:
      if rsvpTunnel:
         err = 'More than one match for subtunnels, use argument `sub-tunnel`'
      else:
         err = ( 'More than one match for session "%s", use arguments `id` and `lsp`'
                 % ( session ) )
      sendOrRenderTracerouteErr( err, sock )
      return errno.EINVAL
   # For traceroute, it can be only one rsvpSpId. The CLI wouldn't allow to get
   # here with more than one LSP.
   rsvpSpId = rsvpSpIds[ 0 ]
   nexthop, label, _ = nextHopsAndLabels[ 0 ]
   dstIp = rsvpSpId.sessionId.dstIp
   # Send traceroute header
   mplsTracerouteHdr = MplsTracerouteHdr( protocol=protocol, dsInfoType=dstype )
   sendViaSocket( sock, MplsTraceroute( mplsTracerouteHdr=mplsTracerouteHdr ) )
   sessionStr = str( session ) if isinstance( session, int ) else session
   rsvpModel = RsvpModel( session=sessionStr,
                          lspId=lsp,
                          tunnel=rsvpTunnel,
                          subTunnelId=rsvpSubTunnel )
   sendViaSocket( sock, MplsTraceroute( rsvpModel=rsvpModel ) )
   state = getGlobalState()
   # Map L3 nexthop to L2 nexthop
   nexthopIntf, nexthopEthAddr = resolveNexthop( mount, state, nexthop )
   if not ( nexthopIntf and nexthopEthAddr ):
      err = 'via not found or not resolved'

   labelStack = label if isinstance( label, list ) else [ label ]
   ipv = IPV6 if isIpv6Addr( dstIp ) else IPV4
   # send via through socket
   nextHopIp = ( nexthop if isinstance( nexthop, str ) else
                 nexthop.stringValue )
   mplsVia = MplsVia( nextHopIp=nextHopIp,
                      labelStack=labelStack )
   sendViaSocket( sock, MplsTraceroute( mplsVia=mplsVia ) )
   l3Intf, nhMac = nexthopIntf, nexthopEthAddr

   # The label list could potentially contain the implicit null label. We leave
   # it in the list above when we pass it to the clientIdToVias because it is
   # nice and consistent to see it in the ping output, however we remove the
   # label before passing it to viaInfo so that the rest of the processing does
   # not have to special case it (the label isn't actually in the stack).
   # However, if the list ONLY contains the implicit null, we leave it in,
   # since the client config needs to have at least one label.
   if any( label != 3 for label in labelStack ):
      labelStack = [ l for l in labelStack if l != 3 ]

   dsMappingInfo = getDsMappingInfo( nexthop, labelStack,
                                     getL3IntfMtu( l3Intf, mount ),
                                     multipath, multipathbase, multipathcount )
   retVal = traceroute( mount, state, l3Intf, labelStack, src, dst,
                        smac, nhMac, 1, interval=5, ipv=ipv, rsvpSpId=rsvpSpId,
                        rsvpSenderAddr=rsvpSenderAddr, sock=sock, protocol=protocol,
                        dsMappingInfo=dsMappingInfo, tc=tc, standard=standard,
                        size=size, padReply=padReply, tos=tos ).retVal
   return retVal



# ---------------------------------------------------------
#                   LspTraceroute LDP
# ---------------------------------------------------------
def handleLspTracerouteLdp( prefix, mount, label, nexthop, src, dst, vrf,
                            smac, dmac, count, interval, verbose, tc,
                            multipath, multipathbase, multipathcount, sock,
                            maxTtl=None, standard=None, size=None, padReply=False,
                            dstype=None, tos=None, **kwargs ):
   protocol = kwargs[ 'type' ]
   # Always send mplsTracerouteHdr first before sending any other data.
   mplsTracerouteHdr = MplsTracerouteHdr( protocol=protocol,
                                          dsInfoType=dstype )
   sendViaSocket( sock,
                  MplsTraceroute( mplsTracerouteHdr=mplsTracerouteHdr ) )
   ldpModel = GenericModel( prefix=prefix )
   sendViaSocket( sock,
                  MplsTraceroute( ldpModel=ldpModel ) )

   if label is not None:
      err = 'Non-empty label arg provided'
      sendOrRenderTracerouteErr( err, sock )
      return errno.EINVAL

   # nextHopsAndLabels is a dictionary of NextHopAndLabel objects
   # keyed by a str generated using the getNextHopAndLabelKey function.
   nextHopsAndLabels, err, errVal = getLdpFec( mount, prefix,
                                               allowEntropyLabel=True )
   if not nextHopsAndLabels:
      sendOrRenderTracerouteErr( err, sock )
      return errVal

   res = 0
   # If we are unable to resolve any vias, we want to print an error at the end
   resolvedAnyVia = False

   # this loop will run multiple times only for multipath LDP/SR traceroute
   # with ECMP at origin
   for nhlEntry in nextHopsAndLabels:
      state = getGlobalState()
      ipv = IPV6 if isIpv6Addr( prefix ) else IPV4
      nextHopIp = nhlEntry.nextHopIp
      intfId = nhlEntry.intfId
      labelStack = ( nhlEntry.label if isinstance( nhlEntry.label, list ) else
                     [ nhlEntry.label ] )

      nextHopIntf, nextHopEthAddr = resolveNexthop( mount, state, nextHopIp,
                                                    intf=intfId )
      if not ( nextHopIntf and nextHopEthAddr ):
         # If it's multipath, we'd print the error and try next path. If it's regular
         # traceroute, we'd only print the error at end if we can't resolve any via
         if multipath:
            errOutput = 'via not found or not resolved for next-hop IP {}'.format(
                                                                          nextHopIp )
            sendOrRenderTracerouteErr( errOutput, sock )
         continue
      resolvedAnyVia = True

      udpPam = state.lspPingClientRootUdpPam
      genPrefix = IpGenPrefix( str( prefix ) )
      el = generateMplsEntropyLabel(
            srcIp=( src or getIntfPrimaryIpAddr( mount, nextHopIntf, ipv ) ),
            dstIp=( genPrefix.v4Addr if ipv == IPV4 else genPrefix.v6Addr ),
            udpSrcPort=( 0 if not udpPam else udpPam.rxPort ) )
      fillEntropyLabelPlaceholders( labelStack, el )

      mplsVia = MplsVia()
      mplsVia.nextHopIp = ( nextHopIp if isinstance( nextHopIp, str ) else
                            nextHopIp.stringValue )
      # pylint: disable=no-member
      mplsVia.labelStack[ : ] = labelStack
      # pylint: enable=no-member
      sendViaSocket( sock, MplsTraceroute( mplsVia=mplsVia ) )

      multipathcount = multipathcount if multipath else 0
      dsMappingInfo = getDsMappingInfo( str( nextHopIp ), labelStack,
                                        getL3IntfMtu( nextHopIntf, mount ),
                                        multipath, multipathbase, multipathcount )
      sys.stdout.flush()

      # Strip the implicit nulls from the label stack that will be going on the wire
      # if there are other labels
      if any( label != 3 for label in labelStack ):
         labelStack = [ l for l in labelStack if l != 3 ]

      # Until FEC validation for stitched entries are implemented, skip FEC
      # validation at transit nodes for protocol-independent FEC using 'generic'
      # protocol. Note that FEC validation at the egress node is always
      # performed and is not affected by this flag.
      setFecValidateFlag = protocol != 'generic'
      res, _, _ = traceroute( mount, state, nextHopIntf, labelStack,
                              src, dst, smac, nextHopEthAddr, 1,
                              interval=interval, hops=maxTtl,
                              prefix=prefix, ipv=ipv, protocol=protocol,
                              dsMappingInfo=dsMappingInfo, tc=tc,
                              multipath=multipath, nextHopIp=nextHopIp,
                              standard=standard, size=size, tos=tos,
                              padReply=padReply, dstype=dstype,
                              setFecValidateFlag=setFecValidateFlag,
                              sock=sock )

      if not multipath:
         return res

   if not multipath and not resolvedAnyVia:
      err = 'via not found or not resolved'
      sendOrRenderTracerouteErr( err, sock )
      return errno.EINVAL
   return res


# ---------------------------------------------------------
#                   LspTraceroute SR
# ---------------------------------------------------------
def handleLspTracerouteSr( prefix, mount, label, nexthop, src, dst, vrf,
                           smac, dmac, count, interval, verbose, tc,
                           multipath, sock,
                           maxTtl=None, standard=None, size=None, padReply=False,
                           dstype=None, tos=None, **kwargs ):
   protocol = kwargs[ 'type' ]
   # Always send mplsTracerouteHdr first before sending any other data.
   mplsTracerouteHdr = MplsTracerouteHdr( protocol=protocol,
                                          dsInfoType=dstype )
   sendViaSocket( sock,
                  MplsTraceroute( mplsTracerouteHdr=mplsTracerouteHdr ) )

   algorithmName = kwargs.get( 'algorithmName' )
   algorithm = kwargs.get( 'algorithm' )
   # srTunnel will be one of these values - 'isis', 'ospf' or None
   srTunnel = kwargs.get( 'srTunnel' )
   nextHopsAndLabels, err, errVal, igpProtocol = getSrFec( mount, prefix, algorithm,
                                                           srTunnel )
   srModel = SrModel( prefix=prefix, algorithm=algorithmName,
                        igpProtocolType=igpProtocol )
   sendViaSocket( sock, MplsTraceroute( srModel=srModel ) )

   if label is not None:
      err = 'Non-empty label arg provided'
      sendOrRenderTracerouteErr( err, sock )
      return errno.EINVAL

   fecValidateFlag = True
   if protocol == 'generic':
      fecValidateFlag = False

   if not nextHopsAndLabels:
      sendOrRenderTracerouteErr( err, sock )
      return errVal

   if igpProtocol == igpProtocolType.ospf:
      protocol = LspPingTypeOspfSr
   res = 0
   # If we are unable to resolve any vias, we want to print an error at the end
   resolvedAnyVia = False

   # this loop will run multiple times only for multipath SR traceroute
   # with ECMP at origin
   for nhlEntry in nextHopsAndLabels:
      state = getGlobalState()
      ipv = IPV6 if isIpv6Addr( prefix ) else IPV4
      nextHopIp = nhlEntry.nextHopIp
      intfId = nhlEntry.intfId
      labelStack = ( nhlEntry.label if isinstance( nhlEntry.label, list ) else
                     [ nhlEntry.label ] )

      nextHopIntf, nextHopEthAddr = resolveNexthop( mount, state, nextHopIp,
                                                    intf=intfId )
      if not ( nextHopIntf and nextHopEthAddr ):
         # For regular traceroute, we'd only print the error
         # at end if we can't resolve any via
         continue
      resolvedAnyVia = True

      mplsVia = MplsVia()
      mplsVia.nextHopIp = ( nextHopIp if isinstance( nextHopIp, str ) else
                            nextHopIp.stringValue )
      # pylint: disable=no-member
      mplsVia.labelStack[ : ] = labelStack
      # pylint: enable=no-member
      sendViaSocket( sock, MplsTraceroute( mplsVia=mplsVia ) )

      dsMappingInfo = getDsMappingInfo( str( nextHopIp ), labelStack,
                                        getL3IntfMtu( nextHopIntf, mount ),
                                        multipath, baseip='127.0.0.0',
                                        numMultipathBits=0 )
      sys.stdout.flush()

      # Strip the implicit nulls from the label stack that will be going on the wire
      # if there are other labels
      if any( label != implicitNull for label in labelStack ):
         labelStack = [ l for l in labelStack if l != implicitNull ]

      res, _, _ = traceroute( mount, state, nextHopIntf, labelStack,
                              src, dst, smac, nextHopEthAddr, 1,
                              interval=interval, hops=maxTtl,
                              prefix=prefix, ipv=ipv,
                              protocol=protocol,
                              dsMappingInfo=dsMappingInfo, tc=tc,
                              multipath=multipath, nextHopIp=nextHopIp,
                              standard=standard, size=size, tos=tos,
                              padReply=padReply, dstype=dstype,
                              setFecValidateFlag=fecValidateFlag,
                              sock=sock, algorithm=algorithm )

      if not multipath:
         return res

   if not resolvedAnyVia:
      err = 'via not found or not resolved'
      sendOrRenderTracerouteErr( err, sock )
      return errno.EINVAL
   return res

# ---------------------------------------------------------
#                   LspTraceroute MLDP
# ---------------------------------------------------------
def sendStatisticSummaryViaSocket( sock, t, txPkts, replyHostRtts, vias ):
   for nextHopIp, labelStack, intf in vias:
      mplsVia = MplsVia( nextHopIp=nextHopIp,
                         labelStack=labelStack,
                         interface=intf,
                         statisticsVia=True )
      sendViaSocket( sock, MplsTraceroute( mplsVia=mplsVia ) )
   recvNum = ( 0 if replyHostRtts is None or not replyHostRtts
            else sum( len( rtts ) for rtts in replyHostRtts.values() ) )
   lossRate = 100 - recvNum * 100 // txPkts
   mplsTracerouteStatisticsSummary = MplsTracerouteStatisticsSummary(
                                       packetsTransmitted=txPkts,
                                       packetsReceived=recvNum,
                                       packetLoss=lossRate,
                                       replyTime=t )
   sendViaSocket(
      sock, MplsTraceroute(
         mplsTracerouteStatisticsSummary=mplsTracerouteStatisticsSummary ) )
   if replyHostRtts:
      for host, rtts in sorted( replyHostRtts.items() ):
         if rtts:
            mplsTracerouteStatistics = MplsTracerouteStatistics(
                     destination=host.stringValue,
                     roundTripTimeMinUs=int( min( rtts ) * 1000 ),
                     roundTripTimeMaxUs=int( max( rtts ) * 1000 ),
                     roundTripTimeAvgUs=int( ( sum( rtts ) / len( rtts ) ) * 1000 ),
                     packetsReceived=len( rtts ) )
            sendViaSocket( sock, MplsTraceroute(
               mplsTracerouteStatistics=mplsTracerouteStatistics ) )

def handleLspTracerouteMldp( prefix, mount, label, nexthop, src, dst, vrf,
                             smac, dmac, count, interval, verbose, tc,
                             multipath, sock, genOpqVal=None, sourceAddrOpqVal=None,
                             groupAddrOpqVal=None, jitter=None,
                             responderAddr=None, maxTtl=None,
                             standard=None, size=None, padReply=False,
                             dstype=None, tos=None, **kwargs ):
   protocol = kwargs[ 'type' ]
   mplsTracerouteHdr = MplsTracerouteHdr( protocol=protocol,
                                          dsInfoType=dstype )
   sendViaSocket( sock,
                  MplsTraceroute( mplsTracerouteHdr=mplsTracerouteHdr ) )
   ldpModel = GenericModel( prefix=prefix )
   sendViaSocket( sock,
                  MplsTraceroute( ldpModel=ldpModel ) )
   if label is not None:
      err = 'Non-empty label arg provided'
      sendOrRenderTracerouteErr( err, sock )
      return errno.EINVAL

   fecValidateFlag = True
   # nextHopsAndLabels is a dictionary of NextHopAndLabel objects
   # keyed by a str generated using the getNextHopAndLabelKey function.
   nextHopsAndLabels, err, errVal = getMldpFec( mount, prefix, genOpqVal,
                                                sourceAddrOpqVal, groupAddrOpqVal )
   if not nextHopsAndLabels:
      sendOrRenderTracerouteErr( err )
      return errVal

   res = 0
   # If we are unable to resolve any vias, we want to print an error at the end
   resolvedAnyVia = False

   # this loop will run multiple times only for multipath traceroute
   # with ECMP at origin
   for nhlEntry in nextHopsAndLabels:
      state = getGlobalState()
      ipv = IPV6 if isIpv6Addr( prefix ) else IPV4
      nextHopIp = nhlEntry.nextHopIp
      intfId = nhlEntry.intfId
      labelStack = ( nhlEntry.label if isinstance( nhlEntry.label, list ) else
                     [ nhlEntry.label ] )

      nextHopIntf, nextHopEthAddr = resolveNexthop( mount, state, nextHopIp,
                                                    intf=intfId )
      if not ( nextHopIntf and nextHopEthAddr ):
         # For regular traceroute, we'd only print the error
         # at end if we can't resolve any via
         continue
      resolvedAnyVia = True

      nextHopIp = ( nextHopIp if isinstance( nextHopIp, str ) else
                    nextHopIp.stringValue )
      mplsVia = MplsVia( nextHopIp=nextHopIp,
                         labelStack=labelStack )
      sendViaSocket( sock, MplsTraceroute( mplsVia=mplsVia ) )

      dsMappingInfo = getDsMappingInfo( str( nextHopIp ), labelStack,
                                        getL3IntfMtu( nextHopIntf, mount ),
                                        multipath, baseip='127.0.0.0',
                                        numMultipathBits=0 )
      mldpInfo = MldpInfo( genOpqVal, sourceAddrOpqVal, groupAddrOpqVal,
                           jitter, responderAddr )
      sys.stdout.flush()

      # Strip the implicit nulls from the label stack that will be going on the wire
      # if there are other labels
      if any( label != 3 for label in labelStack ):
         labelStack = [ l for l in labelStack if l != 3 ]

      startTime = Tac.now()
      res, txPkts, replyHostRtts = traceroute( mount, state, nextHopIntf, labelStack,
                                               src, dst, smac, nextHopEthAddr, 1,
                                               interval=interval, hops=maxTtl,
                                               prefix=prefix, ipv=ipv,
                                               protocol=protocol,
                                               dsMappingInfo=dsMappingInfo, tc=tc,
                                               multipath=multipath,
                                               nextHopIp=nextHopIp,
                                               mldpInfo=mldpInfo, standard=standard,
                                               size=size, tos=tos,
                                               padReply=padReply, dstype=dstype,
                                               setFecValidateFlag=fecValidateFlag,
                                               sock=sock )
      t = int( ( Tac.now() - startTime ) * 1000 ) # microsecond
      sys.stdout.flush()

      # traceroute statistics support only for mldp for now.
      sendStatisticSummaryViaSocket( sock, t, txPkts, replyHostRtts,
                                     [ ( nextHopIp, labelStack, nextHopIntf ) ] )

      if not multipath:
         return res

   if not resolvedAnyVia:
      err = 'via not found or not resolved'
      sendOrRenderTracerouteErr( err, sock )
      return errno.EINVAL
   return res

# ---------------------------------------------------------
#                   LspTraceroute raw
# ---------------------------------------------------------

def handleLspTracerouteRaw( prefix, mount, label, src, dst, smac, dmac,
                            interface, count, interval, verbose,
                            nexthop=None, tc=None, standard=None, size=None,
                            padReply=False, tos=None, **kwargs ):
   state = getGlobalState()
   if dmac is None or interface is None:
      if nexthop:
         interface, dmac = resolveNexthop( mount, state, nexthop )
      if interface is None or dmac is None:
         print( "Failed to find a valid output interface" )
         return errno.EINVAL 

   # Translation from default Ipv4 to default Ipv6 if not given
   ipv = IPV4
   if ( prefix and isIpv6Addr( prefix ) or
        nexthop and isIpv6Addr( nexthop ) or
        src and isIpv6Addr( src ) ):
      ipv = IPV6

   retVal = traceroute( mount, state, interface, label, src, dst, smac, dmac,
                        count, interval, verbose, ipv=ipv, tc=tc,
                        standard=standard, size=size, padReply=padReply,
                        tos=tos ).retVal
   return retVal

# ---------------------------------------------------------
#                   LspTraceroute generic
# ---------------------------------------------------------

def handleLspTracerouteGeneric( prefix, mount, **kwargs ):
   # Traceroute on stitched LSP, currently we support SR and LDP
   supports = { 'segment-routing' : handleLspTracerouteSr,
                'ldp' : handleLspTracerouteLdp }

   err = None
   errVal = errno.EINVAL
   genOpqVal = kwargs.get( 'genOpqVal' )
   sourceAddrOpqVal = kwargs.get( 'sourceAddrOpqVal' )
   groupAddrOpqVal = kwargs.get( 'groupAddrOpqVal' )
   algorithm = kwargs.get( 'algorithm' )
   for protocol, func in supports.items():
      nextHopsAndLabels, err, errVal, _ = getProtocolIpFec( mount, prefix, protocol,
                                                            genOpqVal,
                                                            sourceAddrOpqVal,
                                                            groupAddrOpqVal,
                                                            algorithm,
                                                            allowEntropyLabel=True )

      if not nextHopsAndLabels:
         continue

      ret = func( prefix, mount, **kwargs )
      return ret

   print( err )
   return errVal
