# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
import json
import socket
import struct
import Tac
from ctypes import ( CDLL, POINTER, c_char_p )

libc = CDLL( 'libc.so.6' )
getprotobynumber = libc.getprotobynumber
getprotobynumber.restype = POINTER( c_char_p )

class ArMwhtFlagsHelp:
   flagsMap = {}

   @staticmethod
   def getFlagMap():
      if not ArMwhtFlagsHelp.flagsMap:
         for flagName in Tac.Type( 'SfeModules::ArMwhtFlags' ).attributes:
            val = Tac.enumValue( 'SfeModules::ArMwhtFlags', flagName )
            # power of 2
            if val & ( val - 1 ) == 0:
               ArMwhtFlagsHelp.flagsMap[ val ] = flagName
      return ArMwhtFlagsHelp.flagsMap

   @staticmethod
   def getFlagsInfo( flags ):
      info = []
      fmap = ArMwhtFlagsHelp.getFlagMap()
      val = 1 << 31
      while val:
         fVal = val & flags
         if fVal and ( fVal in fmap ):
            fname = fmap[ fVal ]
            info.append( fname )
         val = val >> 1
      return info

def getIntfShortNameFromId( intfId ):
   if intfId in ( 0, 0xffffffff ):
      return 'N/A'
   retIntf = Tac.Value( 'Arnet::IntfId' )
   retIntf.intfId = intfId
   return retIntf.shortName

def get_purge_time_ms( cli ):
   mgmt = [ m.name for m in cli.bess.list_modules().modules
               if m.mclass == 'Management' ]
   if len( mgmt ) == 0:
      return 0
   ret = cli.bess.run_module_command( mgmt[ 0 ],
                                      'getFlowCacheInfo',
                                      'EmptyArg', {} )
   return ret.purge_time_ms

def getIpStr( genAddr, ipv6=False ):
   if ipv6:
      return socket.inet_ntop( socket.AF_INET6, genAddr.ipv6_addr )
   else:
      return socket.inet_ntoa( struct.pack( '=L', genAddr.ip_addr ) )

def getPortStr( port ):
   return f'{socket.ntohs( port )}'

def getProtocolStr( protocol ):
   return getprotobynumber( protocol ).contents.value.decode().upper()

# Get the string representation of sfe::pb::FlowKey in the format:
# vrfId <vrfId> ipA <ipA> portA <portA> ipB <ipB> portB <portB> protocol <protocol>
def getFlowKeyStr( flowKey, ipv6=False ):
   vrfIdStr = f'vrfId {flowKey.vrf_id}'
   ipAStr = f'ipA {getIpStr( flowKey.ip_a, ipv6=ipv6 )}'
   portAStr = f'portA {getPortStr( flowKey.port_a )}'
   ipBStr = f'ipB {getIpStr( flowKey.ip_b, ipv6=ipv6 )}'
   portBStr = f'portB {getPortStr( flowKey.port_b )}'
   protoStr = f'protocol {getProtocolStr( flowKey.protocol )}'
   return f'{vrfIdStr} {ipAStr} {portAStr} {ipBStr} {portBStr} {protoStr}'

def getAgeStr( lastTime ):
   return f'{ int( lastTime / 10**6 ) } Secs ago'

def getKeyValueStr( d, printZero=True ):
   if printZero:
      return ', '.join( k + ': ' + str( v ) for k, v in d.items() )
   else:
      # Only include non-zero fields
      return ', '.join( k + ': ' + str( v ) for k, v in d.items() if v )

def printJson( cli, data ):
   cli.fout.write( json.dumps( data, indent=4 ) )

# Copied from SfCommonTestLib, since this lib does not exist on physical duts.
# Helper class that allows user to access Tac Enum the same way as normal class
# attributes.
class TacEnum:
   def __init__( self, typeName ):
      t = Tac.Type( typeName )
      for a in t.attributes:
         setattr( self, a, Tac.enumValue( t, a ) )

# Copied from SfCommonTestLib.
# produces a dictionary of:
#  { enumVal : 'enumString' }
def TacEnumDict( typeName ):
   t = Tac.Type( typeName )
   result = {}
   for a in t.attributes:
      result[ Tac.enumValue( t, a ) ] = a
   return result
