#!/usr/bin/env python3
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Ark
import BasicCli
import CliPlugin.EthIntfCli as EthIntfCli
import CliPlugin.IntfCli as IntfCli
import CliPlugin.IntfQueueCounterLib as IntfQueueCounterLib
from CliPlugin.QueueCountersModel import QueueCountersRate, EgressQueueCounters, \
     EgressQueueDestinationTypeCounters, EgressQueueTrafficClassCounters, \
     EgressQueueDropPrecedenceCounters, Counters
from CliPlugin.QosCliCommon import showQosInterface
from CliPlugin.AleCountersModel import LatePollEvent, FastPollStats
from CliPlugin.AleCountersModel import PollInfo, FastPollInfo, FastPollInfoDetail

def vtepCounter( idx, counterTable, snapshotTable ):
   """
   Read the counters from counterTable and deduct the value found in snapshotTable.
   snapshotTable contains the snapshot of counter values from counterTable when
   clear command is issued.
   """
   return vxlanCounter(
         idx, counterTable.counter, snapshotTable.counter )

def vxlanCounter( idx, counterColl, snapshotColl ):
   """
   Read the counters from counterTable and deduct the value found in snapshotTable.
   snapshotTable contains the snapshot of counter values from counterTable when
   clear command is issued.
   """
   currentCtr = counterColl.get( idx )
   snapshotCtr = snapshotColl.get( idx )

   def counter( attrName ):
      if currentCtr is None:
         # If the counter is missing, will return zeros. This is usually a transient
         # condition when we are caught in the middle of cleanup.
         return 0
      currentVal = getattr( currentCtr, attrName )
      # If the snapshot is not present, will return the running counter only.
      snapshotVal = 0 if snapshotCtr is None else getattr( snapshotCtr, attrName )
      # The current is the running counter minus the snapshot.
      return currentVal - snapshotVal
   return counter( 'pkts' ), counter( 'octets' ), \
          counter( 'bumPkts' ), counter( 'bumOctets' ), \
          counter( 'dropPkts' ), counter( 'dropOctets' )

def showIntfQueueCounterRates( mode, intf=None, mod=None, intfType=None,
                               trafficClassLabelFunc=None,
                               counterSupportedIntfsFunc=None,
                               countersRateSupportedIntfsFunc=None ):
   counterAccessor = IntfQueueCounterLib.getCounterAccessor()
   queueCounters = QueueCountersRate()
   queueCounters.egressQueueCounters = EgressQueueCounters()
   interfaces = queueCounters.egressQueueCounters.interfaces
   if intfType is None:
      intfType = EthIntfCli.EthPhyIntf
   if counterSupportedIntfsFunc is None:
      intfs = IntfCli.counterSupportedIntfs( mode, intf=intf, mod=mod,
                                             intfType=intfType )
   else:
      intfs = IntfCli.intfsSupported( mode, counterSupportedIntfsFunc, "Counters",
                                      intf=intf, mod=mod, intfType=intfType )
   if intfs:
      intfs = [ i for i in intfs if i.name.startswith( ( "Ethernet", "Switch",
                                                         "Port-Channel", ) ) ]

   if not intfs:
      mode.addWarning(
         f"Queue counter rates not supported on {intf or 'any interface'}" )
      return queueCounters

   for intfObj in intfs:
      intfId = intfObj.name
      interfaces[ intfId ] = EgressQueueDestinationTypeCounters()
      interfaces[ intfId ].ucastQueues = EgressQueueTrafficClassCounters()
      interfaces[ intfId ].mcastQueues = EgressQueueTrafficClassCounters()
      counter = counterAccessor.counter( intfId )
      numUnicastQueues = counterAccessor.numUnicastQueues( intfId )
      numMulticastQueues = counterAccessor.numMulticastQueues( intfId )
      ucastTc = interfaces[ intfId ].ucastQueues.trafficClasses
      mcastTc = interfaces[ intfId ].mcastQueues.trafficClasses
      for queueType, numQueues, tc in [ ( "ucast", numUnicastQueues, ucastTc ),
                                        ( "mcast", numMulticastQueues, mcastTc ) ]:
         for queueId in range( numQueues ):
            queueIdx = queueId if queueType == "ucast" \
                       else queueId + numUnicastQueues
            intfQueueRates = counter.intfQueueRates[ queueIdx ]
            if trafficClassLabelFunc is None:
               tcId = "TC" + str( queueId )
            else:
               tcId = trafficClassLabelFunc( intfId, queueId, queueType == "ucast" )
            if tcId in tc or tcId is None:
               # If the traffic class label already exists, skip it.  This is
               # probably a mapped traffic class and the rates for both traffc
               # classes will be combined in the mapped queue.
               # If tcId is None, also skip it. This queue does not exist.
               continue
            tc[ tcId ] = EgressQueueDropPrecedenceCounters()
            tc[ tcId ].dropPrecedences[ "DP0" ] = Counters()
            ctrs = tc[ tcId ].dropPrecedences[ "DP0" ]
            ctrs.enqueuedPacketsRate = intfQueueRates.pktsRate
            ctrs.enqueuedBitsRate = intfQueueRates.bitsRate
            ctrs.droppedPacketsRate = intfQueueRates.pktsDropRate
            ctrs.droppedBitsRate = intfQueueRates.bitsDropRate

   if countersRateSupportedIntfsFunc is None:
      supportedIntfs = IntfCli.countersRateSupportedIntfs( mode, intf=intf, mod=mod )
   else:
      supportedIntfs = IntfCli.intfsSupported( mode, countersRateSupportedIntfsFunc,
                                               "Counter rates are", intf=intf,
                                               mod=mod, intfType=intfType )

   for supportedIntf in supportedIntfs:
      if supportedIntf.name in interfaces:
         intf = interfaces[ supportedIntf.name ]
         intf.bandwidth = supportedIntf.bandwidth()
         intf.loadInterval = IntfCli.getActualLoadIntervalValue(
            supportedIntf.config().loadInterval )
         intfQos = showQosInterface( supportedIntf, mode=mode )

         for txQueueQos in intfQos.txQueueQosModel.txQueueList:
            for destType in ( "ucastQueues", "mcastQueues", ):
               tcs = intf[ destType ][ "trafficClasses" ]
               if trafficClassLabelFunc is None:
                  tcId = f"TC{txQueueQos.txQueue}"
               else:
                  queueId = int( txQueueQos.txQueue )
                  tcId = trafficClassLabelFunc( supportedIntf.name, queueId,
                                                destType == "ucastQueues" )

               # T4/TH4 has txQueue named as UC7, UC6, MC1, etc. while
               # the tcs object for ucast and mcast is each separately
               # indexed as 7, 6, 1, etc. respectively
               if tcId is not None:
                  tcId = tcId.replace( "UC", "" ).replace( "MC", "" )

               if tcId not in tcs:
                  continue

               tc = tcs[ tcId ]
               if txQueueQos.operationalSchedMode.schedulingMode == "roundRobin":
                  tc.schedMode = "weightedRoundRobin"
                  tc.wrrBw = txQueueQos.operationalWrrBw
               else:
                  tc.schedMode = "strictPriority"
               tc.shapeRate = txQueueQos.operationalShapeRate
   return queueCounters

def populatePollInfo( aleCountersConfig, aleCountersStatus ):
   """
   Creates/populates the PollInfo model from an Ale::Counters::CliConfig,
   and an Ale::Counters::Status object. The objects are arguments so that
   this code can be used with different polling objects (e.g. StrataCounters,
   and possibly SandCounters).
   """
   # Create empty model
   model = PollInfo()

   model.pollInterval = aleCountersConfig.periodPoll
   model.lastPollTimestamp = Ark.switchTimeToUtc(
      aleCountersStatus.timestampLastPoll )
   model.totalPollCount = aleCountersStatus.pollCount
   model.latePollCount = aleCountersStatus.latePollCount

   # If the polling interval is changed immediately before running this,
   # it's possible for the fetchCount to become zero between the zero check and
   # divide; Therefore, we copy it into a separate variable first.  See BUG614561.
   avgFetch = 0.0
   fetchCount = aleCountersStatus.fetchCount
   if fetchCount:
      avgFetch = aleCountersStatus.totalFetchTime / fetchCount

   model.totalFetchCount = fetchCount
   model.averageFetchTime = avgFetch
   model.maximumFetchTime = aleCountersStatus.maxFetchTime
   model.reEnqueueCount = aleCountersStatus.reEnqueueCount
   model.retryDequeueCount = aleCountersStatus.retryDequeueCount

   return model

def populateFastPollInfo( aleCountersConfig, aleCountersStatus ):
   """ Create and populate the Cli model for the detailed fast poll statistics"""
   # Create empty model
   model = FastPollInfo()

   model.pollFastInterval = aleCountersConfig.periodFastPoll
   model.lastPollTimestamp = Ark.switchTimeToUtc(
      aleCountersStatus.timestampLastPoll )
   model.totalPollCount = aleCountersStatus.pollCount
   model.latePollCount = aleCountersStatus.latePollCount

   avgFetch = 0.0
   fetchCount = aleCountersStatus.fetchCount
   if fetchCount:
      avgFetch = aleCountersStatus.totalFetchTime / fetchCount * 1000

   model.totalFetchCount = fetchCount
   model.averageFetchTime = avgFetch
   model.maximumFetchTime = aleCountersStatus.maxFetchTime * 1000
   model.reEnqueueCount = aleCountersStatus.reEnqueueCount
   model.retryDequeueCount = aleCountersStatus.retryDequeueCount

   return model

def populateFastPollInfoDetail(
      aleCountersStatus, latePollHistory, hourSnapshotDir, daySnapshotDir ):
   detail = FastPollInfoDetail()

   def latePollEvent( timestamp, delay ):
      out = LatePollEvent()
      out.timestamp = Ark.switchTimeToUtc( timestamp )
      out.delay = delay
      return out
   detail.latePollHistory = [
         latePollEvent( event.timestamp, event.delay * 1000 ) for
         event in latePollHistory.latePoll.values() ]

   def snapshotToStats( snapshots, snapshotName ):
      earliestSnapshot = snapshots[ 0 ]
      stats = FastPollStats()
      stats.statsStartTimestamp = Ark.switchTimeToUtc(
         earliestSnapshot.timestampSnapshot )
      stats.totalPollCount = (
            aleCountersStatus.pollCount - earliestSnapshot.pollCount )
      stats.latePollCount = (
            aleCountersStatus.latePollCount - earliestSnapshot.latePollCount )
      avgFetch = 0.0
      fetchCount = aleCountersStatus.fetchCount - earliestSnapshot.fetchCount
      if fetchCount:
         fetchTime = (
               aleCountersStatus.totalFetchTime - earliestSnapshot.totalFetchTime )
         avgFetch = fetchTime / fetchCount * 1000
      stats.averageFetchTime = avgFetch
      stats.maximumFetchTime = max( s.maxFetchTime for s in snapshots ) * 1000.
      stats.maximumFetchTime = max(
            stats.maximumFetchTime,
            aleCountersStatus.maxFetchTimePerUser.get( snapshotName, 0 ) )
      return stats

   if hourSnapshotDir.snapshot:
      detail.statsLastHour = snapshotToStats(
            list( hourSnapshotDir.snapshot.values() ), 'hour' )
      if daySnapshotDir.snapshot and (
            list( daySnapshotDir.snapshot.values() )[ 0 ].timestampSnapshot <
            list( hourSnapshotDir.snapshot.values() )[ 0 ].timestampSnapshot ):
         detail.statsLastDay = snapshotToStats(
               list( daySnapshotDir.snapshot.values() ), 'day' )
   return detail

def registerShowCpuCountersQueueCommand( showCommandClass ):
   # Unfortunately, each platform implements its own "show cpu counters queue"
   # command (with different CAPI models) which causes problems for btests/stests
   # with GrammarError (with guards are disabled). It'd be nice to enhance this API
   # to do something more intelligently, but for now, just register the command.
   BasicCli.addShowCommandClass( showCommandClass )
