# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import commands
import socket
import struct
from cli_parser.option_parser import OptionParser

# pylint: disable=broad-except

class DpsPathMetricsOptionParser( OptionParser ):
   def register_options( self ):
      self.add_option( "-latency", "uint" )
      self.add_option( "-jitter", "uint" )
      self.add_option( "-lossRate", "uint" )
      self.add_option( "-totalBandwidth", "uint" )
      self.add_option( "-availBandwidth", "uint" )
      self.add_option( "-h", "action", action=self.help )

   @staticmethod
   def help():
      helpStr = '''
Usage:
   dps path [ pathId ] metrics [ -latency <latency> ] [ -jitter <jitter> ]
                               [ -lossRate <lossRate> ]
                               [ -totalBandwidth <totalBandwidth> ]
                               [ -availBandwidth <availBandwidth> ]
   For a path set metrics like latency, jitter, lossRate,
                               total bandwidth, available bandwidth
Options:
   -latency         configure latency in microseconds
   -jitter          configure jitter in microseconds
   -lossRate        configure lossRate scaled by 10000 for %
   -totalBandwidth  configure total bandwidth in Mbps
   -availBandwidth  configure avail bandwidth in Mbps
   -h               help
   '''
      print( helpStr )

   def get_latency( self ):
      return self.get_value( "-latency" )

   def get_jitter( self ):
      return self.get_value( "-jitter" )

   def get_lossRate( self ):
      return self.get_value( "-lossRate" )

   def get_totalBandwidth( self ):
      return self.get_value( "-totalBandwidth" )

   def get_availBandwidth( self ):
      return self.get_value( "-availBandwidth" )

@commands.cmd( 'show dps bestpath  [LBID]',
   'Show best paths for all load balance profiles or for a given load-balance id' )
def show_dps_bestpath( cli, lbId ):
   modName = [ m.name for m in cli.bess.list_modules().modules if m.mclass == 'Dps' ]
   if not modName:
      return
   args = {}
   if lbId:
      args[ 'lb_id' ] = lbId
   ret = cli.bess.run_module_command( modName[ 0 ], 'showLbBestPath',
                                      'DpsLbProfileArg', args )
   cli.fout.write(
         'Load Balance Profiles and their Best Paths: \n%s\n' % repr( ret ) )

@commands.cmd( 'show dps rules',
      'Show the mapping of <vtepIp, vrf, appProfile> mapping to LBID' )
def show_dps_rules( cli ):
   modName = [ m.name for m in cli.bess.list_modules().modules if m.mclass == 'Dps' ]
   if not modName:
      return
   ret = cli.bess.run_module_command( modName[ 0 ], 'showDpsRules', 'EmptyArg', {} )
   cli.fout.write( 'Dps Rules: \n%s\n' % repr( ret ) )

def _show_dps_flows( cli, args ):
   modName = [ m.name for m in cli.bess.list_modules().modules if m.mclass == 'Dps' ]
   if not modName:
      return 0
   ret = cli.bess.run_module_command( modName[ 0 ],
                                      'getDpsFlows',
                                      'GetDpsFlowsArg', args )
   # may not have entries if DPS flows are not found
   if not ret.info:
      return ret.next_iter
   for entry in ret.info:
      src_ip = socket.ntohl( entry.src_ip )
      src_port = socket.ntohs( entry.src_port )
      dest_ip = socket.ntohl( entry.dest_ip )
      dest_port = socket.ntohs( entry.dest_port )
      path_src_ip = socket.ntohl( entry.path_src_ip )
      path_dest_ip = socket.ntohl( entry.path_dest_ip )
      cli.fout.write( "FLOW : vrf %d src_ip %s port %d dest_ip %s "
          "port %d proto %d age %ds PATH : src_ip %s dest_ip %s \n" % (
          entry.vrf_id,
          socket.inet_ntoa( struct.pack( '!L', src_ip ) ), src_port,
          socket.inet_ntoa( struct.pack( '!L', dest_ip ) ), dest_port,
          entry.protocol, entry.flow_age / ( 1000000 ),
          socket.inet_ntoa( struct.pack( '!L', path_src_ip ) ),
          socket.inet_ntoa( struct.pack( '!L', path_dest_ip ) ) ) )
   return ret.next_iter

@commands.cmd( 'show dps-flows [DESTIP]',
     'show dps flows: display all dps flows in flow cache for a dest IP' )
def show_dps_flows( cli, destIp ):
   ret = cli.bess.run_module_command( 'Management',
                                      'getFlowCacheInfo',
                                      'EmptyArg', {} )
   cli.fout.write( 'Total flowcache-flows: %d flows. DPS flows :\n'
                   % ret.flow_count )
   if ret.flow_count:
      args = {}
      args[ 'max_iterate' ] = 100
      args[ 'iter' ] = 0
      next_iter = 1
      if destIp != None:
         args[ 'dest_ip' ] = socket.htonl( struct.unpack( "!I",
                                           socket.inet_aton( destIp ) )[ 0 ] )
      while next_iter != 0:
         next_iter = _show_dps_flows( cli, args )
         args[ 'iter' ] = next_iter
   cli.fout.write( '\n' )

def show_path( cli, p ):
   state = { True : "active", False : "inactive" }
   cli.fout.write( "index: %d %s vtep %s\n" % (
                   p.path_index,
                   state[ p.its_pc[ 0 ].txstate ],
                   socket.inet_ntoa( struct.pack( '!L', p.vtep_ip ) ) ) )
   cli.fout.write( "   src %s dst %s sport %d dport %d egress %s\n" % (
                       socket.inet_ntoa( struct.pack( '!L', p.src_ip ) ),
                       socket.inet_ntoa( struct.pack( '!L', p.dst_ip ) ),
                       p.src_port, p.dst_port, p.src_intf ) )
   if p.its_pc[ 0 ].txstate:
      cli.fout.write( "   latency %s jitter %s pktloss %s throughput %s\n" % (
                          str( p.its_pc[ 0 ].latency ),
                          str( p.its_pc[ 0 ].jitter ),
                          str( p.its_pc[ 0 ].packetloss ),
                          str( p.its_pc[ 0 ].throughput ) ) )

      cli.fout.write( "   mtu %d active-flows %d\n" % (
                         p.ip_mtu,
                         p.active_flows ) )
      load = 0.0
      if p.its_pc[ 0 ].total_bandwidth:
         load = round( ( p.its_pc[ 0 ].total_bandwidth -
                         p.its_pc[ 0 ].avail_bandwidth ) /
                              p.its_pc[ 0 ].total_bandwidth, 4 )
      load = load * 100
      cli.fout.write( "   load %s total bandwidth %s available bandwidth %s\n" % (
                          str( load ),
                          str( p.its_pc[ 0 ].total_bandwidth ),
                          str( p.its_pc[ 0 ].avail_bandwidth ) ) )
   cli.fout.write( '\n' )

@commands.cmd( 'show dps paths [PATH_INDEX]',
     'show dps paths: display path information for all/specific Dps paths' )
def show_dps_paths( cli, pathIndex ):
   modName = [ m.name for m in cli.bess.list_modules().modules
         if m.mclass == 'DpsGlobal' ]
   if not modName:
      return
   arg = {}
   if pathIndex is None:
      pathIndex = 0
   arg[ 'path_index' ] = pathIndex
   ret = cli.bess.run_module_command( modName[ 0 ],
                                      'showGlobalPaths',
                                      'DpsGlobalGetPathArg', arg )
   for i in range( len( ret.path_info ) ):
      show_path( cli, ret.path_info[ i ] )
   cli.fout.write( '\n' )

def get_pmtu_tx_state_str( state ):
   stateStr = [ "INIT", "PROBING", "PROBING_COMPLETE", "ERROR" ]
   return stateStr[ state ]

def show_mtu_details( cli, p ):
   modName = [ m.name for m in cli.bess.list_modules().modules
         if m.mclass == 'ItsCtrl' ]
   if not modName:
      return
   # query the stats and status for Rx & Tx
   arg = { }
   arg[ 'path_index' ] = p.path_index
   arg[ 'tc' ] = 0
   RxStats = cli.bess.run_module_command( modName[ 0 ],
                                          'getPmtuRxStats',
                                          'ItsPathIndexArg', arg )
   RxStatus = cli.bess.run_module_command( modName[ 0 ],
                                          'getPmtuRxStatus',
                                          'ItsPathIndexArg', arg )
   TxStats = cli.bess.run_module_command( modName[ 0 ],
                                          'getPmtuTxStats',
                                          'ItsPathIndexArg', arg )
   TxStatus = cli.bess.run_module_command( modName[ 0 ],
                                          'getPmtuTxStatus',
                                          'ItsPathIndexArg', arg )
   cli.fout.write( "Path Index: %d\n" % p.path_index )
   cli.fout.write(
         "Src Ip: %s\n" % socket.inet_ntoa( struct.pack( '!L', p.src_ip ) ) )
   cli.fout.write(
         "Dst Ip: %s\n" % socket.inet_ntoa( struct.pack( '!L', p.dst_ip ) ) )
   cli.fout.write( "Dst Port: %d\n" % p.dst_port )
   cli.fout.write( "Tx state: %s\n" % str( p.its_pc[ 0 ].txstate ) )
   cli.fout.write( "Ip MTU: %d\n" % p.ip_mtu )
   cli.fout.write( "Tx Stats:\n" )
   if TxStats != None:
      cli.fout.write( "    Total probes sent: %d\n" % TxStats.total_probes_sent )
      cli.fout.write(
         "    Total report req sent: %d\n" % TxStats.total_report_req_sent )
      cli.fout.write(
         "    Total report resp received: %d\n" % TxStats.total_report_resp_rcvd )
      cli.fout.write( "    Err code report resp received: %d\n" % (
         TxStats.err_code_report_resp_rcvd ) )
      cli.fout.write( "    Out of ctxt report resp received: %d\n" % (
         TxStats.out_of_ctx_report_resp_rcvd ) )
      cli.fout.write( "    Invalid report resp received: %d\n" % (
         TxStats.invalid_report_resp_rcvd ) )
      cli.fout.write( "    Non matching report resp received: %d\n" % (
         TxStats.non_matching_report_resp_rcvd ) )
      cli.fout.write( "    Already proc report resp received: %d\n" % (
         TxStats.already_proc_report_resp_rcvd ) )
      cli.fout.write( "    Total txns: %d\n" % TxStats.total_txns )
      cli.fout.write( "    Pkt alloc err: %d\n" % TxStats.pkt_alloc_err )
   cli.fout.write( "Tx Status:\n" )
   if TxStatus != None:
      cli.fout.write( "    State: %s\n" % get_pmtu_tx_state_str( TxStatus.state ) )
      cli.fout.write(
            "    Round one Pmtu detected: %d\n" % TxStatus.round_one_pmtu_detected )
      cli.fout.write(
            "    Final Pmtu detected: %d\n" % TxStatus.final_pmtu_detected )
      cli.fout.write(
            "    Num rounds completed: %d\n" % TxStatus.num_rounds_done )
      for i in range( len( TxStatus.txn_info ) ):
         txn = TxStatus.txn_info[ i ]
         cli.fout.write( "    TxnId: %d\n" % txn.txn_id )
         cli.fout.write( "    Probes sent: [ " )
         for j in range( len( txn.probe_size_info ) ):
            cli.fout.write( "%d " % txn.probe_size_info[ j ] )
         cli.fout.write( "]\n" )
      cli.fout.write(
          "    Discovery time taken (us): %d\n" % TxStatus.disc_time_taken_us )
      cli.fout.write(
          "    Discovery interval (secs): %d\n" % TxStatus.mtu_disc_interval_secs )
      cli.fout.write(
          "    Next Discovery due in (secs): %d\n" % TxStatus.disc_due_in_secs )
   cli.fout.write( "Rx Stats:\n" )
   if RxStats != None:
      cli.fout.write( "    Total probes received: %d\n" % RxStats.total_probes_rcvd )
      cli.fout.write(
          "    Total report req received: %d\n" % RxStats.total_report_req_rcvd )
      cli.fout.write( "    Err code report resp sent: %d\n" % (
         RxStats.err_code_report_resp_sent ) )
      cli.fout.write(
          "    Total report resp sent: %d\n" % RxStats.total_report_resp_sent )
      cli.fout.write(
          "    Total txns received: %d\n" % RxStats.total_txns_rcvd )
      cli.fout.write(
          "    Dropped Invalid pkts: %d\n" % RxStats.dropped_invalid_pkts )
      cli.fout.write( "    Dropped Invalid probe pkts: %d\n" % (
         RxStats.dropped_invalid_probe_pkts ) )
      cli.fout.write( "    Dropped Old probe pkts: %d\n" % (
         RxStats.dropped_old_probe_pkts ) )
      cli.fout.write( "    Duplicate probe pkts: %d\n" % RxStats.dup_probe_pkts )
      cli.fout.write( "    Dropped Invalid report req pkts: %d\n" % (
         RxStats.dropped_invalid_report_req_pkts ) )
   cli.fout.write( "Rx Status:\n" )
   if RxStatus != None:
      for i in range( len( RxStatus.txn_info ) ):
         txn = RxStatus.txn_info[ i ]
         cli.fout.write( "    TxnId: %d\n" % txn.txn_id )
         cli.fout.write( "    Num probes recv: %d\n" % txn.recv_probe_index )
         cli.fout.write( "    Probes recv: [ " )
         for j in range( txn.recv_probe_index ):
            cli.fout.write( "%d " % txn.probe_size_info[ j ] )
         cli.fout.write( "]\n" )
         cli.fout.write( "    Max probe size recv: %d\n" % txn.max_probe_size_recv )
         cli.fout.write( "    Latest timestamp: %d\n" % txn.latest_timestamp )

@commands.cmd( 'show dps mtu [PATH_INDEX]',
     'show dps mtu: display mtu information for all/specific Dps paths' )
def show_dps_path_mtu( cli, pathIndex ):
   modName = [ m.name for m in cli.bess.list_modules().modules
         if m.mclass == 'DpsGlobal' ]
   if not modName:
      return
   arg = {}
   if pathIndex is None:
      pathIndex = 0
   arg[ 'path_index' ] = pathIndex
   ret = cli.bess.run_module_command( modName[ 0 ],
                                      'showGlobalPaths',
                                      'DpsGlobalGetPathArg', arg )
   for i in range( len( ret.path_info ) ):
      show_mtu_details( cli, ret.path_info[ i ] )
   cli.fout.write( '\n' )

@commands.cmd( 'dps path PATH_INDEX ENABLE_DISABLE',
'dps path : disable a Dps path' )
def set_dps_path_disabled( cli, pathIndex, adminDisabled ):
   modName = [ m.name for m in cli.bess.list_modules().modules
         if m.mclass == 'ItsCtrl' ]
   if not modName:
      cli.fout.write( 'DpsPathDisable command handler module unavailable' )
      return
   if pathIndex is None:
      return
   if adminDisabled is None:
      return

   arg = { }
   arg[ 'path_index' ] = pathIndex
   arg[ 'admin_disabled' ] = adminDisabled == 'disable'
   ret = cli.bess.run_module_command( modName[ 0 ],
                                       'setPathTxAdminDisable',
                                       'ItsPathTxAdminStateArg', arg )

@commands.cmd( 'dps path PATH_INDEX metrics [DPS_METRICS_OPTS...]',
'dps path : set metrics for a Dps path' )
def set_dps_path_metrics( cli, pathIndex, opts ):
   modName = [ m.name for m in cli.bess.list_modules().modules
         if m.mclass == 'ItsCtrl' ]
   if not modName:
      cli.fout.write( 'DpsPathMetrics command handler module unavailable' )
      return

   latency = 0
   jitter = 0
   loss_rate = 0
   total_bandwidth = 0
   avail_bandwidth = 0
   try:
      parser = DpsPathMetricsOptionParser( opts )
      parser.parse()
      if parser.get_latency():
         latency = parser.get_latency()
      if parser.get_jitter():
         jitter = parser.get_jitter()
      if parser.get_lossRate():
         loss_rate = parser.get_lossRate()
      if parser.get_totalBandwidth():
         total_bandwidth = parser.get_totalBandwidth()
      if parser.get_availBandwidth():
         avail_bandwidth = parser.get_availBandwidth()
   except Exception as error:
      # pylint: disable=no-member
      print( error )

   arg = { }
   arg[ 'path_index' ] = pathIndex
   arg[ 'latency' ] = latency
   arg[ 'jitter' ] = jitter
   arg[ 'loss_rate' ] = loss_rate
   arg[ 'total_bandwidth' ] = total_bandwidth
   arg[ 'avail_bandwidth' ] = avail_bandwidth
   ret = cli.bess.run_module_command( modName[ 0 ],
                                       'setPathTxMetrics',
                                       'ItsPathTxMetricsArg', arg )

@commands.cmd( 'show dps vtepmtutable',
     'display vtep mtu table contents' )
def show_dps_vtep_mtu_table( cli ):
   modName = [ m.name for m in cli.bess.list_modules().modules
         if m.mclass == 'DpsGlobal' ]
   if not modName:
      return
   ret = cli.bess.run_module_command( modName[ 0 ],
                                      'getRemoteMtuTable',
                                      'EmptyArg', {} )
   for i in range( len( ret.mtu_entry ) ):
      cli.fout.write( "Index: %u\tMTU: %u\n" %
          ( ret.mtu_entry[ i ].remote_index, ret.mtu_entry[ i ].mtu ) )
   cli.fout.write( '\n' )

def show_dps_icmp_frag_needed_stats( cli, stats ):
   cli.fout.write( "ICMP fragmentation needed message generation statistics:\n" )
   cli.fout.write( "--------------------------------------------------------\n" )
   cli.fout.write( "Number of messages sent to FIB: %d\n" % ( stats.sent_to_fib ) )
   leadStr = "Number of messages suppressed due to"
   cli.fout.write( "%s rate limit: %d\n" % ( leadStr, stats.dropped_throttle ) )
   cli.fout.write( "%s no source address: %d\n" % ( leadStr,
      stats.dropped_no_src_ip_addr ) )
   cli.fout.write( "%s prepend header failure: %d\n" % ( leadStr,
      stats.dropped_prepend_ip_hdr_fail ) )
   cli.fout.write( '\n' )

@commands.cmd( 'show dps icmp statistics',
     'Show DPS ICMP fragmentation needed message generation statistics' )
def show_dps_icmp_statistics( cli ):
   modName = [ m.name for m in cli.bess.list_modules().modules
         if m.mclass == 'DpsIcmpFragNeeded' ]
   if not modName:
      return
   stats = cli.bess.run_module_command( modName[ 0 ], 'getStats', 'EmptyArg', {} )
   show_dps_icmp_frag_needed_stats( cli, stats )

@commands.var_attrs( '[DPS_METRICS_OPTS...]' )
def dps_path_metrics_var_attrs():
   # Return (var_type(str), var_desc(str), var_candidates([str]))
   return ( 'opts', '[OPTIONS]( run dps path PATHINDEX metrics -h for more help )',
            [] )
