# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import CliGlobal
import ConfigMount
import LazyMount
import Shark
import SharkLazyMount
import Tac

from SfeUtils import flowCacheWalkerThreadsDefault
from TypeFuture import TacLazyType
from CliDynamicPlugin.SfeFlowCacheWalkerModel import (
   Counter,
   ModelShowCounters,
   ModelShowCountersDetail,
   ModelShowStatus,
   WalkerCounter,
   WalkerStatus
)
from CliPlugin.SfeFlowCacheWalkerCliLib import (
   CounterPrefix,
   maxThreads,
   WalkerCounterLib,
)

WalkerCounterShark = TacLazyType( 'SfeModulesApi::WalkerCounterShark' )

gv = CliGlobal.CliGlobal( bessCliConfig=None, cpuUtilDir=None,
                          sharkTables=None, veosConfig=None )

restartWarning = "Please save your config and restart the  Sfe agent"\
   " for your changes to take effect."

def handlerConfig( mode, args ):
   threads = args[ 'INPUT_THREADS' ]
   cpuUtil = gv.cpuUtilDir.cpuUtil
   # if dataplane core information is available, then infer number of control
   # plane cores since we assume cpCores are continuous from from 0 - firstDpCore.
   # If not available pass 0 so we can choose the minimum config
   nCpCores = min( cpuUtil ) if cpuUtil else 0
   threadCountDefault = flowCacheWalkerThreadsDefault( gv.veosConfig.platformRuby,
                                                       nCpCores ).value
   if threads == threadCountDefault:
      gv.bessCliConfig.flowCacheWalkerThreads = \
            Tac.Value( "Sfe::FlowCacheWalkerThreads", 0 )
   else:
      gv.bessCliConfig.flowCacheWalkerThreads = \
            Tac.Value( "Sfe::FlowCacheWalkerThreads", threads )
   mode.addWarning( restartWarning )

def handlerConfigDefault( mode, args ):
   # Reset the Sysdb value so threads are chosen based on platform
   gv.bessCliConfig.flowCacheWalkerThreads = 0

class WalkerSharkHelper:
   def __init__( self ):
      self.counterLib = WalkerCounterLib()

   def getWalkerStatus( self, status, inputThreadId=0 ):
      if inputThreadId:
         threads = [ inputThreadId - 1 ]
      else:
         threads = range( maxThreads )

      for threadIndex in threads:
         self._getWalkerThreadStatus( status, threadIndex )

   def _getWalkerThreadStatus( self, status, threadIndex ):
      tbl = gv.sharkTables[ threadIndex ]

      scanCount = self.counterLib.getCounter( tbl, CounterPrefix.general,
                                              "scanCount" )

      if not scanCount:
         return

      ws = WalkerStatus()
      ws.bucketStart = tbl.walkerCounter.bucketStartIdx
      ws.bucketEnd = tbl.walkerCounter.bucketEndIdx
      ws.lastScanDuration = tbl.walkerCounter.lastWalkDurationInMs
      ws.scanCount = scanCount.count

      status[ threadIndex + 1 ] = ws

   def getWalkerCounterDetail( self, counterDetail, inputThreadId=0 ):
      if inputThreadId:
         threads = [ inputThreadId - 1 ]
      else:
         threads = range( maxThreads )

      for threadIndex in threads:
         self._getWalkerThreadCounter( counterDetail, threadIndex )

   def _getWalkerThreadCounter( self, counterDetail, threadIndex ):
      tbl = gv.sharkTables[ threadIndex ]

      wc = WalkerCounter()
      for cPrefix, counters in self.counterLib.allCounters.items():
         cliCounter = Counter()
         for cName in counters:
            cVal = self.counterLib.getCounter( tbl, cPrefix, cName )
            if cVal:
               cliCounter.counters[ cName ] = cVal.count

         if cliCounter.counters:
            wc.walkerCounters[ cPrefix ] = cliCounter

      if wc.walkerCounters:
         counterDetail[ threadIndex + 1 ] = wc

   def getAggrWalkerCounter( self ):
      modelDetail = ModelShowCountersDetail()
      self.getWalkerCounterDetail( modelDetail.counterDetail )

      aggrWc = WalkerCounter()
      aggr = aggrWc.walkerCounters
      for _, wc in modelDetail.counterDetail.items():
         for grpName, grpCounter in wc.walkerCounters.items():
            if grpCounter:
               counter = aggr.setdefault( grpName, Counter() )
               for cName, cVal in grpCounter.counters.items():
                  counter.counters.setdefault( cName, 0 )
                  counter.counters[ cName ] += cVal
      return aggrWc

def handlerShowCounters( mode, args ):
   walkerCounter = WalkerSharkHelper().getAggrWalkerCounter()
   model = ModelShowCounters( counter=walkerCounter )

   return model

def handlerShowCountersDetail( mode, args ):
   inputThreadId = args.get( 'THREAD_ID', 0 )
   model = ModelShowCountersDetail()
   WalkerSharkHelper().getWalkerCounterDetail( model.counterDetail, inputThreadId )

   return model

def handlerShowStatus( mode, args ):
   inputThreadId = args.get( 'THREAD_ID', 0 )
   model = ModelShowStatus()
   WalkerSharkHelper().getWalkerStatus( model.status, inputThreadId )

   return model

def Plugin( entityManager ):
   gv.bessCliConfig = ConfigMount.mount(
      entityManager, "bess/cli/config", "Sfe::BessdCliConfig", "w" )
   gv.cpuUtilDir = LazyMount.mount( entityManager, "dp/cpu/util",
                                    "Sfe::CpuUtilDir", "r" )
   gv.veosConfig = LazyMount.mount( entityManager, "hardware/sfe/veosConfig",
                                    "Sfe::VeosConfig", "r" )
   gv.sharkTables = []
   for threadId in range( maxThreads ):
      mount = SharkLazyMount.mount( entityManager,
                                    WalkerCounterShark.getSharkMountPath( threadId ),
                                    'SfeModulesApi::WalkerCounterShark',
                                    Shark.mountInfo( 'shadow' ), True )
      gv.sharkTables.append( mount )
