# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
import commands
import socket
import Tac
import Arnet
import struct

from bessctlPluginLib import *
from cli_parser.option_parser import OptionParser

FilterWildcard = TacEnum( "SfeModules::FilterWildcard" )

class ShowIpfixFlowInfoOptionParser( OptionParser ):

   def register_options( self ):
      self.add_option( '-vrf_id', 'uint' )
      self.add_option( '-src_ip', 'string' )
      self.add_option( '-src_port', 'uint' )
      self.add_option( '-dst_ip', 'string' )
      self.add_option( '-dst_port', 'uint' )
      self.add_option( '-protocol', 'uint' )
      self.add_option( '-max_iterate', 'uint' )
      self.add_option( '-max_query_size', 'uint' )
      self.add_option( '-count_only', 'bool' )
      self.add_option( '-detail', 'bool' )
      self.add_option( '-flow_key_only', 'bool' )
      self.add_option( '-print_zero', 'bool' )
      self.add_option( '-json', 'bool' )
      self.add_option( '-h', 'action', action=self.help )

   # pylint: disable=no-member
   @staticmethod
   def help():
      help_text = [
         'Usage:',
         ' show ipfix flows [OPTIONS]',
         '',
         'Options:',
         ' -vrf_id\t\tvrf id to filter the ipfix flows (default: 0x00FFFFFF)',
         ' -src_ip\t\tsrc ip to filter the ipfix flows (default: "")',
         ' -src_port\t\tsrc port to filter the ipfix flows (default: 0xFFFFFFFF)',
         ' -dst_ip\t\tdst ip to filter the ipfix flows (default: "")',
         ' -dst_port\t\tdst port to filter the ipfix flows (default: 0xFFFFFFFF)',
         ' -protocol\t\tprotocol to filter the ipfix flows (default: 0)',
         ' -max_iterate\t\tmax number of flows to print (default: print all)',
         ' -max_query_size\tmax number of flows in one grpc call (default: 100)',
         ' -count_only\t\tprint the ipfix flow counts alone (default: False)',
         ' -detail\t\tprint the detailed view for flags, etc (default: False)',
         ' -flow_key_only\t\tprint flow keys alone(default: False)',
         ' -print_zero\t\tprint zero value fields (default: False))',
         ' -json\t\t\tprint in the json format (default: False)',
         ' -h\t\t\thelp'
         ]
      print( '\n'.join( help_text ) )

   def get_max_iterate( self ):
      val = self.get_value( '-max_iterate' )
      # 0 or None
      if not val:
         # all flows
         val = 32 * 10**6
      return val

   def get_max_query_size( self ):
      val = self.get_value( '-max_query_size' )
      # 0 or None
      if not val:
         val = 100
      return val

   def get_json( self ):
      return self.get_value( '-json', False )

   def get_detail( self ):
      return self.get_value( '-detail', False )

   def get_flow_key_only( self ):
      return self.get_value( '-flow_key_only', False )

   def get_count_only( self ):
      return self.get_value( '-count_only', False )

   def get_print_zero( self ):
      return self.get_value( '-print_zero', False )

   def get_vrf_id( self ):
      return self.get_value( '-vrf_id', FilterWildcard.vrfWildcard )

   def get_src_ip( self ):
      val = self.get_value( '-src_ip' )
      if val:
         n_sip = Arnet.IpAddr( val ).value
      else:
         n_sip = FilterWildcard.ipWildcard
      return n_sip

   def get_src_port( self ):
      return self.get_value( '-src_port', FilterWildcard.l4PortWildcard )

   def get_dst_ip( self ):
      val = self.get_value( '-dst_ip' )
      if val:
         n_dip = Arnet.IpAddr( val ).value
      else:
         n_dip = FilterWildcard.ipWildcard
      return n_dip

   def get_dst_port( self ):
      return self.get_value( '-dst_port', FilterWildcard.l4PortWildcard )

   def get_protocol( self ):
      return self.get_value( '-protocol', FilterWildcard.protocolWildcard )

# pylint: disable=no-member
def _show_ipfix_flow_info_args_preproc( cli, args ):
   cmd_args = {}

   cmd_args[ 'vrf_id' ] = args[ 'vrf_id' ]
   cmd_args[ 'src_ip' ] = args[ 'src_ip' ]
   cmd_args[ 'src_port' ] = args[ 'src_port' ]
   cmd_args[ 'dst_ip' ] = args[ 'dst_ip' ]
   cmd_args[ 'dst_port' ] = args[ 'dst_port' ]
   cmd_args[ 'protocol' ] = args[ 'protocol' ]
   cmd_args[ 'max_iterate' ] = args[ 'max_iterate' ]
   cmd_args[ 'iter' ] = args[ 'iter' ]
   return cmd_args

def _show_ipfix_flow_info_parse_options( cli, opts ):
   if opts is None:
      opts = []
   args = {}
   try:
      parser = ShowIpfixFlowInfoOptionParser( opts )
      parser.parse()
      args[ 'vrf_id' ] = parser.get_vrf_id()
      args[ 'src_ip' ] = parser.get_src_ip()
      args[ 'src_port' ] = parser.get_src_port()
      args[ 'dst_ip' ] = parser.get_dst_ip()
      args[ 'dst_port' ] = parser.get_dst_port()
      args[ 'protocol' ] = parser.get_protocol()
      args[ 'max_iterate' ] = parser.get_max_iterate()
      args[ 'max_query_size' ] = parser.get_max_query_size()
      args[ 'json' ] = parser.get_json()
      args[ 'detail' ] = parser.get_detail()
      args[ 'flow_key_only' ] = parser.get_flow_key_only()
      args[ 'count_only' ] = parser.get_count_only()
      args[ 'print_zero' ] = parser.get_print_zero()
   except Exception as error:
      cli.CommandError( error )
      return None

   return args

#
# SfeModules/arModule_msg.proto
# message BiFlowInfo {
#    KvContainerInfo c_info = 1;
#    FlowDirInfo dir_a = 2;
#    FlowDirInfo dir_b = 3;
#    ExportInfo e_info = 4;
# }
#

def getFlatIpfixFlowInfo( args, purge_time_ms, bi_info ):
   d = {}
   c_info = bi_info.c_info
   fk = c_info.fk
   dir_a = bi_info.dir_a
   dir_b = bi_info.dir_b
   e_info = bi_info.e_info

   kI = {}
   kI[ 'vrfId' ] = fk.vrf_id
   kI[ 'protocol' ] = fk.protocol
   kI[ 'ipA' ] = getIpStr( fk.ip_a )
   kI[ 'ipB' ] = getIpStr( fk.ip_b )
   kI[ 'portA' ] = socket.ntohs( fk.port_a )
   kI[ 'portB' ] = socket.ntohs( fk.port_b )

   d[ 'keyInfo' ] = kI

   if args[ 'flow_key_only' ] and not args[ 'json' ]:
      return d

   cI = {}
   flags = c_info.flags
   if args[ 'detail' ]:
      fInfo = ArMwhtFlagsHelp.getFlagsInfo( flags )
      if args[ 'json' ]:
         cI[ 'flags' ] = flags
         cI[ 'flagsDetail' ] = fInfo
      else:
         pInfo = '|'.join( flagName.replace( 'MWHT_', '' ) for flagName in fInfo )
         cI[ 'flags' ] = hex( flags ) + f'({pInfo})'
   else:
      if args[ 'json' ]:
         cI[ 'flags' ] = flags
      else:
         cI[ 'flags' ] = hex( flags )

   if args[ 'json' ]:
      cI[ 'lastAccessed' ] = c_info.last_accessed
   else:
      cI[ 'lastAccessed' ] = getAgeStr( c_info.last_accessed )

   purgeDeltaTimeout = c_info.purge_delta_timeout
   purgeTimeout = purgeDeltaTimeout + purge_time_ms

   if args[ 'json' ]:
      cI[ 'purgeDeltaTimeout' ] = purgeDeltaTimeout
      cI[ 'purgeTimeout' ] = purgeTimeout
   else:
      purgeDeltaTimeoutS = int( purgeDeltaTimeout / 1000 )
      purgeTimeoutS = int( purgeTimeout / 1000 )
      if purgeDeltaTimeout:
         cI[ 'purgeTimeout' ] = f'{purgeTimeoutS} Secs '\
                                f'({purgeDeltaTimeoutS} Secs Delta)'
      else:
         cI[ 'purgeTimeout' ] = f'{purgeTimeoutS} Secs'
   cI[ 'flowState' ] = c_info.flow_state
   cI[ 'ctsDir' ] = c_info.cts_dir
   cI[ 'exportedFlows' ] = e_info.exported_flows
   cI[ 'exportedEstablishedFlows' ] = e_info.exported_established_flows

   d[ 'commonInfo' ] = cI

   def dirFlowInfo( dirN, info ):
      if args[ 'json' ]:
         dirSuffix = ''
      else:
         dirSuffix = dirN

      dI = {}
      dI[ 'ftrVersion' + dirSuffix ] = info.ftr_version
      dI[ 'ftID' + dirSuffix ] = info.ft_id
      dI[ 'intfId' + dirSuffix ] = info.intf_id
      dI[ 'countPkts' + dirSuffix ] = info.count_pkts
      dI[ 'countBytes' + dirSuffix ] = info.count_bytes
      if args[ 'json' ]:
         dI[ 'lastPktTime' + dirSuffix ] = info.last_pkt_time
         dI[ 'startTime' + dirSuffix ] = info.start_time
         dI[ 'epochStartTime' + dirSuffix ] = info.epoch_start_time
         dI[ 'epochLastPktTime' + dirSuffix ] = info.epoch_last_pkt_time
      else:
         # start time and last packet time not relevant without any packet
         if info.count_pkts:
            dI[ 'lastPktTime' + dirSuffix ] = getAgeStr( info.last_pkt_time ) + \
            " ( " + str( info.epoch_start_time ) + " )"
            dI[ 'startTime' + dirSuffix ] = getAgeStr( info.start_time ) + \
            " ( " + str( info.epoch_last_pkt_time ) + " )"
         elif args[ 'print_zero' ]:
            dI[ 'lastPktTime' + dirSuffix ] = info.last_pkt_time
            dI[ 'startTime' + dirSuffix ] = info.start_time

      dI[ 'tcpFlags' + dirSuffix ] = info.tcp_flags
      dI[ 'appServiceId' + dirSuffix ] = info.app_service_id
      dI[ 'appId' + dirSuffix ] = info.app_id
      dI[ 'avtId' + dirSuffix ] = info.avt_id
      dI[ 'dpsPathId' + dirSuffix ] = info.dps_path_id
      dI[ 'endPointDpsPathId' + dirSuffix ] = info.end_point_dps_path_id
      dI[ 'countFlows' + dirSuffix ] = info.count_flows
      dI[ 'countEstablishedFlows' + dirSuffix ] = info.count_established_flows
      dI[ 'srcVtepIp' + dirSuffix ] = socket.inet_ntoa(
                        struct.pack( '<I', info.src_vtep ) ) if info.src_vtep else 0

      if dirN == 'A':
         dI[ 'exportedPkts' + dirSuffix ] = e_info.exported_pkts_a
         dI[ 'exportedBytes' + dirSuffix ] = e_info.exported_bytes_a
         dI[ 'exportEpochStartTime' + dirSuffix ] = e_info.epoch_start_time_a
      else:
         dI[ 'exportedPkts' + dirSuffix ] = e_info.exported_pkts_b
         dI[ 'exportedBytes' + dirSuffix ] = e_info.exported_bytes_b
         dI[ 'exportEpochStartTime' + dirSuffix ] = e_info.epoch_start_time_b
      return dI
   d[ 'dirInfoA' ] = dirFlowInfo( 'A', dir_a )
   d[ 'dirInfoB' ] = dirFlowInfo( 'B', dir_b )

   return d

def _show_ipfix_flow_info( cli, modName, args, purge_time_ms, iters_left, allFlows ):
   cmd_args = _show_ipfix_flow_info_args_preproc( cli, args )
   ret = cli.bess.run_module_command( modName,
                                      'getBiFlowInfo',
                                      'IpfixGetBiFlowInfoReqArg', cmd_args )
   count = 0
   for info in ret.bi_info:
      count += 1
      pd = getFlatIpfixFlowInfo( args, purge_time_ms, info )
      if args[ 'json' ]:
         allFlows.append( pd )
      elif not args[ 'count_only' ]:
         # Write the key information
         cli.fout.write( getKeyValueStr( pd[ 'keyInfo' ], printZero=True ) )
         if not args[ 'flow_key_only' ]:
            for iName in ( 'commonInfo', 'dirInfoA', 'dirInfoB' ):
               pInfo = getKeyValueStr( pd[ iName ], printZero=args[ 'print_zero' ] )
               # do not add blank lines
               if pInfo:
                  cli.fout.write( "\n\t" )
                  cli.fout.write( pInfo )
         cli.fout.write( "\n" )
      iters_left -= 1
      if iters_left == 0:
         break

   return ret.next_iter, iters_left, count

@commands.cmd( 'show ipfix flows -h', 'show ipfix flows help' )
def show_ipfix_flow_info_help( cli ):
   opts = [ '-h' ]
   _show_ipfix_flow_info_parse_options( cli, opts )

@commands.cmd( 'show ipfix flows [IPFIXFLOWS_CMD_OPTS...]',
     'show ipfix flows : display ipfix user objects in the flow cache' )
def show_ipfix_flow_info( cli, opts ):
   mod = [ m.name for m in cli.bess.list_modules().modules
               if m.mclass.startswith( 'IpfixFlowTracker' ) ]
   if len( mod ) == 0:
      cli.fout.write( 'IpfixFlowTracker module is not installed' )
      cli.fout.write( '\n' )
      return

   next_iter = 1
   args = _show_ipfix_flow_info_parse_options( cli, opts )
   if args is None:
      cli.fout.write( 'could not parse the command options' )
      cli.fout.write( '\n' )
      return

   args[ 'iter' ] = 0
   iters_left = args[ 'max_iterate' ]
   if args[ 'max_iterate' ] > args[ 'max_query_size' ]:
      args[ 'max_iterate' ] = args[ 'max_query_size' ]

   purge_time_ms = get_purge_time_ms( cli )
   allFlows = []
   countFlows = 0
   while next_iter != 0 and iters_left:
      next_iter, iters_left, count = _show_ipfix_flow_info(
         cli, mod[ 0 ], args, purge_time_ms, iters_left, allFlows )
      args[ 'iter' ] = next_iter
      countFlows += count

   if args[ 'json' ]:
      printJson( cli, allFlows )

   if not args[ 'json' ]:
      if args[ 'count_only' ]:
         cli.fout.write( f'{countFlows}' )
      else:
         cli.fout.write( f'show ipfix flows: {countFlows} flows' )
   cli.fout.write( '\n' )

@commands.var_attrs( '[IPFIXFLOWS_CMD_OPTS...]' )
def flow_cache_var_attrs():
   return ( 'opts', '[OPTIONS]( run show ipfix flows -h for more help )', [] )
