# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
import commands
import socket
import struct
from cli_parser.option_parser import OptionParser


class ShowFlowcacheFlowsOptionParser( OptionParser ):

   def register_options( self ):
      self.add_option( "-max_iterate", "uint" )
      self.add_option( "-ip_addr", "string" )
      self.add_option( "-h", "action", action=self.help )

   @staticmethod
   def help():
      help_text = [
         "Usage:",
         " show flow-cache flows [OPTIONS]",
         "",
         "Options:",
         " -max_iterate\t\tmax number of flows to print (by default print all)",
         " -ip_addr\t\tflow IP address filter:",
         "\t\t\t   single IP - IP address can appear either as IP_A or IP_B",
         "\t\t\t   two IPs (comma separated) - return only IP_A and IP_B are found",
         " -h\t\t\thelp"
         ]
      print( "\n".join( help_text ) )

   def get_max_iterate( self ):
      # default value is None, which will make all flows be printed if user does not
      # specify a -max_iterate value
      return self.get_value( "-max_iterate" )

   def get_ip_addr( self ):
      return self.get_value( "-ip_addr", "" )

def _convert_filter_ip_address( cli, addr_in, ipv6 ):
   af = socket.AF_INET6 if ipv6 else socket.AF_INET
   addr_out = []
   try:
      addr_list = addr_in.split( ',' )
      if len( addr_list ) > 2:
         raise cli.CommandError( 'Incorrect format of flow IP filter' )
      for addr in addr_list:
         addr = socket.inet_pton( af, addr ) # host order
         if not ipv6:
            addr = socket.htonl( addr ) # net order
         addr_out.append( addr )
   except socket.error:
      raise cli.CommandError(
                      'Incorrect format of flow IP filter passed to inet_aton' )
   # IP addresses in FlowCacheInfo are stored in sorted order - lower IP address
   # is assigned to 'ip_a' and higher to 'ip_b'.
   addr_out.sort()
   return addr_out

def _show_flowcache_flows_args_preproc( cli, args, ipv6 ):
   cmd_args = {}
   ip_addr_filter = []
   cmd_args[ 'ipv6' ] = ipv6
   if 'max_iterate' in args and args[ 'max_iterate' ]:
      cmd_args[ 'max_iterate' ] = args[ 'max_iterate' ]
   if 'iter' in args and args[ 'iter' ]:
      cmd_args[ 'iter' ] = args[ 'iter' ]
   if 'ip_addr' in args and args[ 'ip_addr' ]:
      ip_addr_filter = _convert_filter_ip_address( cli, args[ 'ip_addr' ], ipv6 )
   return ( cmd_args, ip_addr_filter )

def _show_flowcache_flows_parse_options( cli, opts ):
   if opts is None:
      opts = []
   args = {}
   try:
      parser = ShowFlowcacheFlowsOptionParser( opts )
      parser.parse()
      args[ 'max_iterate' ] = parser.get_max_iterate()
      args[ 'ip_addr' ] = parser.get_ip_addr()
   except Exception as error:
      cli.CommandError( error )
      return None

   return args

def _show_flowcache_flows( cli, mgmtName, args, ipv6=False ):
   ( cmd_args, ip_addr_filter ) = _show_flowcache_flows_args_preproc(
         cli, args, ipv6 )
   ret = cli.bess.run_module_command( mgmtName,
                                      'getFlowCacheFlows',
                                      'GetFlowCacheFlowsArg', cmd_args )

   getFkAddr = ( ( lambda genAddr: genAddr.ipv6_addr ) if ipv6 else
                   lambda genAddr: genAddr.ip_addr )

   genAddrToStr = (
         ( lambda genAddr: socket.inet_ntop(
               socket.AF_INET6, genAddr.ipv6_addr ) ) if ipv6 else
           lambda genAddr: socket.inet_ntoa( struct.pack( '=L', genAddr.ip_addr ) ) )

   # Predetermine filter to optimize the loop since a lot of flows may be processed
   if len( ip_addr_filter ) == 0:
      entries = ret.info
   else:
      addrFilter = (
            lambda entry: getFkAddr( entry.fk.ip_a ) == ip_addr_filter[ 0 ] or
                          getFkAddr( entry.fk.ip_b ) == ip_addr_filter[ 0 ]
         ) if len( ip_addr_filter ) == 1 else (
            lambda entry: getFkAddr( entry.fk.ip_a ) == ip_addr_filter[ 0 ] and
                          getFkAddr( entry.fk.ip_b ) == ip_addr_filter[ 1 ]
         )
      entries = filter( addrFilter, ret.info )

   for entry in entries:
      cli.fout.write(
            f"vrfId {entry.fk.vrf_id} "
            f"ipA {genAddrToStr( entry.fk.ip_a )} "
            f"portA {socket.ntohs( entry.fk.port_a )} "
            f"ipB {genAddrToStr( entry.fk.ip_b )} "
            f"portB {socket.ntohs( entry.fk.port_b )} "
            f"protocol {entry.fk.protocol} "
            f"flags 0x{entry.flags:x} "
            f"age {int( entry.flow_age / 1000000 )} secs "
            f"flowState {entry.flow_state} "
            f"ctsDirection {entry.cts_direction}\n" )

   return ret.next_iter

def show_flowcache_flows_impl( cli, opts, ipv6=False ):
   mgmt = [ m.name for m in cli.bess.list_modules().modules
               if m.mclass == 'Management' ]
   if len( mgmt ) == 0:
      return
   mgmt = mgmt[ 0 ]
   ret = cli.bess.run_module_command( mgmt,
                                      'getFlowCacheInfo',
                                      'GetFlowCacheInfoArg', { 'ipv6' : ipv6 } )

   args = _show_flowcache_flows_parse_options( cli, opts )
   if args is None:
      return

   # Following shows total number of flows, not number returned
   cli.fout.write( f'show flow-cache{ "v6" if ipv6 else "" } flows: '
         f'{ret.flow_count} flows\n' )
   if ret.flow_count == 0:
      return

   # When max_iterate is not specified, show all flows. Otherwise, show exactly as
   # many iterates (buckets) as user specified... unless there aren't that many.
   iters_left = ( 0xffffffff if args.get( 'max_iterate', None ) is None
         else args[ 'max_iterate' ] )

   # Do requests in chunks to avoid locking the flow cache for long periods
   max_chunk_size = 100

   args[ 'iter' ] = 0
   while iters_left > 0:
      args[ 'max_iterate' ] = min( max_chunk_size, iters_left )
      args[ 'iter' ] = _show_flowcache_flows( cli, mgmt, args, ipv6=ipv6 )
      if args[ 'iter' ] == 0:
         break
      iters_left -= args[ 'max_iterate' ]

   cli.fout.write( '\n' )

@commands.cmd( 'show flow-cache flows [FLOWCACHE_CMD_OPTS...]',
     'show flow-cache flows: display flows in flow cache' )
def show_flowcache_flows( cli, opts ):
   show_flowcache_flows_impl( cli, opts )

@commands.cmd( 'show flow-cachev6 flows [FLOWCACHE_CMD_OPTS...]',
     'show flow-cachev6 flows: display flows in IPv6 flow cache' )
def show_flowcache_flowsv6( cli, opts ):
   show_flowcache_flows_impl( cli, opts, ipv6=True )

@commands.cmd( 'show flow-cache flows -h',
     'show flow-cache flows help' )
def show_flowcache_flows_help( cli ):
   opts = [ '-h' ]
   _show_flowcache_flows_parse_options( cli, opts )

# NAT USER OBJ COMMANDS
class ShowFlowCacheNatFlowsOptionParser( OptionParser ):

   def register_options( self ):
      self.add_option( "-max_iterate", "uint" )
      self.add_option( "-h", "action", action=self.help )

   @staticmethod
   def help():
      help_text = [
         "Usage:",
         " show flow-cache natflows [OPTIONS]",
         "",
         "Options:",
         " -max_iterate\t\tmax number of flows to print (by default print all)",
         " -h\t\t\thelp"
         ]
      print( "\n".join( help_text ) )

   def get_max_iterate( self ):
      return self.get_value( "-max_iterate" )

def _show_flowcache_natflows_args_preproc( cli, args ):
   cmd_args = {}
   if 'max_iterate' in args and args[ 'max_iterate' ]:
      cmd_args[ 'max_iterate' ] = args[ 'max_iterate' ]
   if 'iter' in args and args[ 'iter' ]:
      cmd_args[ 'iter' ] = args[ 'iter' ]
   return cmd_args

def _show_flowcache_natflows_parse_options( cli, opts ):
   if opts is None:
      opts = []
   args = {}

   try:
      parser = ShowFlowCacheNatFlowsOptionParser( opts )
      parser.parse()
      args[ 'max_iterate' ] = parser.get_max_iterate()

   except Exception as error:
      cli.CommandError( error )
      return None

   return args

def _show_flowcache_natflows( cli, mgmtName, args, iters_left ):
   cmd_args = _show_flowcache_natflows_args_preproc( cli, args )
   ret = cli.bess.run_module_command( mgmtName,
                                      'getFlowCacheNatFlows',
                                      'GetFlowCacheNatFlowsArg', cmd_args )
   # Same as SfeModules::ArNatTcpState enum
   tcpStateMap = {
         0 : 'Invalid',
         1 : 'SynSentFwd',
         2 : 'SynSentRvs',
         3 : 'SynSent2',
         4 : 'Established',
         5 : 'FinWaitFwd',
         6 : 'FinWaitRvs',
         7 : 'LastAck'
   }
   for entry in ret.info:
      natIp = entry.nat_ip_addr
      natPort = ( entry.nat_port )
      srcIp = entry.src_ip_addr
      srcPort = ( entry.src_port )
      dstIp = entry.dst_ip_addr
      dstPort = ( entry.dst_port )
      dynamicIp = entry.dynamic_ip_addr
      dynamicPort = ( entry.dynamic_port )
      staticIp = entry.static_ip_addr
      staticPort = ( entry.static_port )

      ipA = entry.fk.ip_a.ip_addr
      portA = socket.ntohs( entry.fk.port_a )
      ipB = entry.fk.ip_b.ip_addr
      portB = socket.ntohs( entry.fk.port_b )

      # Write the key information
      cli.fout.write( f"ipA {socket.inet_ntoa( struct.pack( '=L', ipA ) )}"
          f" portA {portA} ipB {socket.inet_ntoa( struct.pack( '=L', ipB ) )}"
          f" portB {portB} protocol {entry.fk.protocol} " )
      cli.fout.write( "\t" )
      # Write the userObj information
      if entry.twice:
         userObjInfo = ( f"srcIp {socket.inet_ntoa( struct.pack( '!L', srcIp ) )}"
            f" srcPort {srcPort}"
            f" dstIp {socket.inet_ntoa( struct.pack( '!L', dstIp ) )}"
            f" dstPort {dstPort}"
            f" dynamicIp {socket.inet_ntoa( struct.pack( '!L', dynamicIp ) )}"
            f" dynamicPort {dynamicPort}"
            f" staticIp {socket.inet_ntoa( struct.pack( '!L', staticIp ) )}"
            f" staticPort {staticPort}"
            f" reverse {entry.reverse}" )
      else:
         userObjInfo = ( f"srcIp {socket.inet_ntoa( struct.pack( '!L', srcIp ) )}"
            f" srcPort {srcPort}"
            f" natIp {socket.inet_ntoa( struct.pack( '!L', natIp ) )}"
            f" natPort {natPort} vrfId {entry.vrf_id}"
            f" age {round(entry.flow_age / ( 1000000 ), 2)}secs"
            f" reverse {entry.reverse}" )
      if entry.fk.protocol == socket.IPPROTO_TCP:
         tcpState = tcpStateMap.get( entry.tcp_state, 'Invalid' )
         userObjInfo += f" tcpState {tcpState}"
      cli.fout.write( userObjInfo )
      cli.fout.write( "\n" )
      iters_left -= 1
      if iters_left == 0:
         break
   return ret.next_iter, iters_left

@commands.cmd( 'show flow-cache natflows [FLOWCACHE_CMD_OPTS...]',
     'show flow-cache natflows : display nat user objects in flow cache' )
def show_flowcache_natuser_flows( cli, opts ):
   mgmt = [ m.name for m in cli.bess.list_modules().modules
               if m.mclass == 'ArNatFlowCache' ]
   if len( mgmt ) == 0:
      return

   next_iter = 1
   args = _show_flowcache_flows_parse_options( cli, opts )
   if args is None:
      return

   args[ 'iter' ] = 0
   max_chunk_size = 100
   if args.get( 'max_iterate', None ) is not None:
      # Show exactly as many iterates as user specified... unless there aren't
      # that many.
      iters_left = args[ 'max_iterate' ]
      while next_iter != 0 and iters_left > 0:
         args[ 'max_iterate' ] = min( max_chunk_size, iters_left )
         next_iter, iters_left = \
             _show_flowcache_natflows( cli, mgmt[ 0 ], args, iters_left )
         args[ 'iter' ] = next_iter
   else:
      iters_left = 1000000
      # When max_iterate is not specified - let's show all flows.
      args[ 'max_iterate' ] = max_chunk_size
      while next_iter != 0:
         next_iter, iters_left = _show_flowcache_natflows(
            cli, mgmt[ 0 ], args, iters_left )
         args[ 'iter' ] = next_iter

   cli.fout.write( '\n' )

@commands.cmd( 'show flow-cache natflows -h',
     'show flow-cache natflows help' )
def show_flowcache_natflows_help( cli ):
   opts = [ '-h' ]
   _show_flowcache_natflows_parse_options( cli, opts )

def show_flowcache_bucketinfo_impl( cli, ipv6=False ):
   mgmt = [ m.name for m in cli.bess.list_modules().modules
               if m.mclass == 'Management' ]
   if len( mgmt ) == 0:
      return
   ret = cli.bess.run_module_command( mgmt[ 0 ],
                                      'getFlowCacheBuckets',
                                      'GetFlowCacheBucketsArg', { 'ipv6' : ipv6 } )
   cli.fout.write( 'Element Count     Num Buckets\n' )
   cli.fout.write( '-------------     -----------\n' )
   for elemCount, numBuckets in sorted( ret.num_buckets_by_elems.items() ):
      cli.fout.write( '%3d               %d\n' % ( elemCount, numBuckets ) )
   cli.fout.write( '\n' )

show_flowcache_bucketinfo_desc = 'show flow-cache{} bucket-info: ' \
      'display number of buckets by number of elements'

@commands.cmd( 'show flow-cache bucket-info',
     show_flowcache_bucketinfo_desc.format( '' ) )
def show_flowcache_bucketinfo( cli ):
   show_flowcache_bucketinfo_impl( cli )

@commands.cmd( 'show flow-cachev6 bucket-info',
     show_flowcache_bucketinfo_desc.format( 'v6' ) )
def show_flowcache_bucketinfov6( cli ):
   show_flowcache_bucketinfo_impl( cli, ipv6=True )

def show_flowcache_stats_impl( cli, ipv6=False ):
   mgmt = [ m.name for m in cli.bess.list_modules().modules
               if m.mclass == 'Management' ]
   if len( mgmt ) == 0:
      return
   ret = cli.bess.run_module_command( mgmt[ 0 ],
                                      'getFlowCacheInfo',
                                      'GetFlowCacheInfoArg', { 'ipv6' : ipv6 } )

   cli.fout.write( 'Flow Count:      %d\n' % ret.flow_count )
   cli.fout.write( 'Num Buckets:     %d\n' % ret.num_buckets )
   cli.fout.write( 'Max Keys:        %s\n' % ( "unlimited" if ret.max_keys == 0
                                                   else str( ret.max_keys ) ) )
   cli.fout.write( 'Permanent Keys:  %d\n' % ret.perm_keys )
   cli.fout.write( 'Purgeable Keys:   %d\n' % ret.purgeable_keys )

   if ret.ephemeral_supported:
      cli.fout.write( 'Ephemeral Keys:   %d\n' % ret.total_eph_keys )
      cli.fout.write( 'Ephemeral Summary Keys:   %d\n' % ret.summary_eph_keys )
      cli.fout.write( 'Ephemeral Purge Timeout(ms):   %d\n' % ret.eph_purge_time_ms )
      cli.fout.write( 'Ephemeral Source Starting Port:   %d\n' %
                                                      ret.eph_source_port_start )
      cli.fout.write( 'Ephemeral Destination Starting Port:   %d\n' %
                                                   ret.eph_destination_port_start )

   cli.fout.write( 'Low Memory Mode: %s\n' % ( "True" if ret.lowmemory_mode
                                                   else "False" ) )
   if ret.purge_time_ms != 0:
      cli.fout.write( f"Purge Timeout(ms): {ret.purge_time_ms}\n" )

   if ret.fin_purge_ring_cnt != 0:
      cli.fout.write( f"FIN Pending TCP Flows: {ret.fin_purge_ring_cnt}\n" )

   if ret.rst_purge_ring_cnt != 0:
      cli.fout.write( f"RST Pending TCP Flows: {ret.rst_purge_ring_cnt}\n" )

   if ret.fin_purge_ring_full != 0:
      cli.fout.write( f"Times Unable To Queue(FIN): {ret.fin_purge_ring_full}\n" )

   if ret.rst_purge_ring_full != 0:
      cli.fout.write( f"Times Unable To Queue(RST): {ret.rst_purge_ring_full}\n" )

   if ret.total_purged != 0:
      cli.fout.write( f"Total Purged: {ret.total_purged}\n" )
   cli.fout.write( '\n' )

@commands.cmd( 'show flow-cache stats',
     'show flow-cache stats: display number of buckets by number of elements' )
def show_flowcache_stats( cli ):
   show_flowcache_stats_impl( cli )

@commands.cmd( 'show flow-cachev6 stats',
     'show flow-cachev6 stats: display number of buckets by number of elements' )
def show_flowcache_statsv6( cli ):
   show_flowcache_stats_impl( cli, ipv6=True )

@commands.var_attrs( '[FLOWCACHE_CMD_OPTS...]' )
def flow_cache_var_attrs():
   # Return (var_type(str), var_desc(str), var_candidates([str]))
   return ( 'opts', '[OPTIONS]( run show flow-cache flows -h for more help )', [] )
