# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=import-error
import commands
import socket
from ipaddress import IPv4Address, IPv6Address
from pybess.plugin_pb.arModule_msg_pb2 import AegisActionType

def action_to_string( action_type ):
   if action_type == AegisActionType.AEGIS_ACTION_FIREWALL:
      return "ACTION_FIREWALL"
   elif action_type == AegisActionType.AEGIS_ACTION_PBR:
      return "ACTION_PBR"
   elif action_type == AegisActionType.AEGIS_ACTION_POLICY_BASED_VPN:
      return "ACTION_POLICY_BASED_VPN"
   elif action_type == AegisActionType.AEGIS_ACTION_NONE:
      return "ACTION_NONE"
   else:
      return "UNKNOWN"

def _show_aegis_policy_flow_info( cli, args, modName, ipv6=False ):
   def _ipStr( genIp ):
      return str( IPv6Address( genIp.ipv6_addr ) ) if ipv6 else \
             str( IPv4Address( socket.ntohl( genIp.ip_addr ) ) )

   ret = cli.bess.run_module_command( modName,
                                      'getAegisPolicyFlowInfo',
                                      'GetAegisPolicyFlowInfoArg', args )
   for entry in ret.info:
      ipAStr = _ipStr( entry.fk.ip_a )
      ipBStr = _ipStr( entry.fk.ip_b )
      portA = socket.ntohs( entry.fk.port_a )
      portB = socket.ntohs( entry.fk.port_b )
      cli.fout.write( f"vrfId {entry.fk.vrf_id} ipA {ipAStr} portA {portA} "
            f"ipB {ipBStr} portB {portB} protocol {entry.fk.protocol} "
            f"aclVersion {entry.acl_version} lookupRetCode {entry.lookup_ret_code} "
            f"hwAclId {entry.hw_aclid} "
            f"actionType {action_to_string( entry.action_type )} "
            f"countAction {entry.count_action} " )
      if entry.action_type == AegisActionType.AEGIS_ACTION_FIREWALL:
         cli.fout.write( f"permit {entry.permit} reject {entry.reject} "
                         f"log {entry.log}\n" )
      else:
         cli.fout.write( f"fecid {entry.fecid} encapid {entry.encapid} "
                         f"ttl {entry.ttl}\n" )

   return ret.next_iter

def show_aegis_policy_flow_info_internal( cli, mclassName, ipv6=False ):
   mod = [ m.name for m in cli.bess.list_modules().modules
               if m.mclass == mclassName ]
   if len( mod ) == 0:
      return
   args = {}
   args[ 'max_iterate' ] = 100
   args[ 'iter' ] = 0
   next_iter = 1

   while next_iter != 0:
      next_iter = _show_aegis_policy_flow_info( cli, args, mod[ 0 ], ipv6=ipv6 )
      args[ 'iter' ] = next_iter
   cli.fout.write( '\n' )

@commands.cmd( 'show aegis-policy-flow-info ipv4',
     'Show Aegis Policy rules matched for all ipv4 flows' )
def show_aegis_policy_flow_info_ipv4( cli ):
   show_aegis_policy_flow_info_internal( cli, 'AegisPolicyLookupIpv4' )

@commands.cmd( 'show aegis-policy-flow-info ipv6',
     'Show Aegis Policy rules matched for all ipv6 flows' )
def show_aegis_policy_flow_info_ipv6( cli ):
   show_aegis_policy_flow_info_internal( cli, 'AegisPolicyLookupIpv6', ipv6=True )
