#!/usr/bin/env python3
# Copyright (c) 2015 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

try:
   import argparse
   import EntityManager
   import CliCommon
   import json
   import os
   import subprocess
   import sys
   import Tac
   from threading import Thread
   from CliDynamicSymbol import CliDynamicPlugin
   from CliPlugin.MplsUtilCli import genTraceroute
   from CliPlugin.MplsCli import maxIngressMplsTopLabels

   MplsUtilModel = CliDynamicPlugin( "MplsUtilModel" )
   AddressFamily = Tac.Type( 'Arnet::AddressFamily' )
   IpAddr = Tac.Type( 'Arnet::IpAddr' )
   Ip6Addr = Tac.Type( 'Arnet::Ip6Addr' )
   Ip6Prefix = Tac.Type( 'Arnet::Ip6Prefix' )
   IpGenAddr = Tac.Type( 'Arnet::IpGenAddr' )
   MplsOamStandard = Tac.Type( 'MplsUtils::MplsOamStandard' )
   Prefix = Tac.Type( 'Arnet::Prefix' )

   from ClientDispatcher import lspUtilHandler
   from MplsTracerouteClientLib import (
      cleanTracerouteSocket,
   )
   from ClientCommonLib import (
      createSocket,
      getThreadLocalData,
      setThreadLocalData,
      LspPingDSMap,
      LspPingDSTypes,
      LspPingTypes,
      LspTraceroute,
      igpProtocolTypes,
      setProductCodeGlobals,
      tracerouteUseCapi,
   )
   setProductCodeGlobals()

   p = argparse.ArgumentParser( description='LSP Traceroute Utility' )
   p.add_argument( 'destination', nargs='?', 
                   help='Destination to Traceroute to, optional for RSVP' )
   p.add_argument( '--type', metavar='destination-type', choices=LspPingTypes, 
                   help='Type of the destination to Traceroute to' )
   p.add_argument( '--srTunnel', metavar='segment-tunnel-type',
                   choices=igpProtocolTypes,
                   help='Type of segment routing tunnel to use in Segment Routing' )
   p.add_argument( '--vrf', metavar='VRF', 
                   help='Traceroute to destination in VRF' )
   p.add_argument( '--src', metavar='source-address', 
                   help='Source address for Traceroute packet' )
   p.add_argument( '--dst', metavar='destination-address',
                   help='Destination address for Traceroute packet' )
   p.add_argument( '--entry', metavar='nexthop-group-entry', type=int, 
                   help='Nexthop group entry index' )
   p.add_argument( '--nexthop', metavar='nexthop-address', 
                   help='Nexthop address' )
   p.add_argument( '--label', metavar='Label-or-label-stack',
                   help='Label or comma-separated label stack (Top of stack first)' )
   p.add_argument( '--smac', metavar='source-MAC', 
                   help='Source MAC Address of Traceroute packet' )
   p.add_argument( '--dmac', metavar='destination-MAC', 
                   help='Destination MAC address of Traceroute packet' )
   p.add_argument( '--count', metavar='count', type=int, default=1,
                   help='Number of Traceroute packets' )
   p.add_argument( '--interval', metavar='interval', type=int, default=1,
                   help='Interval of Traceroute packets in seconds' )
   p.add_argument( '--interface', metavar='interface',
                   help='Egress interface of Traceroute packets' )
   p.add_argument( '--verbose', metavar='verbose-output', type=bool,
                   help='Enable verbose output' )
   p.add_argument( '--tos', metavar='tos', choices=list( range( 0, 256 ) ),
                   type=int, help='Type of service' )
   p.add_argument( '--tc', metavar='traffic-class', type=int, 
                   default=0, choices=list( range( 0, 8 ) ),
                   help='MPLS traffic class field' )
   p.add_argument( '--session-id', metavar='session-id', type=int,
                   help='RSVP session ID to Traceroute to' )
   p.add_argument( '--tunnel', metavar='tunnel', type=str,
                   help='Tunnel name to Traceroute to' )
   p.add_argument( '--sub-tunnel-id', metavar='sub-tunnel-id', type=int,
                   help='Sub-tunnel id to Traceroute to' )
   p.add_argument( '--session-name', metavar='session-name', type=str,
                   help='RSVP session name to Traceroute to' )
   p.add_argument( '--lsp', metavar='lsp-rsvp', type=int,
                   help='RSVP LSP ID to Traceroute to' )
   p.add_argument( '--color', metavar='color', type=int, help='SR-TE Policy color '
                   'value' )
   p.add_argument( '--trafficAf', metavar='trafficAf', type=str, help='The type '
                   'of traffic steered through the policy' )
   p.add_argument( '--multipath', action='store_true',
                   help='Enable multipath diagnosis' )
   p.add_argument( '--dstype', metavar='dstype',
                   choices=LspPingDSTypes,
                   default=LspPingDSMap,
                   help='TLV type to use for the downstream mapping' )
   p.add_argument( '--standard', metavar='standard', type=str,
                   choices=( MplsOamStandard.arista, MplsOamStandard.ietf ),
                   help='The OAM standard to comply with for the Traceroute' )
   p.add_argument( '--multipathbase', metavar='multipathbase',
                   default='127.0.0.0',
                   help='Base IP to use in the multipath traceroute')
   p.add_argument( '--multipathcount', metavar='multipathcount',
                   type=int, default=64,
                   help='Number of IPs to use in the multipath traceroute')
   p.add_argument( '--multiLabel', metavar='multiLabel', type=int, default=1,
                   choices=[ maxIngressMplsTopLabels ],
                   help=( 'Perform validation using the specified number of '
                          'top labels during route lookups' ) )
   p.add_argument( '--size', metavar='size', type=int,
                   choices=list( range( 120, 10001 ) ),
                   help='Size of ping packet in bytes' )
   p.add_argument( '--padReply', metavar='padReply', type=bool, default=False,
                   help='Indicates that ping reply should copy the Pad TLV' )

   p.add_argument( '--genOpqVal', metavar='opaque-value', type=int,
                   help='Generic Opaque Value' )
   p.add_argument( '--sourceAddrOpqVal', metavar='source-address-opaque-value',
                   help='Source Address Opaque Value' )
   p.add_argument( '--groupAddrOpqVal', metavar='group-address-opaque-value',
                   help='Group Address Opaque Value' )
   p.add_argument( '--jitter', metavar='jitter-value', type=int,
                   help='Echo Jitter Value' )
   p.add_argument( '--responderAddr', metavar='responder-address',
                   help='Node Responder Address' )
   p.add_argument( '--maxTtl', metavar='Maximum TTL Value',
                   type=int, default=64, help='Maximum hops')
   p.add_argument( '--egressValidateAddress', metavar='egressValidateAddress',
                   type=str, help=( 'Perform egress validation' ) )
   p.add_argument( '--print-args', action='store_true',
                   help='print given arguments and exit( 0 )' )
   p.add_argument( '--cli', action='store_true',
                   help='Indicates that binary was invoked by cli' )
   p.add_argument( '--json', action='store_true',
                   help='Indicates that json output is requested' )
   p.add_argument( '--servSockPort', type=int, help='Server socket port' )
   p.add_argument( '--algorithm', type=int, help='Flexible algorithm ID' )
   p.add_argument( '--algorithmName', type=str, help='Flexible algorithm Name' )
   p.add_argument( '--backup', action='store_true',
                   help='Traceroute backup entry of NHG')

   args = vars( p.parse_args() )
   # This argument is parsed before lspUtilHandler since we do not want to create
   # an entityManager in this case.
   if args[ 'print_args' ]:
      print( args )
      sys.stdout.flush()
      sys.stderr.flush()
      sys.exit( 0 )
   # It is not needed in any case after the previous check.
   args.pop( 'print_args' )

   # session_id and lsp could be 0, so explicitly compare these args to None
   if args[ 'type' ] == 'rsvp':
      if ( ( args[ 'tunnel' ] is None ) and 
           ( ( args[ 'session_id' ] is None and not args[ 'session_name' ] ) or
           ( args[ 'session_id' ] is not None and args[ 'session_name' ] ) ) ):
         p.error( '--type rsvp requires either --tunnel or' 
                  ' exclusively --session-id or --session-name' )
      elif args[ 'session_name' ] and args[ 'lsp' ] is not None:
         p.error( '--lsp cannot be used with --session-name, use --session-id' )
      elif args[ 'session_id' ] is not None and args[ 'lsp' ] is None:
         p.error( '--session-id requires --lsp' )
   elif not args[ 'destination' ]:
      p.error( 'missing destination argument' )

   # Verify the multipathcount
   if args[ 'multipathcount' ] < 1 or args[ 'multipathcount' ] > 512:
      p.error( 'multipathcount must be in the range 1..512' )

   # Multipath option with mldp is not allowed
   if args[ 'type' ] == 'mldp' and args[ 'multipath' ]:
      p.error( 'mldp type cannot be used with --multipath' )

   # Verify the multipathbase, and multipathbase + multipathcount
   try:
      testGenAddr = IpGenAddr( args[ 'multipathbase' ] )

      if testGenAddr.af == AddressFamily.ipv4:
         loopbackPrefix = Prefix( IpAddr.ipAddrLoopbackBase, 8 )
         testAddr = IpAddr()
         testAddr.stringValue = testGenAddr.v4Addr
         testTopAddr = IpAddr( testAddr.value + args[ 'multipathcount' ] )
      else:
         loopbackPrefix = Ip6Prefix( Ip6Addr( 0, 0, 0xffff, 0x7f000000 ), 104 )
         testAddr = testGenAddr.v6Addr
         testTopAddr = Ip6Addr( 0, 0, 0xffff,
                                testAddr.word3 + args[ 'multipathcount' ] )

      if not loopbackPrefix.contains( testAddr ):
         p.error( 'multipathbase must be within %s' % loopbackPrefix )

      if not loopbackPrefix.contains( testTopAddr ):
         p.error( 'multipathcount extends out of the valid IP range of %s' \
                  % loopbackPrefix )

   except IndexError:
      p.error( 'multipathbase must be a v4 or v6 address' )

   # Lsp traceroute is not supported if issued from non-default vrf
   currNamespace = subprocess.check_output(
                      ['stat', '-L', '--format=%i', '/proc/self/ns/net'] )
   defaultNamespace = subprocess.check_output(
                         ['stat', '--format=%i', '/var/run/netns/default'] )
   if currNamespace != defaultNamespace:
      CliCommon.printErrorMessage(
         "LSP traceroute is not supported in non-default VRF" )
      sys.exit( 0 )
   # Create entity manager here to allow easy testing of lspUtilHandler
   sysname = os.environ.get( 'SYSNAME', 'ar' )
   entityManager = EntityManager.Sysdb( sysname, agentName='LspTraceroute' )

   def handler( util, em, arguments ):
      global retCode
      retCode = lspUtilHandler( util, em, arguments )

   # If this binary is invoked directly for traceroute then args.get( 'cli' )
   # returns False and "if" block below executes where render/genTracerouteWrapper
   # function is in one thread and lspUtilHandler in another thread.
   # If invoked via cli then "else" block executes and this binary is already in
   # subprocess and render/genTraceroute function is in CliPugin thread.
   
   if not args.get( 'cli' ) and tracerouteUseCapi( args.get( 'type' ) ):
      retCode = None
      servSockPort = createSocket()
      args.update( { 'servSockPort' : servSockPort } )
      thread = Thread( target=handler,
                       args=( LspTraceroute, entityManager, args ) )
      setThreadLocalData( 't', thread )
      thread.start()
      if args.get( 'json' ):
         genTraceroute()
         model = getThreadLocalData( 'tracerouteModel' )
         print( json.dumps( model.toDict(), sort_keys=False, indent=4 ) )
      else:
         MplsUtilModel.MplsTraceroute().render()
      thread.join()
      # cleaning socket after child thread has returned/exited.
      cleanTracerouteSocket()
      sys.exit( retCode )
   else: 
      sys.exit( lspUtilHandler( LspTraceroute, entityManager, args ) )

# We catch the exception but skip doing anything to handle it.
except KeyboardInterrupt:
   pass
except Exception: # pylint: disable-msg=broad-except
   pass
except SystemExit:
   # pylint: disable-msg=protected-access
   pass

