# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
import commands
import socket
import struct
from cli_parser.option_parser import OptionParser

def ip2int( ipAddr ):
   return struct.unpack( "!I", socket.inet_aton( ipAddr.strip() ) )[ 0 ]

def v6ToBytes( ipv6Addr ):
   return socket.inet_pton( socket.AF_INET6, ipv6Addr )

class ShowFibLookupOptionParser( OptionParser ):
   def register_options( self ):
      self.add_option( "-src_ip", "string" )
      self.add_option( "-dst_ip", "string" )
      self.add_option( "-sport", "uint" )
      self.add_option( "-dport", "uint" )
      self.add_option( "-proto", "uint" )
      self.add_option( "-vrf", "uint" )
      self.add_option( "-h", "action", action=self.help )

   @staticmethod
   def help():
      helpStr = '''
Usage:
  show fib lookup -dst_ip <ip> [ -vrf <vrfId> ]
  show fib lookup -dst_ip <ip> [ [ -vrf <vrfId> ] OPTIONS ]

  The command can be used in two ways:
    1. If just -dst_ip and optionally -vrf are the only parameters, then it will
       perform lookup of that IP and return all the nexthops.
    2. In the second form, it will create a packet header with the option
       values and return the exact nh that would be selected for the packet hash.

Options:
  -src_ip       src ip of the packet to be looked up
  -dst_ip       dst ip of the packet to be looked up
  -sport        source port of the packet to be looked up
  -dport        destination port of the packet to be looked up
  -proto        protocol number of the packet to be looked up
  -vrf          vrfId of the packet to be looked up",
  -h            help
  '''
      print( helpStr )

   def get_sport( self ):
      return self.get_value( "-sport", 0 )

   def get_src_ip( self ):
      return self.get_value( "-src_ip", "" )

   def get_dport( self ):
      return self.get_value( "-dport", 0 )

   def get_dst_ip( self ):
      return self.get_value( "-dst_ip", "" )

   def get_proto( self ):
      return self.get_value( "-proto", 0 )

   def get_vrf( self ):
      return self.get_value( "-vrf", 0 )

class ShowFibEntriesOptionParser( OptionParser ):
   def register_options( self ):
      self.add_option( "-count", "uint" )
      self.add_option( "-vrf", "uint" )
      self.add_option( "-h", "action", action=self.help )

   @staticmethod
   def help():
      helpStr = '''
Usage:
  show fib entries [ -count <num> ] [ -vrf <vrfId> ]

  iterate through the lpm table in Fib module and dump all the rules ( prefixes )
  and the fec it is pointing to ( l2 rewrite info )

Options:
  -count        max number of prefix->fec entries to read ( default 1024 )
  -vrf          vrfId of the lpm table.
  -h            help
  '''
      print( helpStr )

   def get_count( self ):
      return self.get_value( "-count", 1024 )

   def get_vrf( self ):
      return self.get_value( "-vrf", 0 )

def _get_fib_mod_name( cli, moduleClass ):

   # In the case of LPMng, the name of the fib module is FibIpv4Lpmng
   # Rather than looking for FibIpv4 == m.mlass, checking whether
   # "FibIpv4" is present in the name ( for FibIpv4 and FibIpv4Lpmng )

   return [ m.name for m in cli.bess.list_modules().modules
               if moduleClass in m.mclass ]

def show_fib_lookup_internal( cli, opts, ipv6 ):
   # if only dstIp option is set, then get all the nh for that lookup.
   # If at least one another option is set, treat it as a lookup nhinfo
   def getExactNh():
      return bool( srcIp or proto or sport or dport )

   dstIp = None
   srcIp = None
   proto = 0
   vrfId = 0
   sport = 0
   dport = 0
   moduleClass = "FibIpv6" if ipv6 else "FibIpv4"
   argType = "FibIpv6LookupArg" if ipv6 else "FibIpv4LookupArg"

   try:
      parser = ShowFibLookupOptionParser( opts )
      parser.parse()
      dstIp = parser.get_dst_ip()
      srcIp = parser.get_src_ip()
      proto = parser.get_proto()
      sport = parser.get_sport()
      dport = parser.get_dport()
      vrfId = parser.get_vrf()
   except Exception as error:
      print( error )
      return

   try:
      if ipv6:
         dip = v6ToBytes( dstIp ) if dstIp else v6ToBytes( "::" )
         sip = v6ToBytes( srcIp ) if srcIp else v6ToBytes( "::" )
      else:
         dip = ip2int( dstIp ) if dstIp else 0
         sip = ip2int( srcIp ) if srcIp else 0
   except Exception as error:
      print( error )
      return

   # if none of the parameters is set, return.
   if not dip and not getExactNh():
      parser.help()
      return

   args = {}
   args[ 'vrf' ] = vrfId
   args[ 'dst_ip' ] = dip
   args[ 'src_ip' ] = sip
   args[ 'src_port' ] = sport
   args[ 'dst_port' ] = dport
   args[ 'proto' ] = proto
   args[ 'direction' ] = True
   args[ 'exact_nh' ] = getExactNh()

   modName = _get_fib_mod_name( cli, moduleClass )

   if not modName:
      return
   try:
      ret = cli.bess.run_module_command( modName[ 0 ], 'routeLookup', argType, args )
   except Exception as error:
      print( "Error in lpm lookup; arguments passed: " )
      # pylint: disable=no-member
      print( error.info[ 'command_arg' ] )
      return

   cli.fout.write( 'fecId: %d\n' % ret.fec_id )
   for nh in ret.nh_info:
      cli.fout.write( '   nh_index: %s, gate: %d, vlan_id: %d, vni: %d\n' %
                      ( nh.nh_index, nh.gate, nh.vlan_id, nh.vni ) )

@commands.cmd( 'show fib lookup [FIB_LOOKUP_OPTS...]',
     'lookup the given dst ip in fib lpm' )
def show_fib_lookup( cli, opts ):
   show_fib_lookup_internal( cli, opts, ipv6=False )

@commands.cmd( 'show fibv6 lookup [FIBV6_LOOKUP_OPTS...]',
     'lookup the given dst ip in fibv6 lpm' )
def show_fibv6_lookup( cli, opts ):
   show_fib_lookup_internal( cli, opts, ipv6=True )

def _show_fib_entries_helper( fibMod, cli, args, ipv6 ):

   if ipv6:
      argType = 'FibIpv6GetFibEntriesArg'
   else:
      argType = 'FibIpv4GetFibEntriesArg'
   try:
      ret = cli.bess.run_module_command( fibMod,
                                         'getFibEntries',
                                         argType, args )

   except Exception as error:
      print( "Error in lpm iterate; arguments passed: " )
      # pylint: disable=no-member
      print( error.info[ 'command_arg' ] )
      return 0, 0, 0

   for rule in ret.rules:
      prefix = rule.prefix
      fecInfo = rule.fec
      if ipv6:
         prefixStr = socket.inet_ntop( socket.AF_INET6, prefix )
      else:
         prefixStr = socket.inet_ntoa( struct.pack( '!L', prefix ) )
      cli.fout.write( "prefix: %s/%d -> fec: %d\n" %
          ( prefixStr, rule.prefix_len, fecInfo.fec_id ) )
      for nh in fecInfo.nh_info:
         cli.fout.write( '   nh_index: %s, gate: %d, vlan_id: %d, vni: %d\n' %
                         ( nh.nh_index, nh.gate, nh.vlan_id, nh.vni ) )

   next_array = []
   # Make sure to save the state we received MBT in after the intial walk.
   for node in ret.node_keys:
      if node.node_key is not None:
         next_array.append( { "node_key" : node.node_key } )

   return len( ret.rules ), ret.next_iter, next_array

def show_fib_entries_internal( cli, opts, ipv6 ):
   if ipv6:
      moduleClass = "FibIpv6"
   else:
      moduleClass = "FibIpv4"

   modName = _get_fib_mod_name( cli, moduleClass )
   if not modName:
      return

   try:
      parser = ShowFibEntriesOptionParser( opts )
      parser.parse()
      entryCount = parser.get_count()
      vrfId = parser.get_vrf()
   except Exception as error:
      print( error )
      return

   next_iter = 1
   args = {}
   args[ 'next_iter' ] = 0
   args[ 'entry_count' ] = entryCount
   args[ 'vrf' ] = vrfId
   readCount = 0
   while True:
      count, next_iter, next_arr = _show_fib_entries_helper(
         modName[ 0 ], cli, args, ipv6 )
      readCount += count
      args[ 'next_iter' ] = next_iter
      args[ 'entry_count' ] = entryCount - readCount
      args[ 'node_keys' ] = next_arr

      if next_iter == 0 or readCount >= entryCount or \
         not cli.confirm( 'continue?' ):
         # No need to continue.
         break
   cli.fout.write( '\n' )

@commands.cmd( 'show fib entries [FIB_ENTRIES_OPTS...]',
           'show fib entries: display routes in lpm' )
def show_fib_entries( cli, opts ):
   show_fib_entries_internal( cli, opts, ipv6=False )

@commands.cmd( 'show fibv6 entries [FIBV6_ENTRIES_OPTS...]',
     'show fibv6 entries: display routes in lpm' )
def show_fibv6_entries( cli, opts ):
   show_fib_entries_internal( cli, opts, ipv6=True )

@commands.var_attrs( '[FIB_LOOKUP_OPTS...]' )
def fib_lookup_var_attrs():
   # Return (var_type(str), var_desc(str), var_candidates([str]))
   return ( 'opts', '[OPTIONS]( run show fib lookup -h for more help )', [] )

@commands.var_attrs( '[FIBV6_LOOKUP_OPTS...]' )
def fibv6_lookup_var_attrs():
   # Return (var_type(str), var_desc(str), var_candidates([str]))
   return ( 'opts', '[OPTIONS]( run show fibv6 lookup -h for more help )', [] )

@commands.var_attrs( '[FIB_ENTRIES_OPTS...]' )
def fib_entries_var_attrs():
   # Return (var_type(str), var_desc(str), var_candidates([str]))
   return ( 'opts', '[OPTIONS]( run show fib entries -h for more help )', [] )

@commands.var_attrs( '[FIBV6_ENTRIES_OPTS...]' )
def fibv6_entries_var_attrs():
   # Return (var_type(str), var_desc(str), var_candidates([str]))
   return ( 'opts', '[OPTIONS]( run show fibv6 entries -h for more help )', [] )
