# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import commands
import socket
import struct
from cli_parser.option_parser import OptionParser

def ip2int( ipAddr ):
   return struct.unpack( "!I", socket.inet_aton( ipAddr.strip() ) )[ 0 ]

class ShowHaFlowOwnershipOptionParser( OptionParser ):
   def register_options( self ):
      self.add_option( "-src_ip", "string" )
      self.add_option( "-dst_ip", "string" )
      self.add_option( "-sport", "uint" )
      self.add_option( "-dport", "uint" )
      self.add_option( "-proto", "uint" )
      self.add_option( "-h", "action", action=self.help )

   @staticmethod
   def help():
      helpStr = '''
Usage:
  show ha flow-ownership -src_ip <ip> -dst_ip <ip> -sport <port> -dport <port>
    -proto <protoNum>

  Show the ownership of a flow with particular 5 tuples on the local device

Options:
  -src_ip       src ip of the packet to be looked up
  -dst_ip       dst ip of the packet to be looked up
  -sport        source port of the packet to be looked up
  -dport        destination port of the packet to be looked up
  -proto        protocol number of the packet to be looked up
  -h            help
  '''
      print( helpStr )

   def get_sport( self ):
      return self.get_value( "-sport", 0 )

   def get_src_ip( self ):
      return self.get_value( "-src_ip", "" )

   def get_dport( self ):
      return self.get_value( "-dport", 0 )

   def get_dst_ip( self ):
      return self.get_value( "-dst_ip", "" )

   def get_proto( self ):
      return self.get_value( "-proto", 0 )

def _show_ha_flow_info( cli, args, modName ):
   def pBool( flag ):
      return "true" if flag else "false"

   try:
      ret = cli.bess.run_module_command( modName,
                                         'getHaFlowCacheInfo',
                                         'HaGetFlowCacheInfoArg', args )
   except Exception as error:
      # pylint: disable=no-member
      print( error.errmsg )
      return 0

   for entry in ret.info:
      ipA = entry.fk.ip_a.ip_addr
      portA = socket.ntohs( entry.fk.port_a )
      ipB = entry.fk.ip_b.ip_addr
      portB = socket.ntohs( entry.fk.port_b )
      cli.fout.write( "vrfId %d ipA %s portA %d ipB %s portB %d protocol %d"
        " owner %s wanSideFlow %s genId %d\n" %
         ( entry.fk.vrf_id, socket.inet_ntoa( struct.pack( '=L', ipA ) ), portA,
           socket.inet_ntoa( struct.pack( '=L', ipB ) ), portB, entry.fk.protocol,
           pBool( entry.owner ), pBool( entry.wan_side_flow ), entry.gen_id ) )
   return ret.next_iter

@commands.cmd( 'show ha flow-info',
      'Show High Availability flowcache info for all flows' )
def show_ha_flow_info( cli ):
   mod = [ m.name for m in cli.bess.list_modules().modules
         if m.mclass in [ 'HaEgress', "HaIngress", "HaFlowSync" ] ]
   if not mod:
      return
   args = {}
   args[ 'max_iterate' ] = 100
   args[ 'iter' ] = 0
   next_iter = 1
   cli.fout.write( 'HA flowcache info:\n' )
   while next_iter != 0:
      next_iter = _show_ha_flow_info( cli, args, mod[ 0 ] )
      args[ 'iter' ] = next_iter
   cli.fout.write( '\n' )

def _show_ha_counters( cli, modName ):
   cnts = {}
   try:
      ret = cli.bess.run_module_command( modName,
                                         'getStats',
                                         'EmptyArg', {} )
   except Exception as error:
      # pylint: disable=no-member
      print( error.errmsg )
      return cnts
   cli.fout.write( "%s:" % modName )
   cli.fout.write( '''
   Owned packets: %d
   Peer owned packets: %d
   Owned flows: %d
   Peer owned flows: %d
   Missing flowcache packets: %d
   Missing userflowcache packets: %d
   Packets received from the HA peer: %d
   Wan side initiated flows: %d
   PeerOwned->Owned Flows: %d
   Owned->Peer Owned Flows: %d
   update genID flows: %d
   Packets received as transit node: %d
   Packets received from internal dps intf: %d
   Owned packets for which FIB look failed: %d
   Owned packets for which FIB look failed and HA peer is also down: %d
   Non-owned packets for which FIB lookup failed: %d
   Non-owned packets from peer for which FIB lookup failed: %d
   To peer vtep packets: %d
   Packets for which HA user flow cache does not exist: %d
   Self destined packets for which ownership check is skipped: %d
   Internet exit packets: %d
   Internet exit recirc packets: %d
''' % ( ret.owned_pkts,
       ret.non_owned_pkts,
       ret.owned_flows,
       ret.non_owned_flows,
       ret.missing_flowcache_pkts,
       ret.missing_userflowcache_pkts,
       ret.ha_pkts,
       ret.wanside_flows,
       ret.non_owned_to_owned_flows,
       ret.owned_to_non_owned_flows,
       ret.update_gen_id_flows,
       ret.transit_node_pkts,
       ret.dps_cp_pkts,
       ret.owned_fib_lookup_fail_pkts,
       ret.owned_fib_lookup_fail_peer_down_pkts,
       ret.non_owned_fib_lookup_fail_pkts,
       ret.non_owned_from_peer_fib_lookup_fail_pkts,
       ret.to_peer_vtep_pkts,
       ret.no_userflowcache_pkts,
       ret.self_destined_pkts,
       ret.internet_exit_pkts,
       ret.internet_exit_recirc_pkts,
      )
    )
   # return global counters
   cnts[ 'owned_flows' ] = ret.owned_flows
   cnts[ 'non_owned_flows' ] = ret.non_owned_flows
   cnts[ 'wanside_flows' ] = ret.wanside_flows
   cnts[ 'non_owned_to_owned_flows' ] = ret.non_owned_to_owned_flows
   cnts[ 'owned_to_non_owned_flows' ] = ret.owned_to_non_owned_flows
   cnts[ 'update_gen_id_flows' ] = ret.update_gen_id_flows
   return cnts

@commands.cmd( 'show ha counters HA_MODULE...',
        'Show counters for a given or all the HA ingress/egress modules' )
def show_ha_counters( cli, ha_modules ):
   aggOwnedFlows = 0
   aggNonOwnedFlows = 0
   aggWanSideFlows = 0
   aggOwnedToNonOwnedFlows = 0
   aggNonOwnedToOwnedFlows = 0
   aggUpdateGenIdFlows = 0

   if ha_modules[ 0 ] == "*":
      mods = [ m.name for m in cli.bess.list_modules().modules if
              m.mclass in [ 'HaEgress', "HaIngress", "HaFlowSync" ] ]
   else:
      mods = [ m.name for m in cli.bess.list_modules().modules if
              m.name in ha_modules ]
   cli.fout.write( 'HA counters:\n' )
   for i in mods:
      # Display HA counters for each HA module and return counter values
      cnts = _show_ha_counters( cli, i )
      if ha_modules[ 0 ] == "*":
         aggOwnedFlows += cnts[ 'owned_flows' ]
         aggNonOwnedFlows += cnts[ 'non_owned_flows' ]
         aggWanSideFlows += cnts[ 'wanside_flows' ]
         aggNonOwnedToOwnedFlows += cnts[ 'non_owned_to_owned_flows' ]
         aggOwnedToNonOwnedFlows += cnts[ 'owned_to_non_owned_flows' ]
         aggUpdateGenIdFlows += cnts[ 'update_gen_id_flows' ]
   # Display HA global counters when user intend to show counters for all HA modules
   # and we do have HA enabled
   if ha_modules[ 0 ] == "*" and mods:
      cli.fout.write( "global:" )
      cli.fout.write( '''
   Owned flows: %d
   Peer owned flows: %d
   Wan side initiated flows: %d
   PeerOwned->Owned Flows: %d
   Owned->Peer Owned Flows: %d
   update genID flows: %d
''' % ( aggOwnedFlows,
       aggNonOwnedFlows,
       aggWanSideFlows,
       aggNonOwnedToOwnedFlows,
       aggOwnedToNonOwnedFlows,
       aggUpdateGenIdFlows ) )

   cli.fout.write( '\n' )

def _show_ha_state( cli, modName ):
   def pBool( flag ):
      return "true" if flag else "false"

   def pHaOwnership( flag ):
      if flag == 0:
         return "hash"
      elif flag == 1:
         return "alwaysOwn"
      elif flag == 2:
         return "alwaysNotOwn"
      return "hash"

   try:
      ret = cli.bess.run_module_command( modName,
                                         'getHaState',
                                         'EmptyArg',
                                         {} )
   except Exception as error:
      # pylint: disable=no-member
      print( error.errmsg )
      return
   cli.fout.write( 'HA state:\n' )
   cli.fout.write( ' genId: %d\n flowOwnerForEven: %s\n peerDown: %s\n'
                   ' peerVtep: %s\n seed: %d\n'
                   ' ownership: %s\n peerGroupId: %d\n' %
                   ( ret.gen_id, pBool( ret.flow_owner_for_even ),
                      pBool( ret.peer_down ),
                     socket.inet_ntoa( struct.pack( '!L', ret.peer_vtep ) ),
                     ret.seed, pHaOwnership( ret.ownership ), ret.peer_group_id ) )

@commands.cmd( 'show ha state', 'Show High Availability State' )
def show_ha_state( cli ):
   mod = [ m.name for m in cli.bess.list_modules().modules
         if m.mclass == 'HaForward' ]
   if not mod:
      cli.fout.write( 'HA is not enabled\n' )
      return
   _show_ha_state( cli, mod[ 0 ] )

@commands.cmd( 'show ha flow-ownership -h',
     'show ha flow-ownership help' )
def show_ha_flow_ownership_help( cli ):
   opts = [ '-h' ]
   try:
      parser = ShowHaFlowOwnershipOptionParser( opts )
      parser.parse()
   except Exception as error:
      cli.CommandError( error )
   return

@commands.cmd( 'show ha flow-ownership HA_FLOW_OWNERSHIP_OPTS...',
     'Show ownership for a flow based on hash of the 5 tuples' )
def show_ha_flow_ownership( cli, opts ):
   def pBool( flag ):
      return "true" if flag else "false"
   mod = [ m.name for m in cli.bess.list_modules().modules
         if m.mclass in [ 'HaEgress', "HaFlowSync" ] ]
   if not mod:
      cli.fout.write( 'HA is not enabled\n' )
      return
   try:
      parser = ShowHaFlowOwnershipOptionParser( opts )
      parser.parse()
      args = {}
      srcIp = parser.get_src_ip()
      dstIp = parser.get_dst_ip()
      args[ 'src_ip' ] = ip2int( srcIp ) if srcIp else 0
      args[ 'dst_ip' ] = ip2int( dstIp ) if dstIp else 0
      args[ 'src_port' ] = parser.get_sport()
      args[ 'dst_port' ] = parser.get_dport()
      args[ 'proto' ] = parser.get_proto()
   except Exception as error:
      print( error )
      return
   try:
      ret = cli.bess.run_module_command( mod[ 0 ], 'getHaFlowOwnership',
                                         'HaGetFlowOwnershipArg', args )
      ipA = ret.fk.ip_a.ip_addr
      portA = socket.ntohs( ret.fk.port_a )
      ipB = ret.fk.ip_b.ip_addr
      portB = socket.ntohs( ret.fk.port_b )
      cli.fout.write( "vrfId all ipA %s portA %d ipB %s portB %d protocol %d"
                      " owner %s\n" %
                      ( socket.inet_ntoa( struct.pack( '=L', ipA ) ), portA,
                        socket.inet_ntoa( struct.pack( '=L', ipB ) ), portB,
                        ret.fk.protocol, pBool( ret.owner ) ) )

   except Exception as error:
      print( error )
      return

@commands.cmd( 'ha force-owner HA_FORCE_OWNER_OPTS' )
def ha_force_owner( cli, flag ):
   mod = [ m.name for m in cli.bess.list_modules().modules
         if m.mclass == 'HaForward' ]
   if not mod:
      cli.fout.write( 'HA is not enabled\n' )
      return

   args = {}
   args[ 'ownership' ] = 0
   if flag == 'hash':
      args[ 'ownership' ] = 0
   elif flag == 'own':
      args[ 'ownership' ] = 1
   elif flag == 'not-own':
      args[ 'ownership' ] = 2

   try:
      ret = cli.bess.run_module_command( mod[ 0 ], 'forceFlowOwner',
                                         'HaForwardForceFlowOwnerArg', args )
      _show_ha_state( cli, mod[ 0 ] )
   except Exception as error:
      print( error )
      return

def ha_module_candidates( cli ):
   var_candidates = [ '*' ]
   try:
      var_candidates += [ m.name for m in cli.bess.list_modules().modules
          if m.mclass in [ 'HaIngress', 'HaEgress', 'HaFlowSync' ] ]
   except:
      pass
   return var_candidates

@commands.var_attrs( 'HA_MODULE...' )
def ha_module_var_attrs():
   # Return (var_type(str), var_desc(str), var_candidates(func))
   return ( 'name+', 'one or more module names of HA type  (* means all)',
      ha_module_candidates )

@commands.var_attrs( 'HA_FLOW_OWNERSHIP_OPTS...' )
def ha_flow_ownership_var_attrs():
   # Return (var_type(str), var_desc(str), var_candidates([str]))
   return ( 'opts', '[OPTIONS]( run show ha flow-ownership -h for more help )', [] )

@commands.var_attrs( 'HA_FORCE_OWNER_OPTS' )
def ha_force_owner_var_attrs():
   # Return (var_type(str), var_desc(str), var_candidates([str]))
   return ( 'haOwner',
      'force HA flow ownership using keywords "hash", "own" or "not-own"', [] )

@commands.var_type_handler( 'haOwner' )
def ha_owner_type_handler( cli, val ):
   if 'hash'.startswith( val ):
      val = 'hash'
   elif 'own'.startswith( val ):
      val = 'own'
   elif 'not-own'.startswith( val ):
      val = 'not-own'
   else:
      raise cli.BindError( '"haOwner" must be either "hash", "own" or "not-own"' )
   return val
