#!/usr/bin/env python3
# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import optparse # pylint: disable=deprecated-module
import time
import re

import EapiClientLib
import syslog
from CliCommon import ResponseFormats
import Tac
import Cell
import EntityManager
import SharedMem

# trace0: errors
# trace2: debug
Debug = False
def trace0( msg ):
   print( '~~~', msg )
   syslog.syslog( syslog.LOG_NOTICE, msg )

def trace2( msg ):
   if Debug:
      print( '~~~', msg )
      syslog.syslog( syslog.LOG_DEBUG, msg )

numUcastQueues = 8
numMcastQueues = 8

class EapiClientWrapper:
   ''' A wrapper for the EAPI client used for feteching counters '''
   def __init__( self, entityManager ):
      self.eapiClient = EapiClientLib.EapiClient( disableAaa=True )
      self.entityMgr = entityManager
      trace2( 'Initializing EapiClient' )

      # For queue counters
      self.queueWriterConfigDir = None
      self.queueCounterAccessor = None
      self.queueReaderSm = None

      # For eth intf counter
      self.ethIntfWriterStatusDir = None
      self.shmemEm = None
      self.ethIntfCounterReaderSm = None
      self.allEthIntfCounterDir = None

      self.initEthCounter()
      self.initQueueCounter()

   def initEthCounter( self ):
      if self.allEthIntfCounterDir is not None:
         return

      mg = self.entityMgr.mountGroup()
      self.ethIntfWriterStatusDir = mg.mount(
         "interface/ethIntfCounter/writerStatus", "Tac::Dir", "ri" )
      mg.close( blocking=True )

      shmemEm = SharedMem.entityManager( sysdbEm=self.entityMgr )
      self.ethIntfCounterReaderSm = Tac.newInstance(
         'Interface::EthIntfCounterReaderSm', shmemEm, self.ethIntfWriterStatusDir )
      self.ethIntfCounterReaderSm.enableLegacyShmemManSupport()
      self.ethIntfCounterReaderSm.handleInitialized()
      self.allEthIntfCounterDir = \
         self.ethIntfCounterReaderSm.legacyShmemManCounterAccessor

   def initQueueCounter( self ):
      if self.queueCounterAccessor is not None:
         return

      mg = self.entityMgr.mountGroup()
      self.queueWriterConfigDir = mg.mount(
         Cell.path( "interface/queueCounter/writerConfigDir" ),
         "Tac::Dir", "ri" )
      mg.close( blocking=True )
      self.queueCounterAccessor = Tac.newInstance(
         'Interface::QueueCounter::CounterAccessor' )
      mountHelper = Tac.newInstance(
         'Interface::QueueCounter::SmashMountHelper',
         self.entityMgr.cEntityManager() )
      self.queueReaderSm = Tac.newInstance(
         'Interface::QueueCounter::WriterConfigDirSm',
         self.queueWriterConfigDir, self.queueCounterAccessor, mountHelper )

   def allEthIntfs( self ):
      return [ intf for intf in self.allEthIntfCounterDir.intfCounterDir
               if intf.startswith( 'Ethernet' ) ]

   def runCmd( self, cmd, revision=1, formatRequested=ResponseFormats.JSON ):
      try:
         resp = self.eapiClient.runCmds( revision, [ 'enable', cmd ],
                                         formatRequested )
         return resp[ 'result' ][ 1 ]
      except EapiClientLib.EapiException as e:
         trace0( 'Unable to get EAPI response: ' + str( e ) )
      return {}

   def isQueueStarvationPossible( self, intf, qid, qosInterfacesStatus ):
      '''Returns True if there is a possibility of the queue being starved'''
      for q in range( numUcastQueues ):
         qosTxQueueStatus = qosInterfacesStatus[ 'intfAllQosAll' ][ '%s' % intf ][ \
               'intfQosModel' ][ 'txQueueQosModel' ][ 'txQueueList' ][ q ]

         if qosTxQueueStatus[ 'txQueue' ] == str( qid ):
            return 'operationalGuaranteedBw' not in qosTxQueueStatus
      return False

   def getTxQueueToCosMap( self ):
      ''' Returns the global cos to TxQueue map '''
      # get qos maps in json format via EAPI
      qosMaps = self.runCmd( 'show qos maps' )
      trace2( 'qos maps: %s' % qosMaps )

      cosToTcMap = qosMaps[ 'cosToTcMap' ]
      tcToTxQueueMap = qosMaps[ 'tcToTxQueueMap' ]
      cosToTxQueueMap = {}
      for cos in range( len( cosToTcMap ) ):
         tc = cosToTcMap[ str( cos ) ]
         cosToTxQueueMap[ cos ] = tcToTxQueueMap[ str( tc ) ]

      txQueueToCosMap = { v: k for k, v in cosToTxQueueMap.items() }
      return txQueueToCosMap

   def getUsedBytesCount( self, intf, qid, qtype, mmuQueueStatus ):
      intfQueueStatus = mmuQueueStatus[ 'interfaces' ][ intf ]
      if qtype == "unicast":
         usedCount = int( intfQueueStatus[ 'ucastQueues' ][ str( qid ) ][ "used" ] )
      else:
         usedCount = int( intfQueueStatus[ 'mcastQueues' ][ str( qid ) ][ "used" ] )
      return usedCount

   def getPfcCounters( self, intf, pfcPriority ):
      counter = self.allEthIntfCounterDir.intfCounterDir.get( intf, None )
      if counter is None:
         return 0
      currentCounter = counter.intfCounter.get( 'current', None )
      if currentCounter is None:
         return 0

      pfcCounter = currentCounter.inPfcClassFrames.count.get( pfcPriority, 0 )
      trace2( 'Get pfc counter for %s priority %s: %s' %
              ( intf, pfcPriority, pfcCounter ) )
      return pfcCounter

   def getQueueCounters( self, intf, qid, qtype ):
      '''
      Returns the egress queue packet counters for each queue for each interface
      '''
      intfQueueCounter = self.queueCounterAccessor.counter( intf )
      if intfQueueCounter is None:
         return 0

      qid = qid if qtype == 'unicast' else (
         qid + self.queueCounterAccessor.numUnicastQueues( intf ) )
      qStat = intfQueueCounter.intfQueueStat.get( qid, None )
      if qStat is None:
         return 0

      trace2( f'Get queue counter for {intf} {qid}: {qStat.pkts}' )
      return qStat.pkts

class MmuQueueMonitorAgent:
   '''
   MmuQueueMonitorAgent handles user configuration and manages the thread that
   constantly monitors the state of queue related config & counters (via EAPI) and
   makes a collective decision whether a queue is stuck or not. If a queue is stuck,
   a syslog is generated and (based on the user configuration), resetting of chip is
   done.
   '''
   def __init__( self, entityManager, options ):
      self.entityManager = entityManager
      self.eapiClient = EapiClientWrapper( self.entityManager )
      # how often (in seconds) we look at the counters to detect queue stuck
      self.monitorInterval = 60
      self.constraintCheck = None

      # the action taken after mmu queue stucks is configured by users.
      self.resetChip = getattr( options, 'resetChip', False )
      if self.resetChip:
         trace2( 'Reset chip if MMU queue is detected stuck' )

      # to store the current counter values (and compare with the next poll cycle)
      self.queueCountersDict = {}
      self.pfcCountersDict = {}
      self.initCounterDicts()

      # to store a mapping of interface name to chipName (LinecardA/B)
      self.intfToChipNameDict = {}

      self.initialized = True
      trace2( 'MmuQueueMonitorAgent is initialized' )

   def initCounterDicts( self ):
      '''
      Initializes the counter dictionaries to default values of 0.
      { 'Ethernet1/1' :
         { 'unicast':
               { 0 : x,
                 1 : y,
                 ... more queue numbers ...
         { 'multicast':
               { 0 : a,
                 1 : b,
                 ... more queue numbers ...
         }
         ... more interfaces ...
      }
      '''
      for intf in self.eapiClient.allEthIntfs():
         self.queueCountersDict[ intf ] = { 'unicast' : {}, 'multicast' : {} }
         self.pfcCountersDict[ intf ] = { 'unicast' : {}, 'multicast' : {} }

         for qid in range( numUcastQueues ):
            self.queueCountersDict[ intf ][ 'unicast' ][ qid ] = 0
            self.pfcCountersDict[ intf ][ 'unicast' ][ qid ] = 0

         for qid in range( numMcastQueues ):
            self.queueCountersDict[ intf ][ 'multicast' ][ qid ] = 0
            self.pfcCountersDict[ intf ][ 'multicast' ][ qid ] = 0

   def run( self ):
      while self.initialized:
         # List of all chipNames that are reset in this iteration
         resetChipList = []

         iterStart = time.time()
         txQueueToCosMap = self.eapiClient.getTxQueueToCosMap()

         # pre-fetch mmu queue status in json format via EAPI
         mmuQueueStatus = self.eapiClient.runCmd(
            'show platform trident mmu queue status' )
         trace2( 'mmuQueueStatus: %s' % mmuQueueStatus )

         # pre-fetch qos interfaces status
         qosInterfacesStatus = self.eapiClient.runCmd( 'show qos interfaces' )

         # get interface to linecard mapping
         systemPortCmd = 'show platform trident agent port'
         systemPortCmdOutput = self.eapiClient.runCmd(
            systemPortCmd, formatRequested=ResponseFormats.TEXT )
         systemPortCmdOutputSplit = systemPortCmdOutput[ 'output' ].splitlines()

         # fetch linecard name (chipName) given the intf & populate dict
         # expected format: ' Fabric9/1/1 Linecard9/1 19 65 1 0'
         expr = r'\s*(?P<intf>\S+)\s+(?P<chip>\S+)\s+\d+\s+\d+\s+\d+\s+\d+'
         for outputLine in systemPortCmdOutputSplit:
            match = re.match( expr, outputLine )
            if not match:
               continue
            intf = match.group( 'intf' )
            chipName = match.group( 'chip' )
            self.intfToChipNameDict[ intf ] = chipName

         def detectStuckQueue( intf, qid, qtype, chipName ):

            # Populate queue counter variables & dicts
            qCountersPre = self.queueCountersDict[ intf ][ qtype ][ qid ]
            qCountersPost = self.eapiClient.getQueueCounters( intf, qid, qtype )
            self.queueCountersDict[ intf ][ qtype ][ qid ] = qCountersPost

            # Populate PFC counter variables & dicts
            pfcPriority = txQueueToCosMap[ qid ]
            pfcCountersPre = self.pfcCountersDict[ intf ][ qtype ][ qid ]
            pfcCountersPost = self.eapiClient.getPfcCounters( intf, pfcPriority )
            self.pfcCountersDict[ intf ][ qtype ][ qid ] = pfcCountersPost

            # CHECK 1: Check occupancy (if queue empty)
            queueOccupied = self.eapiClient.getUsedBytesCount( intf, qid,
                              qtype=qtype, mmuQueueStatus=mmuQueueStatus )
            if not queueOccupied:
               trace2( 'interface %s %s queue %d is empty, moving on' %
                       ( intf, qtype, qid ) )
               return
            trace2( 'interface %s %s queue %d is occupied, checking counters next' %
                    ( intf, qtype, qid ) )

            # CHECK 2: Check activity (if queue counters incrementing)
            isQCounterIncrementing = qCountersPost - qCountersPre
            if isQCounterIncrementing:
               trace2( 'interface %s %s queue %d is occupied and '
                       'active, moving on' % ( intf, qtype, qid ) )
               return
            trace2( 'interface %s %s queue %d is occupied but '
                    'inactive, checking PFC counters next...' %
                    ( intf, qtype, qid ) )

            # CHECK 3: Check if queue is paused
            isQueuePaused = pfcCountersPost - pfcCountersPre
            if isQueuePaused:
               trace2( 'interface %s %s queue %d is paused, moving on' %
                       ( intf, qtype, qid ) )
               return
            trace2( 'interface %s %s queue %d is occupied, inactive & '
                    'not paused, checking if starvation possible' %
                    ( intf, qtype, qid ) )

            # CHECK 4 (Optional): Check if queue starvation possible
            if self.constraintCheck == 'guaranteedBw':
               if self.eapiClient.isQueueStarvationPossible( intf, qid,
                                                             qosInterfacesStatus ):
                  trace2( 'interface %s %s queue %d is possibly '
                          'starved... moving on' % ( intf, qtype, qid ) )
                  return

            # ALL CHECKS ASSERT => Take action, queue is stuck
            queueStuckLog = 'interface %s %s queue %d stuck condition detected' % \
                            ( intf, qtype, qid )
            trace0( queueStuckLog )

            if self.resetChip:
               trace0( 'reset chip %s' % chipName )
               resetChipList.append( chipName )
               self.eapiClient.runCmd( 'platform trident %s reset' % chipName,
                                       formatRequested=ResponseFormats.TEXT )

         # Iterate through all intfs, queues to detectStuckQueue
         for intf in self.eapiClient.allEthIntfs():
            if intf in self.intfToChipNameDict:
               chipName = self.intfToChipNameDict[ intf ]
            else:
               continue

            if chipName in resetChipList:
               continue

            for qid in range( numUcastQueues ):
               if chipName in resetChipList:
                  continue
               detectStuckQueue( intf, qid, qtype="unicast", chipName=chipName )

            for qid in range( numMcastQueues ):
               if chipName in resetChipList:
                  continue
               detectStuckQueue( intf, qid, qtype="multicast", chipName=chipName )

         iterLen = time.time() - iterStart
         trace2( 'Time taken to run this loop: %s ' % iterLen )
         if iterLen < self.monitorInterval:
            Tac.runActivities( self.monitorInterval - iterLen )
         else:
            Tac.runActivities( 0 )

if __name__ == '__main__':
   syslog.openlog( 'MmuQueueMonitor', syslog.LOG_PID, syslog.LOG_LOCAL4 )

   # parse the command line input
   parser = optparse.OptionParser()
   parser.add_option( '--reset-chip', action='store_true', dest='resetChip' )
   parser.add_option( '--debug', action='store_true', dest='debug' )
   cmdOptions = parser.parse_args()[ 0 ]

   if getattr( cmdOptions, 'debug', False ):
      Debug = True

   em = EntityManager.Sysdb( 'ar' )
   agent = MmuQueueMonitorAgent( em, cmdOptions )
   if agent.initialized:
      try:
         agent.run()
      except KeyboardInterrupt:
         pass
