#!/usr/bin/env python3
# Copyright (c) 2022 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
# pkgdeps: library WanTEUtils

from __future__ import absolute_import, division, print_function

import BasicCli
from CliCommand import (
   CliExpressionFactory,
   CliExpression,
)
import ShowCommand
import Tac
import sys
import signal
import CmdExtension
from CliMatcher import (
   IntegerMatcher,
   FloatMatcher,
)
from CliPlugin import NetworkToolsCli

from CliToken.AvtCliToken import (
   avtNodeForPing,
   AvtVrfExprFactory,
   AvtNameExprFactory,
   AvtDestinationExprFactory,
)

from CliModel import (
   Dict,
   Int,
   Float,
   List,
   Model,
   Str,
   Bool,
)
from CliModel import cliPrinted

class AvtPathSequenceModel( Model ):
   pathId = Str( help="Path ID" )
   ttlRemaining = Int( help="TTL remaining in the request packet before it "
                            "was received at the destination" )
   rtt = Float( help="RTT for the packet" )
   onewayTime = Float( help="oneway time for the packet to reach destination" )
   errorCode = Int( help="Error code if any error happens along the path",
                    optional=True )
   errorMessage = Str( help="Error message, if any error happens along the path",
                       optional=True )

class AvtPingSequenceModel( Model ):
   paths = List( valueType=AvtPathSequenceModel,
                 help="Avt ping details for a sequence on a path" )

class AvtPathSummaryModel( Model ):
   pktsTx = Int( help="Packets sent along the path" )
   pktsRx = Int( help="Packets received along the path" )
   lossRate = Float( help="loss rate for the path" )
   rttMin = Float( help="Minimum RTT for a packet along the path" )
   rttMax = Float( help="Maximum RTT for a packet along the path" )
   rttAvg = Float( help="Average RTT for each packet along the path" )
   rttMdev = Float( help="Standard deviation of RTT for packets along the path" )
   onewayTimeMin = Float( help="Minimum oneway time for a packet along the path" )
   onewayTimeMax = Float( help="Maximum oneway time for a packet along the path" )
   onewayTimeAvg = Float( help="Average oneway time for packets along the path" )
   onewayTimeMdev = Float(
      help="Standard deviation of oneway times for packets along the path" )
   totalTime = Float( help="Total time taken to ping along the path" )
   labelStack = List( valueType=int,
                      help="LabelStack of the multihop path", optional=True )

class AvtPingCliModel( Model ):
   vrf = Str( help="VRF name in which the ping command is being executed" )
   avt = Str( help="AVT name in which the ping command is being executed" )
   destVtep = Str( help="Destination VTEP IP" )
   repeat = Int( help="Number of times ping is repeated" )
   errorCode = Int( help="Error code if any while sending ping requests",
                    optional=True )
   errorMessage = Str( help="Error in sending ping requests", optional=True )
   sequences = Dict( keyType=str,
                     valueType=AvtPingSequenceModel,
                     help="Sequence wise detail for every path",
                     optional=True )
   summary = Dict( keyType=str,
                   valueType=AvtPathSummaryModel,
                   help="summary of every ping across every path",
                   optional=True )

class AvtPathTraceInfoModel( Model ):
   labelSeq = Int( help="label sequence number", optional=True )
   directPathId = Int( help="Direct path ID corresponding to the current label",
                       optional=True )
   srcIp = Str( help="Source IP of the path", optional=True )
   egressInterface = Str( help="Egress interface from the source IP",
                            optional=True )
   dstIp = Str( help="Destination IP of the path", optional=True )
   ingressInterface = Str( help="Ingress interface at the destination IP",
                             optional=True )
   onewayTime = Float( help="Oneway time taken to traverse the current label",
                       optional=True )
   reroutedNode = Bool( help="Flag to convey the node where reroute happens",
                        optional=True )
   errorCode = Int( help="Error code if any error happens along the path",
                    optional=True )
   errorMessage = Str( help="Error message, if any error happens along the path",
                       optional=True )

class AvtPathTraceModel( Model ):
   labelStack = List( valueType=int,
                      help="LabelStack of the multihop path", optional=True )
   reroutedPath = Bool( default=False,
                        help="Flag to convey if there is rerouting along the path" )
   pathTrace = List( valueType=AvtPathTraceInfoModel,
                     help="Traceroute details for a path" )

class AvtTracerouteCliModel( Model ):
   vrf = Str( help="VRF name in which the ping command is being executed" )
   avt = Str( help="AVT name in which the ping command is being executed" )
   destVtep = Str( help="Destination VTEP IP" )
   errorCode = Int( help="Error code if any while sending traceroue requests",
                    optional=True )
   errorMessage = Str( help="Error in sending traceroute requests", optional=True )
   summary = Dict( keyType=str,
                   valueType=AvtPathTraceModel,
                   help="Path-wise traceroute summary",
                   optional=True )

maxDirectPath = Tac.Type( "Dps::DpsConstants" ).maxPathIndex
maxMhPath = Tac.Type( "Avt::AvtConstants" ).maxMultihopPathId

vrfHelpStr = 'Ping in a VRF'
avtHelpStr = 'AVT to ping'
destHelpStr = 'Destination to ping'
pathIdHelpStr = 'Path ID to ping'

countMatcher = IntegerMatcher( 1, 255, helpdesc='ping pkt count' )
sizeMatcher = IntegerMatcher( 64, 0xffff, helpdesc='ping pkt size' )
ttlMatcher = IntegerMatcher( 1, 64, helpdesc='utility pkt TTL' )
intervalMatcher = FloatMatcher( 1.0, 3600, helpdesc='interval in 1ms increment',
                                precisionString='%.3f' )

directPathIdMatcher = IntegerMatcher( 1, maxDirectPath, helpdesc='path-id' )
mhPathIdMatcher = IntegerMatcher( 1, maxMhPath, helpdesc='path-id' )

class AvtDirectPathIdExprFactory( CliExpressionFactory ):
   def __init__( self, helpStr ):
      self.helpStr = helpStr
      CliExpressionFactory.__init__( self )

   def generate( self, name ):
      class AvtPathIdDefinitionExpr( CliExpression ):
         expression = 'direct path-id PATH-ID'
         data = {
            'direct': 'direct path',
            'path-id': self.helpStr,
            'PATH-ID': directPathIdMatcher,
         }
      return AvtPathIdDefinitionExpr

class AvtMhPathIdExprFactory( CliExpressionFactory ):
   def __init__( self, helpStr ):
      self.helpStr = helpStr
      CliExpressionFactory.__init__( self )

   def generate( self, name ):
      class AvtPathIdDefinitionExpr( CliExpression ):
         expression = 'multihop path-id MH-PATH-ID'
         data = {
            'multihop': 'multihop path',
            'path-id': self.helpStr,
            'MH-PATH-ID': mhPathIdMatcher,
         }
      return AvtPathIdDefinitionExpr

# invoke the Avt utility script
# currently supported: AvtPing, AvtTraceRoute
def callAvtUtilScript( mode, args ):
   cliCmdExt = CmdExtension.getCmdExtender()
   procInfo = None

   try:
      # Flush stdout before spawning subprocess as the output to stdout is
      # buffered: that is, print statements actually write to a buffer, and
      # this buffer is only occassionally flushed to the terminal. Each
      # process has a separate buffer, which is why writes from different
      # processes can appear out of order
      sys.stdout.flush()
      procInfo = cliCmdExt.subprocessPopen( args, mode.session,
                                            stdout=sys.stdout, stderr=sys.stderr )
   except EnvironmentError as e:
      mode.addError( e.strerror )

   except KeyboardInterrupt:
      # procinfo won't be none
      procInfo.kill( sig=signal.SIGINT )

   # wait for the process to terminate
   procInfo.wait()

# --------------------------------------------------------------------------------
# ping adaptive-virtual-topology vrf VRFNAME avt AVTNAME destination IP_ADDRESS
#           [ [ multihop | direct ] path-id PATH_ID ]
#           [ repeat NUM ] [ ttl NUM ][ fast-reroute ]
# --------------------------------------------------------------------------------
class AvtPingCmd( ShowCommand.ShowCliCommandClass ):
   syntax = '''ping adaptive-virtual-topology VRF AVT DEST
               [ DIR_PATH | MH_PATH ] [ ttl TTL ]
               [ repeat NUM ] [ interval FLOAT ] [ size SIZE ][ fast-reroute ]'''
   data = {
      'ping': NetworkToolsCli.pingMatcher,
      'adaptive-virtual-topology': avtNodeForPing,
      'VRF': AvtVrfExprFactory( vrfHelpStr ),
      'AVT': AvtNameExprFactory( avtHelpStr ),
      'DEST': AvtDestinationExprFactory( destHelpStr ),
      'DIR_PATH': AvtDirectPathIdExprFactory( pathIdHelpStr ),
      'MH_PATH': AvtMhPathIdExprFactory( pathIdHelpStr ),
      'ttl': 'Specify the TTL of the ping packet',
      'TTL': ttlMatcher,
      'repeat': 'Specify repeat count. '
                'Count will be set to 1 if output format is json',
      'NUM': countMatcher,
      'interval': 'Specify ping packet interval in ( seconds.millisecond format )',
      'FLOAT': intervalMatcher,
      'size': 'Specify ping packet size',
      'SIZE': sizeMatcher,
      'fast-reroute': 'Fast reroute around failing intermediate links',
   }

   noMore = True
   cliModel = AvtPingCliModel

   @staticmethod
   def handler( mode, args ):
      destination = args[ 'DEST' ]
      vrfName = args[ 'VRF' ]
      avtName = args[ 'AVT' ]

      utilArgs = [ 'sudo', 'AvtPing' ]
      utilArgs += [ '--vrf', vrfName, '--avt', avtName, destination ]
      if 'direct' in args:
         pathId = str( args[ 'PATH-ID' ] )
         utilArgs += [ '--direct-path', pathId ]
      elif 'multihop' in args:
         pathId = str( args[ 'MH-PATH-ID' ] )
         utilArgs += [ '--multihop-path', pathId ]
      count = None
      if 'repeat' in args:
         count = args[ 'NUM' ]
         count = str( count )
      if mode.session_.outputFormat_ == 'json':
         utilArgs += [ '--json' ]
         # if outputFormat is json, then repeat should be 1
         count = str( 1 )
      if count:
         utilArgs += [ '--repeat', count ]
      if 'interval' in args:
         interval = args[ 'FLOAT' ]
         interval = str( interval )
         utilArgs += [ '--interval', interval ]
      if 'size' in args:
         size = args[ 'SIZE' ]
         size = str( size )
         utilArgs += [ '--size', size ]
      if 'fast-reroute' in args:
         utilArgs += [ '--fast-reroute' ]
      if 'ttl' in args:
         ttl = args[ 'TTL' ]
         utilArgs += [ '--ttl', str( ttl ) ]

      callAvtUtilScript( mode, utilArgs )
      return cliPrinted( AvtPingCliModel )

# --------------------------------------------------------------------------------
# traceroute adaptive-virtual-topology vrf VRFNAME avt AVTNAME destination IP_ADDRESS
#           [ [ multihop | direct ] path-id PATH_ID ] [ ttl NUM ] [ fast-reroute ]
# --------------------------------------------------------------------------------
class AvtTraceRouteCmd( ShowCommand.ShowCliCommandClass ):
   syntax = '''traceroute adaptive-virtual-topology VRF AVT DEST
               [ DIR_PATH | MH_PATH ] [ ttl TTL ] [ fast-reroute ]'''
   data = {
      'traceroute': NetworkToolsCli.tracerouteKwMatcher,
      'adaptive-virtual-topology': avtNodeForPing,
      'VRF': AvtVrfExprFactory( vrfHelpStr ),
      'AVT': AvtNameExprFactory( avtHelpStr ),
      'DEST': AvtDestinationExprFactory( destHelpStr ),
      'DIR_PATH': AvtDirectPathIdExprFactory( pathIdHelpStr ),
      'MH_PATH': AvtMhPathIdExprFactory( pathIdHelpStr ),
      'ttl': 'Specify the TTL of the traceroute packet',
      'TTL': ttlMatcher,
      'fast-reroute': 'Fast reroute around failing intermediate links',
   }

   noMore = True
   cliModel = AvtTracerouteCliModel

   @staticmethod
   def handler( mode, args ):
      destination = args[ 'DEST' ]
      vrfName = args[ 'VRF' ]
      avtName = args[ 'AVT' ]

      utilArgs = [ 'sudo', 'AvtTraceRoute' ]
      utilArgs += [ '--vrf', vrfName, '--avt', avtName, destination ]
      if 'direct' in args:
         pathId = str( args[ 'PATH-ID' ] )
         utilArgs += [ '--direct-path', pathId ]
      elif 'multihop' in args:
         pathId = str( args[ 'MH-PATH-ID' ] )
         utilArgs += [ '--multihop-path', pathId ]
      if 'fast-reroute' in args:
         utilArgs += [ '--fast-reroute' ]
      if 'ttl' in args:
         ttl = args[ 'TTL' ]
         utilArgs += [ '--ttl', str( ttl ) ]
      if mode.session_.outputFormat_ == 'json':
         utilArgs += [ '--json' ]

      callAvtUtilScript( mode, utilArgs )
      return cliPrinted( AvtTracerouteCliModel )

BasicCli.EnableMode.addShowCommandClass( AvtPingCmd )

BasicCli.EnableMode.addShowCommandClass( AvtTraceRouteCmd )

def Plugin( entityManager ):
   pass
