# Copyright (c) 2022 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

from __future__ import absolute_import, division, print_function

import re
import six

from CliModel import Dict, Int, Float, Str, Model

import BasicCli
import Cell
import CliCommand
import CliMatcher
import CliParser
import CliPlugin.TechSupportCli
import LazyMount
import ShowCommand
import Tracing
from CliToken.Clear import clearKwNode
from CliToken.Platform import platformMatcherForShow
from CliToken.Platform import platformMatcherForClear


traceHandle = Tracing.Handle( 'SmbusCli' )
t5 = traceHandle.trace5

class Opcode( object ): # pylint: disable=useless-object-inheritance
   rx = 1
   tx = 2

sysdbRoot = None
smbusTopology = None
counterDir = None
clearCounterDir = None
smbusConfig = None


txBanner = "{0:>73}\n".format( "Est. Tx" )
txBanner += "{0:<18} {1:<24} {2:>20} {3:>10}\n".format(
            "Device", "Address", "Tx bytes", "bytes/sec" )
txBanner += "{0:<18} {1:<24} {2:>20} {3:>10}".format(
            "-" * 18, "-" * 24, "-" * 20, "-" * 10 )

rxBanner = txBanner.replace( "Tx", "Rx" )

class SmbusDevice():
   def __init__( self, accelId, busId, deviceId, deviceAddr, 
                 deviceName="SMBusDevice" ):
      self.accelId = accelId
      self.busId = busId
      self.deviceId = deviceId
      self.deviceAddr = deviceAddr
      self.deviceName = deviceName

def counterGuard( mode, token ):
   if "counter" in sysdbRoot[ "hardware" ].entryState:
      if token in sysdbRoot[ "hardware" ][ "counter" ].entryState:
         return None
   return CliParser.guardNotThisPlatform

def isModular():
   return Cell.cellType() != "fixed"

# "show platform ?" only shows "smbus" option when Smbus agent is running
matcherSmbus = CliMatcher.KeywordMatcher(
   'smbus',
   helpdesc='SMBus-device info' )

nodeSmbus = CliCommand.Node(
   matcher=matcherSmbus,
   guard=counterGuard )

def parseBus( busName ):
   m = re.match( r"bus(\d+)\:(\d+)", busName )
   if m:
      return m.groups()
   return ( None, None )

def createSmbusDeviceMap():
   smbusDeviceMap = {}
   # pylint: disable-next=too-many-nested-blocks
   for nodeType, node in six.iteritems( smbusTopology.node ):

      if nodeType == "Chassis":
         continue

      supeAccelId = None
      supeBusId = None
      # Traverse line/fabric card smbus mux
      if ( ( nodeType.startswith( "Linecard" ) or
             nodeType.startswith( "Fabric" ) or
             nodeType.startswith( "Switchcard" ) ) and
           ( "supe1Bus" in node.hwSmbus ) and
           ( "supe2Bus" in node.hwSmbus ) ):
         # The upstream bus addresses are the same for
         # active and standby supes
         if node.hwSmbus[ "supe1Bus" ].connectionEndpoint[ 0 ].otherEndConnectedTo:
            upstreamBus = node.hwSmbus[ "supe1Bus" ].connectionEndpoint[
                          0 ].otherEndConnectedTo.name
         else:
            upstreamBus = node.hwSmbus[ "supe2Bus" ].connectionEndpoint[
                          0 ].otherEndConnectedTo.name

         if upstreamBus.startswith( "bus" ):
            supeAccelId, supeBusId = parseBus( upstreamBus )

      for busAddr, bus in six.iteritems( node.hwSmbus ):
         if not busAddr.startswith( "bus" ) and \
            not busAddr.startswith( "cardBus" ) and \
            not busAddr.startswith( "isolatorBus" ):
            continue

         if busAddr.startswith( "bus" ):
            accelId, busId = parseBus( busAddr )
            if accelId is None:
               continue
         elif supeAccelId is not None:
            # Line/Fabric Smbus devices accessed by supe
            # i.e. cardBus or isolatorBus
            accelId = supeAccelId
            busId = supeBusId
         elif 0 in bus.connectionEndpoint:
            # Card bus connected directly to supervisor accel (no mux). There may be
            # multiple buses on a card
            endpoint = bus.connectionEndpoint[ 0 ].otherEndConnectedTo
            if not endpoint:
               continue
            accelId, busId = parseBus( endpoint.name )
            if accelId is None:
               continue
         else:
            continue

         if len( bus.device ):
            for deviceName, device in six.iteritems( bus.device ):
               deviceId = device.deviceId
               deviceAddr = ""

               if isModular():
                  if nodeType in [ "1", "2" ]:
                     # Supervisor
                     deviceAddr += "Supervisor%s/" % nodeType
                  elif ( nodeType.startswith( "Linecard" ) or
                         nodeType.startswith( "Fabric" ) ) and \
                       ( busAddr.startswith( "cardBus" ) or
                         busAddr.startswith( "isolator" ) ):
                     # Line/Fabric Smbus devices accessed by current supe
                     deviceAddr += "Supervisor%d/" % Cell.activeCell()
                  elif nodeType.startswith( "Fabric" ):
                     # Fabriccards
                     deviceAddr += "Supervisor%d/" % Cell.activeCell()
                  else:
                     # Linecards
                     deviceAddr += "%s/" % nodeType

               deviceAddr += "%02d/%02d/0x%02x" % ( int( accelId ),
                                int( busId ), int( deviceId ) )

               if deviceName.startswith( "powerSupply" ):
                  # The power supply is represented with Smbus
                  # address 0x0 on the topology. For fixed systems,
                  # the power supply has the FRU and PMBUS address
                  # offsets fixed at 0x50 and 0x58, respectively.
                  # Psu Fru
                  fruAddr = "0x%02x" % ( deviceId + 0x50 )
                  pmbusAddr = "0x%02x" % ( deviceId + 0x58 )
                  for addr in [ fruAddr, pmbusAddr ]:
                     newAddr = deviceAddr.replace( "0x00", addr )
                     newSmbusDevice = SmbusDevice( accelId, busId, deviceId, newAddr,
                                                   deviceName )
                     smbusDeviceMap[ newAddr ] = newSmbusDevice
               else:
                  if re.match( r"^([0-9]*)$", deviceName ):
                     # For tempsensors and Sol chip, use "modelName" instead
                     deviceName = "SMBusDevice"
                     if device.modelName:
                        deviceName =  device.modelName
                     elif device.api:
                        deviceName = device.api.split( "-" )[ 0 ]
                  elif deviceName.find( "PowerController" ) >= 0:
                     # Abbreviate PowerController to fit the CLI column
                     deviceName = deviceName.replace( "PowerController", "DPM" )
                  elif deviceName.startswith( ( "Ethernet", "Xcvr", "Fabric" ) ):
                     # Xcvr devices are represented with Smbus address 0x0
                     # on the topology. The 0x0 address is assigned
                     # by the following code in the FDL
                     # xcvrCtrl.xcvrController.smbusDeviceBase = ( xcvrName, 0x0 )
                     deviceAddr = deviceAddr.replace( "0x00", "0x50" )
                  newSmbusDevice = SmbusDevice( accelId, busId, deviceId, deviceAddr,
                                                deviceName )
                  smbusDeviceMap[ deviceAddr ] = newSmbusDevice
   return smbusDeviceMap

def accelGeneration( hostType, hostId ):
   host = smbusConfig.get( hostType )
   if not host:
      return False
   if "fixed" in hostId:
      # Strip leading "fixed-" from hostId
      hostId = hostId[6:]
   accel = host.get( hostId )
   if not accel:
      return False
   return accel.engineGenerationId

def sameGeneration( gen1, gen2 ):
   return gen1 and gen1.valid and gen1 == gen2

def smbusCounters():
   for hostType, hostDir in six.iteritems( counterDir ):
      for hostId, host in six.iteritems( hostDir ):
         accelGen = accelGeneration( hostType, hostId )
         if not sameGeneration( accelGen, host.generation ):
            continue
         for deviceId, device in six.iteritems( host.device ):
            yield Counter( hostType, hostId, deviceId, device )

def getCounterInfo():
   # smbusDevices is a collection of all registered smbus devices in the system
   smbusDevices = createSmbusDeviceMap()
   # counters is a collection of all smbus counters in the system. There are some
   # counters that appear without a registered smbus device, so we must expose both
   counters = { counter.addr : counter for counter in smbusCounters() }
   output = []
   # We must iterate through both the registered smbus devices as well as the
   # active counters in order to achieve a full collection of smbus activity
   for addr in sorted( smbusDevices.keys() ):
      smbusDevice = smbusDevices[ addr ]
      counter = counters.get( addr )
      output.append( ( smbusDevice.deviceName, addr, counter ) )

   for addr in sorted( counters.keys() ):
      counter = counters[ addr ]
      # if addr is already in smbusDevices, then we've already added information
      # about this one to the output, so we skip it. Else, we do as follows
      if addr not in smbusDevices:
         t5( "Found unregistered SMBus counter at address", addr )
         supeRe = re.compile( r'Supervisor(\d+)\/(\d+)\/(\d+)\/0x(54|5c|3c)' )
         if isModular() and supeRe.match( addr ):
            deviceName = "powerSupply"
         else:
            deviceName = "SMBusDevice"
         output.append( ( deviceName, addr, counter ) )
   return output

class Counter( object ): # pylint: disable=useless-object-inheritance
   def __init__( self, hostType, hostId, deviceId, device ):
      self.hostType = hostType
      self.hostId = hostId
      self.deviceId = deviceId
      self._txCount = device.txCount
      self.txRate = device.txRate
      self._rxCount = device.rxCount
      self.rxRate = device.rxRate
      self._timeoutErrorCount = device.timeoutErrorCount
      self._ackErrorCount = device.ackErrorCount
      self._busConflictErrorCount = device.busConflictErrorCount
      self.hostName = "{}/{}".format( self.hostType, self.hostId )
      if isModular():
         self.addr = ""
         if self.hostType == "cell":
            self.addr += "Supervisor"
         self.addr += "%s/%s" % ( self.hostId, self.deviceId )
      else:
         self.addr = self.deviceId

   def getAllCounts( self ):
      accelGen = accelGeneration( self.hostType, self.hostId )
      host = clearCounterDir.host.get( self.hostName )
      txClear, rxClear = 0, 0
      timeoutClear, ackClear, busClear = 0, 0, 0
      if host and sameGeneration( accelGen, host.generation ):
         device = host.device.get( self.deviceId )
         if device:
            txClear = device.lastTxClear
            rxClear = device.lastRxClear
            timeoutClear = device.lastTimeoutClear
            ackClear = device.lastAckClear
            busClear = device.lastBusConflictClear
      return [ self._rxCount - rxClear,
               self._txCount - txClear,
               self._timeoutErrorCount - timeoutClear,
               self._ackErrorCount - ackClear,
               self._busConflictErrorCount - busClear ]

   def readCount( self ):
      rxval, txval, _, _2, _3 = self.getAllCounts()
      return rxval, txval

   def errorCount( self ):
      _, _2, timeoutClear, ackClear, busClear = self.getAllCounts()
      return timeoutClear, ackClear, busClear

   def clear( self ):
      host = clearCounterDir.newHost( self.hostName )
      clearCounter = host.newDevice( self.deviceId )
      clearCounter.lastTxClear = self._txCount
      clearCounter.lastRxClear = self._rxCount
      clearCounter.lastTimeoutClear = self._timeoutErrorCount
      clearCounter.lastAckClear = self._ackErrorCount
      clearCounter.lastBusConflictClear = self._busConflictErrorCount
      host.generation = accelGeneration( self.hostType, self.hostId )

class DeviceErrorModel( Model ):
   name = Str( help="Name of SMBus Device" )
   timeout = Int( help="Number of timeout errors for SMBus device")
   ack = Int( help="Number of ack errors for SMBus device" )
   busConflict = Int( help="Number of bus conflict errors for SMBus device")

   def fmtErrPrint( self, addr ):
      return "{0:<18} {1:<24} {2:>9} {3:>7} {4:>13}".format(
         self.name, addr, self.timeout,
         self.ack, self.busConflict )

class DeviceModel( Model ):
   name = Str( help="Name of SMBus device" )
   rxByteCount = Int( help="Number of rx bytes processed" )
   rxEstByteRate = Float( help="Estimated rate of rx bytes per second" )
   txByteCount = Int( help="Number of tx bytes processed" )
   txEstByteRate = Float( help="Estimated rate of tx bytes per second" )

   def fmtRxPrint( self, addr ):
      return "{0:<18} {1:<24} {2:>20} {3:>10.2f}".format(
         self.name, addr, 
         self.rxByteCount, self.rxEstByteRate )

   def fmtTxPrint( self, addr ):
      return "{0:<18} {1:<24} {2:>20} {3:>10.2f}".format(
         self.name, addr, 
         self.txByteCount, self.txEstByteRate )

class SmbusErrorCountersModel( Model ):
   devices = Dict( valueType=DeviceErrorModel,
                   help="Mapping from SMBus device address to device "
                         "error counter info" )

   def render( self ):
      banner = "{0:>53} {1:>7} {2:>13}\n".format( "Timeout", "Ack", "Bus Conflict" )
      banner += "{0:<18} {1:<24} {2:>9} {3:>7} {4:>13}\n".format(
                "Device", "Address", "Errors", "Errors", "Errors" )
      banner += "{0:<18} {1:<24} {2:>9} {3:>7} {4:>13}".format(
                "-" * 18, "-" * 24, "-" * 9, "-" * 7, "-" * 13 )

      errorInfo = []
      for smbusAddr, devInfo in sorted( self.devices.items() ):
         errorInfo.append( devInfo.fmtErrPrint( smbusAddr ) )
      
      print( banner )
      print( '\n'.join( errorInfo ) )

class SmbusCountersModel( Model ):
   devices = Dict( valueType=DeviceModel,
                   help="Mapping from SMBus device address to device "
                        "counter info" )

   def render( self ):
      rxInfo = []
      txInfo = []
      for smbusAddr, devInfo in sorted( self.devices.items() ):
         rxInfo.append( devInfo.fmtRxPrint( smbusAddr ) )
         txInfo.append( devInfo.fmtTxPrint( smbusAddr ) )

      print( rxBanner )
      print( '\n'.join( rxInfo ) )
      print( '\n' + txBanner )
      print( '\n'.join( txInfo ) )

# "show platform ?" only shows "smbus" option when Smbus agent is running
nodeSmbus = CliCommand.guardedKeyword( keyword='smbus', 
                                       helpdesc='SMBus-device info',
                                       guard=counterGuard )

def doShowSmbusCountersErrors():
   errorRet = {}
   for deviceName, smbusAddr, counter in getCounterInfo():
      if counter:
         timeoutErrorCount, ackErrorCount, busConflictErrorCount = \
            counter.errorCount()
      else:
         # counter may be none if theres an inactive device
         timeoutErrorCount = ackErrorCount = busConflictErrorCount = 0

      errorCounter = DeviceErrorModel( name=deviceName, 
                                       timeout=timeoutErrorCount, 
                                       ack=ackErrorCount,
                                       busConflict=busConflictErrorCount )

      errorRet[ smbusAddr ] = errorCounter

   return SmbusErrorCountersModel( devices=errorRet ) 

def doShowSmbusCounters():
   devRet = {}
   for deviceName, smbusAddr, counter in getCounterInfo():
      if not counter:
         # counter may be none if theres an inactive device
         rxCount = txCount = 0
         rxRate = txRate = 0.0
      else:
         rxCount, txCount = counter.readCount()
         rxRate = counter.rxRate
         txRate = counter.txRate

      devRet[ smbusAddr ] = DeviceModel( name=deviceName, 
                                         rxByteCount=rxCount, 
                                         rxEstByteRate=rxRate,
                                         txByteCount=txCount, 
                                         txEstByteRate=txRate )

   return SmbusCountersModel( devices=devRet )

#-----------------------------------------------------------------------------
# show platform smbus counters
#-----------------------------------------------------------------------------
class ShowPlatformSmbusCounters( ShowCommand.ShowCliCommandClass ):
   syntax = 'show platform smbus counters'
   data = {
      'platform' : platformMatcherForShow,
      'smbus' : nodeSmbus,
      'counters' : 'Hardware-access counters for each device',
   }

   cliModel = SmbusCountersModel

   @staticmethod
   def handler( mode, args ):
      return doShowSmbusCounters()

BasicCli.addShowCommandClass( ShowPlatformSmbusCounters )

#-----------------------------------------------------------------------------
# show platform smbus counters errors
#-----------------------------------------------------------------------------
class ShowPlatformSmbusCountersErrors( ShowCommand.ShowCliCommandClass ):
   syntax = 'show platform smbus counters errors'
   data = {
      'platform' : platformMatcherForShow,
      'smbus' : nodeSmbus,
      'counters' : 'Hardware-access counters for each device',
      'errors' : 'Error info',
   }

   cliModel = SmbusErrorCountersModel

   @staticmethod
   def handler( mode, args ):
      return doShowSmbusCountersErrors()

BasicCli.addShowCommandClass( ShowPlatformSmbusCountersErrors )

#-----------------------------------------------------------------------------
# clear platform smbus counters
#-----------------------------------------------------------------------------
class ClearPlatformSmbusCounters( CliCommand.CliCommandClass ):
   syntax = 'clear platform smbus counters'
   data = {
      'clear' : clearKwNode,
      'platform' : platformMatcherForClear,
      'smbus' : nodeSmbus,
      'counters' : 'Hardware-access counters for each device',
   }

   @staticmethod
   def handler( mode, args ):
      clearCounterDir.host.clear()
      for counter in smbusCounters():
         counter.clear()

BasicCli.EnableMode.addCommandClass( ClearPlatformSmbusCounters )

CliPlugin.TechSupportCli.registerShowTechSupportCmd(
   '2023-05-05 10:32:22',
   cmds=[ 'show platform smbus counters errors' ],
   cmdsGuard=lambda: counterGuard( None, 'smbus' ) is None )

def Plugin( entityManager ):
   global sysdbRoot, counterDir, smbusTopology, smbusConfig, clearCounterDir
   sysdbRoot = entityManager.root()
   counterDir = LazyMount.mount( entityManager,
                                 "hardware/counter/smbus", "Tac::Dir", "ri" )
   smbusTopology = LazyMount.mount( entityManager,
                                    "hardware/smbus/topology",
                                    "Hardware::SmbusTopology", "r" )

   clearCounterDir = LazyMount.mount( entityManager, "hardware/counter/cli/smbus",
                                      "HardwareCounter::ClearCounterDir", "w" )

   smbusConfig = LazyMount.mount( entityManager, "hardware/smbus", "Tac::Dir", "ri" )
