# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import commands
import ipaddress
from google.protobuf.message import Message
from google.protobuf import text_format
from cli_parser.option_parser import OptionParser
import pybess

def mapt_formatter( message, indent, as_one_line ):
   def field_formatter( field_name, field_value ):
      if isinstance( field_value, Message ):
         field_value = text_format.MessageToString(
            field_value, indent=indent + 2, message_formatter=mapt_formatter )
         return f'{field_name} {{\n{field_value}{"": <{indent}}}}'

      if field_name in ( 'v4_prefix', 'v6_prefix', 'br_prefix' ):
         field_value = ipaddress.ip_address( field_value )
      return f'{field_name}: {field_value}'

   if message.DESCRIPTOR.name in ( 'MapDomainArg', 'MapRuleArg' ):
      field_list = [ field_formatter( field_desc.name, field_value )
                    for field_desc, field_value in message.ListFields() ]
      ret = f'\n{"": <{indent}}'.join( field_list )
      return ret
   return None

@commands.cmd( 'show nat map-t domains',
     'Show MAP-T domains details' )
def show_mapt_domains( cli ):
   mods = [ mod.name for mod in cli.bess.list_modules().modules
           if mod.mclass == 'Map' ]
   if not mods:
      return
   map_mod = mods[ 0 ]
   args = {}
   ret = cli.bess.run_module_command( map_mod, 'getDomains',
                                     'MapDomainArg', args )

   cli.fout.write( 'MAP-T domains:\n%s' % text_format.MessageToString(
       ret, message_formatter=mapt_formatter ) )

@commands.cmd( 'show nat map-t rules',
     'Show MAP-T rules details' )
def show_mapt_rules( cli ):
   mods = [ mod.name for mod in cli.bess.list_modules().modules
           if mod.mclass == 'Map' ]
   if not mods:
      return
   map_mod = mods[ 0 ]
   args = {}
   ret = cli.bess.run_module_command( map_mod, 'getRules',
                                     'MapRuleArg', args )

   cli.fout.write( 'MAP-T rules:\n%s' % text_format.MessageToString(
       ret, message_formatter=mapt_formatter ) )

class ShowNatFlowTableOptionParser( OptionParser ):

   def register_options( self ):
      self.add_option( "-max_iterate", "uint" )
      self.add_option( "-h", "action", action=self.help )

   @staticmethod
   def help():
      help_text = [
            "Usage:",
            " show nat flow-table [OPTIONS]",
            "",
            "Options:",
            " -max_iterate\tmax number of entries to print (by default print all)",
            " -h\thelp"
            ]
      print( "\n".join( help_text ) )

   def get_max_iterate( self ):
      return self.get_value( "-max_iterate" )

def _show_nat_flow_table_parse_options( cli, opts ):
   if opts is None:
      opts = []
   args = {}
   try:
      parser = ShowNatFlowTableOptionParser( opts )
      parser.parse()
      args[ 'max_iterate' ] = parser.get_max_iterate()
   except Exception as error: # pylint: disable=broad-except
      cli.CommandError( error )
      return None

   return args

def _show_nat_flow_table_args_preproc( cli, args ):
   cmd_args = {}
   if 'max_iterate' in args and args[ 'max_iterate' ]:
      cmd_args[ 'max_iterate' ] = args[ 'max_iterate' ]
   if 'iter' in args and args[ 'iter' ]:
      cmd_args[ 'iter' ] = args[ 'iter' ]
   return cmd_args

def _show_nat_flow_table( cli, modName, args ):
   def getIpStr( ip ):
      octets = [ ( ip >> 24 ) & 0xff, ( ip >> 16 ) & 0xff, ( ip >> 8 ) & 0xff,
                                                                     ip & 0xff ]
      ipStr = ".".join( map( str, octets ) )
      return ipStr

   cmd_args = _show_nat_flow_table_args_preproc( cli, args )
   ret = None
   try:
      ret = cli.bess.run_module_command( modName, 'getArNatFlowTable',
                                     'GetNatFlowTableArg', cmd_args )
   except pybess.bess.BESS.Error as error:
      cli.fout.write( error.errmsg )
      return 0
   if args[ 'iter' ] == 0:
      cli.fout.write( f'NAT flow table entries: {ret.numEntries}\n' )
   for row in ret.entry:
      cli.fout.write( 'Key: ' )
      cli.fout.write( f'ipAddr {getIpStr( row.nfk.ipAddr )} ' )
      port = ( ( row.nfk.port & 0xff ) << 8 ) + ( ( row.nfk.port >> 8 ) & 0xff )
      cli.fout.write( f'port {port} ' )
      cli.fout.write( f'proto {row.nfk.protocol} ' )
      cli.fout.write( f'profile {row.nfk.profile} ' )
      cli.fout.write( f'srcVrfId {row.nfk.srcVrfId} ' )
      cli.fout.write( 'Value: ' )
      cli.fout.write( f'natIpAddr {getIpStr( row.nfv.natIpAddr )} ' )
      cli.fout.write( f'srcIpAddr {getIpStr( row.nfv.srcIpAddr )} ' )
      cli.fout.write( f'natPort {row.nfv.natPort} ' )
      cli.fout.write( f'srcPort {row.nfv.srcPort} ' )
      cli.fout.write( f'flowAge {(row.nfv.flowAge / ( 1000000 )):.2f} ' )
      cli.fout.write( f'vrfId {row.nfv.vrfId}' )
      cli.fout.write( '\n' )
   return ret.next_iter

@commands.cmd( 'show nat flow-table [FLOWTABLE_CMD_OPTS...]',
      'show nat flow-table: Display Nat flow table details' )
def show_nat_flow_table( cli, opts ):
   mods = [ mod.name for mod in cli.bess.list_modules().modules
            if mod.mclass == 'ArNatFlowCache' ]
   if not mods:
      return

   args = _show_nat_flow_table_parse_options( cli, opts )
   if args is None:
      return

   next_iter = 1
   args[ 'iter' ] = 0
   # for large number of entries we are splitting it into chunks of size 100
   max_chunk_size = 100

   if args.get( 'max_iterate', None ) is not None:
      # if max_iterate value is specified in the command - show exactly as many
      # iterates as user specified, unless there aren't that many
      iters_left = args[ 'max_iterate' ]
      while next_iter != 0 and iters_left > 0:
         args[ 'max_iterate' ] = min( max_chunk_size, iters_left )
         next_iter = _show_nat_flow_table( cli, mods[ 0 ], args )
         args[ 'iter' ] = next_iter
         iters_left -= args[ 'max_iterate' ]
   else:
      # if max_iterate value is not specified in the command - show all entries
      args[ 'max_iterate' ] = max_chunk_size
      while next_iter != 0:
         next_iter = _show_nat_flow_table( cli, mods[ 0 ], args )
         args[ 'iter' ] = next_iter

@commands.cmd( 'show nat flow-table -h',
      'show nat flow-table help' )
def show_nat_flow_table_help( cli ):
   opts = [ '-h' ]
   _show_nat_flow_table_parse_options( cli, opts )

@commands.var_attrs( '[FLOWTABLE_CMD_OPTS...]' )
def nat_flow_table_var_attrs():
   return ( 'opts', '[OPTIONS]( run show nat flow-table, -h for more help )', [] )
