# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=import-error
import commands
import socket
import struct
from cli_parser.option_parser import OptionParser
from collections import namedtuple

# pylint: disable=broad-except

mhPathFlag = 0x80000000

# Returns the name of the first module of the given class (or of the AvtIpv6 class if
# ipv6=True)
def _getAvtModule( cli, mclass='Avt', ipv6=False ):
   mclass = 'AvtIpv6' if ipv6 else mclass
   mods = [ m.name for m in cli.bess.list_modules().modules if m.mclass == mclass ]
   if not mods:
      print( f'{mclass} module is not active' )
      return None
   return mods[ 0 ]

class ShowAvtBestPathOptionParser( OptionParser ):
   def register_options( self ):
      self.add_option( "-vni", "uint" )
      self.add_option( "-avt", "uint" )
      self.add_option( "-peerGroupId", "uint" )
      self.add_option( "-h", "action", action=self.help )

   @staticmethod
   def help():
      helpStr = '''
Usage:
   show avt bestpath [ -vni <vniId> ] [ -avt <avtID> ] [ -peerGroupId <peerGroupId> ]
   Print all avt bestpaths. Use vni avt or peerGroupId
Options:
   -vni         show entries with the vniID of the avt path table
   -avt         show entries with the specific avtID of the avt path table
   -peerGroupId show entries with the specific peerGroupId of the avt path table
   -h           help
   '''
      print( helpStr )

   def get_vni( self ):
      return self.get_value( "-vni" )

   def get_avt( self ):
      return self.get_value( "-avt" )

   def get_peerGroupId( self ):
      return self.get_value( "-peerGroupId" )

class ShowAvtFlowInfoOptionParser( OptionParser ):
   def register_options( self ):
      self.add_option( "-6", "bool" )
      self.add_option( "-h", "action", action=self.help )

   @staticmethod
   def help():
      helpStr = '''
Usage:
   show avt flow-info [ -6 ]
   Print Avt flowcache info for all IP or IPv6 flows.
Options:
   -6           show IPv6 flows
   -h           help
   '''
      print( helpStr )

   def get_ipv6( self ):
      return self.get_value( "-6" )

class ShowAvtCountersOptionParser( OptionParser ):
   def register_options( self ):
      self.add_option( "-6", "bool" )
      self.add_option( "-vni", "uint" )
      self.add_option( "-avt", "uint" )
      self.add_option( "-path", "string" )
      self.add_option( "-h", "action", action=self.help )

   @staticmethod
   def help():
      helpStr = '''
Usage:
   show avt counters [ -6 ] [ -vni <vniId> ] [ -avt <avtID> ] [ -path <pathID> ]
   Print all avt counters entries corresponding to the matching filters.
Options:
   -6           show IPv6 counter entries
   -vni         show counter entries in the vni
   -avt         show counter entries in the specific avtID
   -path        show counter entries with the specific pathID
   -h           help
   '''
      print( helpStr )

   def get_ipv6( self ):
      return self.get_value( "-6" )

   def get_vni( self ):
      return self.get_value( "-vni" )

   def get_avt( self ):
      return self.get_value( "-avt" )

   def get_path( self ):
      return self.get_value( "-path" )

@commands.cmd( 'show avt avt-lookup',
      'Show the mapping of <Vni, appProfile> mapping to <avtId, tc>' )
def show_avt_avtlookup( cli ):
   def _vniAppSetkey( item ):
      return ( item.vni, item.app_profile_id )

   modName = _getAvtModule( cli, mclass='AvtGlobal' )
   if not modName:
      return

   try:
      ret = cli.bess.run_module_command( modName, 'showAvtLookup', 'EmptyArg', {} )
   except Exception as error:
      # pylint: disable=no-member
      print( error.errmsg )
      return

   cli.fout.write( 'Avt Lookup Table:\n' )
   for avtInfo in sorted( ret.avt_info, key=_vniAppSetkey ):
      cli.fout.write( f'   vni { avtInfo.vni } app_pro { avtInfo.app_profile_id }'
                      f' -> avt { avtInfo.avt_id } tc { avtInfo.traffic_class }'
                      f' dscp { avtInfo.dscp }\n' )
   cli.fout.write( '\n' )

@commands.cmd( 'show avt bestpath [AVT_BESTPATH_OPTS...]',
      'Show best paths for all or given vni / avt / peerGroupId ' )
def show_avt_bestpath( cli, opts ):
   try:
      parser = ShowAvtBestPathOptionParser( opts )
      parser.parse()
   except Exception as error:
      # pylint: disable=no-member
      print( error.errmsg )

   vniId = parser.get_vni()
   avtId = parser.get_avt()
   peerGroupId = parser.get_peerGroupId()

   modName = _getAvtModule( cli )
   if not modName:
      return

   # get the multi-hop path information
   try:
      ret = cli.bess.run_module_command( modName, 'showBestPath', 'EmptyArg', {} )
   except Exception as error:
      # pylint: disable=no-member
      print( error.errmsg )
      return

   cnt = 0
   cli.fout.write( 'Best Paths: ( * - <mh_index>*[path stack] multihop paths )\n' )
   for entry in sorted( ret.bestpath_info,
          key=lambda lbkey: ( lbkey.vni, lbkey.avt_id, lbkey.peer_group_id ) ):
      if ( ( entry.vni == vniId or vniId is None ) and
         ( entry.avt_id == avtId or avtId is None ) and
         ( entry.peer_group_id == peerGroupId or peerGroupId is None ) ):
         cli.fout.write( f'   vni { entry.vni } avt { entry.avt_id } '
                         f'peerGroupId { entry.peer_group_id } ->' )
         for path in entry.path_info:
            pathId = path.path_index
            if pathId & mhPathFlag:
               mp_index = ( pathId & ~( mhPathFlag ) )
               cli.fout.write( f' { mp_index }*' )
               hopList = list( path.hop )
               cli.fout.write( f'{ hopList }' )
            else:
               cli.fout.write( f' { path.path_index }' )

         cli.fout.write( '\n' )
         cnt += 1
   if cnt == 0:
      cli.fout.write( '    no corresponding avt path for given condition(s)\n\n' )

# This helper function returns the multihop dict with path-index as key and
# value is tuple. The first index contains the hop list and
# second list contains single element as Path MTU
# E.g. mhop_data{1: ([1,2,3], 1492) }
def _getMultipathHops( cli, modName, args ):

   try:
      ret = cli.bess.run_module_command( modName, 'showMultihopPath',
                                         'AvtMultihopPathArg', args )
   except Exception as error:
      # pylint: disable=no-member
      print( error.errmsg )
      return {}
   mhop_data = {}
   for mh in ret.mhpath_info:
      # First tuple element is a list of hops
      # Second tuple element is Path MTU
      mhop_data[ mh.mh_path_index ] = ( mh.hop, mh.path_mtu )
   return mhop_data

@commands.cmd( 'show avt mh-path [MHPATHID]',
      'Show the hops of multi-hop paths for all paths or for a given path id' )
def show_avt_mhpath( cli, mhp ):
   modName = _getAvtModule( cli )
   if not modName:
      return

   args = {}
   if mhp:
      args[ 'mh_path_index' ] = mhp

   ret = _getMultipathHops( cli, modName, args )
   if not ret:
      return

   cli.fout.write( 'Multihop Paths: \n' )
   for mh_index, value in ret.items():
      cli.fout.write( f'  Index { mh_index } ->' )
      # First index is hop list
      for elem in value[ 0 ]:
         cli.fout.write( f' { elem }' )
      # Second index is PATH MTU element.
      cli.fout.write( f', Path Mtu { value[ 1 ] }' )
      cli.fout.write( '\n' )

# Converts an sfe::pb::IpGenAddr to version-specific string
def pbIpGenAddrToStr( ipGenAddr ):
   return ( socket.inet_ntop( socket.AF_INET6, ipGenAddr.ipv6_addr )
         if ipGenAddr.HasField( "ipv6_addr" ) else
            socket.inet_ntoa( struct.pack( "=L", ipGenAddr.ip_addr ) ) )

def _show_avt_flow_info( cli, args, modName ):
   def pBool( flag ):
      return "true" if flag else "false"

   def _getMhPath( mhPath ):
      return mhPath & ~mhPathFlag if ( mhPath & mhPathFlag ) else 0

   try:
      ret = cli.bess.run_module_command( modName,
                                         'getAvtFlowCacheInfo',
                                         'AvtGetFlowCacheArg', args )
   except Exception as error:
      # pylint: disable=no-member
      print( error.errmsg )
      return 0

   flowkeyFmtStr = "vrfId %d ipA %s portA %d ipB %s portB %d protocol %d\n"
   dirEntryFmtStr = "   %s: avtId %d mhIndex %d pathIndex %d haPathIndex %d" \
                    " classified %s transit %s\n" \
                    "             pathValid %s pathRefCnt %d pathPtr %s nhIndex %d" \
                    " onlyFrrTraffic %s\n"
   for entry in ret.info:
      ipA = pbIpGenAddrToStr( entry.fk.ip_a )
      portA = socket.ntohs( entry.fk.port_a )
      ipB = pbIpGenAddrToStr( entry.fk.ip_b )
      portB = socket.ntohs( entry.fk.port_b )
      mhPath1 = _getMhPath( entry.dir1.mh_path_id )
      mhPath2 = _getMhPath( entry.dir2.mh_path_id )
      cli.fout.write( ( flowkeyFmtStr + dirEntryFmtStr + dirEntryFmtStr ) %
         ( entry.fk.vrf_id, ipA, portA, ipB, portB, entry.fk.protocol,
           "ipA->ipB", entry.dir1.avt_id, mhPath1, entry.dir1.direct_path_index,
           entry.dir1.to_ha_peer_path_index,
           pBool( entry.dir1.classified ),
           pBool( entry.dir1.transit ),
           pBool( entry.dir1.path_valid ),
           entry.dir1.path_ref_cnt,
           hex( entry.dir1.path_ptr ),
           entry.dir1.nh_index,
           pBool( entry.dir1.only_frr_traffic ),
           "ipB->ipA", entry.dir2.avt_id, mhPath2, entry.dir2.direct_path_index,
           entry.dir2.to_ha_peer_path_index,
           pBool( entry.dir2.classified ),
           pBool( entry.dir2.transit ),
           pBool( entry.dir2.path_valid ),
           entry.dir2.path_ref_cnt,
           hex( entry.dir2.path_ptr ),
           entry.dir2.nh_index,
           pBool( entry.dir2.only_frr_traffic ),
           ) )
   return ret.next_iter

@commands.cmd( 'show avt flow-info [IPV6]',
      'Show Avt flowcache info for all flows' )
def show_avt_flow_info( cli, opts ):
   try:
      parser = ShowAvtFlowInfoOptionParser( opts )
      parser.parse()
   except Exception as error:
      print( error )
      return

   ipv6 = parser.get_ipv6()

   modName = _getAvtModule( cli, ipv6=ipv6 )
   if not modName:
      return

   args = {}
   args[ 'max_iterate' ] = 100
   args[ 'iter' ] = 0
   next_iter = 1
   cli.fout.write( f'Avt{" IPv6" if ipv6 else ""} flowcache info:\n' )
   while next_iter != 0:
      next_iter = _show_avt_flow_info( cli, args, modName )
      args[ 'iter' ] = next_iter
   cli.fout.write( '\n' )

def _show_avt_counters_info( cli, args, modName, filters=None ):
   ret = cli.bess.run_module_command( modName,
                                      'showAvtCounters',
                                      'AvtGetCountersArg', args )
   keyFmtStr = "vni {} avt {} pathId {}\n"
   counterInfoFmtStr = "   flows {} bytes {} packets {}\n"
   for entry in ret.info:
      if filters and \
          ( filters.vni is None or int( filters.vni ) == entry.vni ) and \
          ( filters.avt is None or int( filters.avt ) == entry.avt_id ) and \
          ( filters.path is None or int( filters.path ) == entry.path_id ):
         cli.fout.write(
             keyFmtStr.format( entry.vni, entry.avt_id, entry.path_id ) +
             counterInfoFmtStr.format( entry.flows, entry.bytes_count,
                                       entry.pkts_count )
         )
   return ret.iter_key

@commands.cmd( 'show avt counters [AVT_COUNTERS_OPTS...]',
      'Show all Avt Counter entries' )
def show_avt_counter_info( cli, opts ):
   try:
      parser = ShowAvtCountersOptionParser( opts )
      parser.parse()
   except Exception as error:
      print( error )
      return

   ipv6 = parser.get_ipv6()
   vniId = parser.get_vni()
   avtId = parser.get_avt()
   pathId = parser.get_path()

   modName = _getAvtModule( cli, ipv6=ipv6 )
   if not modName:
      return

   args = {}
   args[ 'max_iterate' ] = 100
   args[ 'iter_lookup' ] = False
   args[ 'iter_vni' ] = 0
   args[ 'iter_avt_id' ] = 0
   args[ 'iter_path_id' ] = 0
   try:
      cli.bess.run_module_command( modName, 'showAvtCounters',
                                   'AvtGetCountersArg', args )
   except Exception as error:
      # pylint: disable=no-member
      print( error.errmsg )
      return

   AvtCounterEntry = namedtuple( "AvtCounterEntry", "vni avt path" )
   filters = AvtCounterEntry( vniId, avtId, pathId )

   cli.fout.write( 'Avt Counters info:\n' )
   while True:
      iter_key = _show_avt_counters_info( cli, args, modName, filters )
      if ( iter_key.vni, iter_key.avt_id, iter_key.path_id ) == ( None, None,
            None ):
         cli.fout.write( 'Error: Lookup failed with the given iterator key\n' )
         break
      if ( iter_key.vni, iter_key.avt_id, iter_key.path_id ) == \
         ( args[ 'iter_vni' ], args[ 'iter_avt_id' ], args[ 'iter_path_id' ] ):
         break
      args[ 'iter_lookup' ] = True
      args[ 'iter_vni' ] = iter_key.vni
      args[ 'iter_avt_id' ] = iter_key.avt_id
      args[ 'iter_path_id' ] = iter_key.path_id
   cli.fout.write( '\n' )

@commands.cmd( 'show avt counter memory',
      'Show avt counter memory allocation count' )
def show_avt_counter_mem_use( cli ):
   modName = _getAvtModule( cli )
   if not modName:
      return

   try:
      ret = cli.bess.run_module_command( modName, 'showAvtCountersMemUsage',
                                         'EmptyArg', {} )
   except Exception as error:
      # pylint: disable=no-member
      print( error.errmsg )
      return

   cli.fout.write( f"Avt Counters Memory Usage: { ret.in_use_count } objects\n" )

@commands.var_attrs( '[AVT_BESTPATH_OPTS...]' )
def avt_bppath_var_attrs():
   # Return (var_type(str), var_desc(str), var_candidates([str]))
   return ( 'opts', '[OPTIONS]( run show avt bestpath -h for more help )', [] )

@commands.var_attrs( '[AVT_COUNTERS_OPTS...]' )
def avt_counters_var_attrs():
   # Return (var_type(str), var_desc(str), var_candidates([str]))
   return ( 'opts', '[OPTIONS]( run show avt counters -h for more help )', [] )

@commands.var_attrs( 'AVTID' )
def avt_id_var_attrs():
   # Return (var_type(str), var_desc(str), var_candidates([str]))
   return ( 'int', '', [] )

@commands.var_attrs( '[IPV6]' )
def avt_ipv6_var_attrs():
   # Return (var_type(str), var_desc(str), var_candidates([str]))
   return ( 'opts', '', [] )

@commands.var_attrs( 'VNI' )
def vni_var_attrs():
   # Return (var_type(str), var_desc(str), var_candidates([str]))
   return ( 'int', '', [] )

@commands.var_attrs( '[MHPATHID]' )
def mhpath_id_opt_var_attrs():
   # Return (var_type(str), var_desc(str), var_candidates([str]))
   return ( 'int', '', [] )
