#!/usr/bin/env python3
# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import re

import Tac
from TypeFuture import TacLazyType

SESSION_DATA_COUNTER = "Clb.counter"

IpGenAddr = TacLazyType( 'Arnet::IpGenAddr' )
FlowKey = TacLazyType( 'ClbFlow::FlowKey' )

def getMatchPolicyCommand( match ):
   if match.valid:
      cmd = ( f'match rocev2 ib-bth opcode value 0x{match.value:02x} '
              f'mask 0x{match.mask:02x}' )
      exclude = match.exclude
      if exclude:
         cmd += f' exclude {" ".join( str(x) for x in exclude.values() )}'
      return cmd
   return None

def getActionCommand( action ):
   cmds = []
   if action.dscp != action.dscpInvalid:
      cmds.append( f'dscp {action.dscp}' )
   if action.trafficClass != action.trafficClassInvalid:
      cmds.append( f'traffic-class {action.trafficClass}' )
   if cmds:
      return "action " + " ".join( cmds )
   return None

def getDirectFlowName( clbDirectFlowStatus, flowKey ):
   if clbDirectFlowStatus:
      dfName = clbDirectFlowStatus.flowName.get( flowKey )
      if dfName:
         dfName += '-clb'
   return dfName

def getDirectFlowCounters( clbDirectFlowStatus, directFlowStatusDir,
                           flowKey ):
   # returns: packetCount, byteCount
   #
   # if counters are unavailable, return None
   statDir = directFlowStatusDir.entity.get( 'FixedSystem' )
   if statDir and clbDirectFlowStatus:
      dfName = getDirectFlowName( clbDirectFlowStatus, flowKey )
      if dfName:
         stats = statDir.flowStats.get( dfName )
         if stats:
            return stats.packetCount, stats.byteCount

   return None

def getDirectHwCounters( directHwFlowCounterStatus, directHwSmStatus, flowKey ):
   # returns: packetCount, byteCount (or None if counters are unavailable)
   if directHwSmStatus and directHwFlowCounterStatus:
      flowId = directHwSmStatus.flowId.get( flowKey )
      if flowId:
         flow = directHwFlowCounterStatus.flow.get( flowId )
         if flow:
            stats = flow.stats
            return stats.packetCount, stats.byteCount
   return None

flowRuleRe = re.compile( r"(\d+)_(.+)_(.+)_([\da-f]+)" )

def trafficPolicyFlowStatsCache( capabilities, trafficPolicyCounters ):
   if capabilities.hwOffloadMethod != 'trafficPolicy':
      return None
   prefix = '__clb_'
   stats = {}
   for policyKey, policyCounters in trafficPolicyCounters.policyMapCounters.items():
      policyName: str = policyKey.name
      if not policyName.startswith( prefix ):
         continue
      ruleName: str
      for ruleName, ruleCounters in policyCounters.classCounters.items():
         match = flowRuleRe.match( ruleName )
         if match is None:
            continue
         vrfId, srcAddr, dstAddr, queuePair = match.groups()
         flowKey = Tac.const( FlowKey( int( vrfId ),
                                       IpGenAddr( srcAddr ),
                                       IpGenAddr( dstAddr ),
                                       int( queuePair, 16 ) ) )
         countData = ruleCounters.countData
         stats[ flowKey ] = ( countData.pktHits, countData.byteHits )
   return stats
