# Copyright (c) 2022 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import BasicCliUtil
import LazyMount
import re
import Tac
import Tracing
from TypeFuture import TacLazyType

__defaultTraceHandle__ = Tracing.Handle( "ResourceMgrCliLib" )
t0 = Tracing.trace0
t4 = Tracing.trace4

# dma mem *based* hardware table base models
AmaMemIndexTableType = TacLazyType( 'Asic::AmaMem::IndexTable' )
DmaMemHashTableType = TacLazyType( 'Asic::AmaMem::HashTable' )
# base class tac model for AsicUtils-generated IndexTable with
# Direct Memory Access
DirectIndexTableType = TacLazyType( 'Asic::DirectIndexTable' )
ResourceStub = Tac.singleton( 'Asic::HwTableStub' )

class HwTableMountInfo:
   # this class is used to store mount information for a
   # hardware table.
   def __init__( self, tableName, hamOwner ):
      self.tableName = tableName
      self.tableType = Tac.Type( self.tableName ).tacType
      # AsicResourceMgr is the hamOwner for shared hardware tables
      # dictionary used to store per
      self.hamOwner = hamOwner
      # table unitMountInfo.
      self.mountInfoDict = {}
      self.unitIdList = set()
      self.storage = 'smash'

   def updateMountInfoDict( self, featureName, unitMountInfo, unitIdArg=None ):
      # for shared tables mountInfoDict collection
      # shall have multiple entries, where each entry
      # corresponds to an agent sharing the same table
      if unitMountInfo:
         # if a unitIdArg has been specified, we retrieve
         # the specific unitMountInfo value from the collection
         if unitIdArg and unitMountInfo.get( unitIdArg ):
            self.mountInfoDict[ featureName ] = {
                  unitIdArg: unitMountInfo[ unitIdArg ] }
            self.unitIdList.add( unitIdArg )
         else:
            self.mountInfoDict[ featureName ] = unitMountInfo
            self.unitIdList.update( unitMountInfo )

   def isDirectIndexTable( self ):
      return self.tableType.isDerived( DirectIndexTableType.tacType )

   def checkAggregationSupported( self ):
      # check if the aggregated hardware tables are available.
      # The aggregated hardware table view reflects the content of
      # the table as it is written to in the hardware.
      # aggregated view of hardware tables are accessible for the following
      # hwTable types:
      # 1. Shared index tables where the aggregated version is using dmamem
      # 2. Exclusively owned direct index tables are index tables with dmamem
      # access. A single copy of the hardware table is maintained that is written
      # to by the featureAgent, and used to program the hardware.
      # 3. Hash Tables: Hash-based hardware tables backed by dmamem have a feature-
      # facing smash table written to by FA, and a direct index table containing
      # containing the aggregated table view used to program the hardware.
      baseTableTypeList = [
         DmaMemHashTableType,
         AmaMemIndexTableType
      ]
      if ( any( self.tableType.isDerived( dmaMemTableType.tacType ) for
         dmaMemTableType in baseTableTypeList ) or self.isDirectIndexTable()
      ):
         self.storage = 'dmamem'
         return True
      return False

class CustomTable:
   '''Used for AsicResourceMgrPlugins to create the customTableMappings
   and customTableHandlers'''
   def __init__( self, name, agent, hwTableList, handler=None, extension=None ):
      self.name = name
      self.agent = agent
      self.hwTableList = hwTableList
      self.handler = handler
      self.extension = extension

   def validateCustomTableName( self, allConfigPathDir, chipList=None ):
      '''Validate the custom table by checking underlying tables exist'''
      tableFound = 0
      for hwTable in self.hwTableList:
         hwTableMountInfos = collectMountInfos(
            allConfigPathDir, hwTable, self.agent )
         for fullTableName in hwTableMountInfos:
            chipName = fullTableName.split( '::' )[ 0 ]
         if hwTableMountInfos and chipName in chipList:
            tableFound += 1
      return tableFound == len( self.hwTableList )

class CustomTableFilter:
   '''Base class for the platform specific resource Filters for custom tables'''
   def __init__( self, name, entityManager, customTableMapping, chipList ):
      self.allConfigPathDir = LazyMount.mount(
         entityManager, 'hardware/resource', 'Tac::Dir', 'ri' )
      self.name = name
      self.entityManager = entityManager
      self.customTableMapping = customTableMapping
      self.chipList = chipList

   def getCustomTable( self ):
      customTableSet = set()
      for tableClass in self.customTableMapping.values():
         if not tableClass.validateCustomTableName( self.allConfigPathDir,
               self.chipList ):
            continue
         customTableSet.add( tableClass.name )
      return customTableSet

   def getCustomHandler( self ):
      handlerDict = {}
      for tableClass in self.customTableMapping.values():
         if not tableClass.validateCustomTableName( self.allConfigPathDir,
               self.chipList ):
            continue
         handlerDict[ tableClass.name ] = tableClass.handler, tableClass.extension
      return handlerDict

   def getOwnerAgent( self, table ):
      ownerAgentSet = set()
      for tableClass in self.customTableMapping.values():
         if tableClass.name == table:
            ownerAgentSet.add( tableClass.agent )
      return ownerAgentSet

def numEntries( tableFullName ):
   stub = Tac.singleton( 'Asic::HwTableStub' )
   desc = stub.tableDesc[ tableFullName ]
   return desc.numEntries()

# This function checks if a hwTable is valid. A valid hardware table has memory
# associated with it and have numEntries > 1. Events which result in an invalid
# hardware table include a uftMode change that results in hardware tables with
# zero physical tables associated with it.
def isValidHwTable( tableFullName ):
   return numEntries( tableFullName ) > 1

def populateSmashParams( entityManager, featureName, tableDesc ):
   """Populate the SmashParams for mounting smash tables via unified
   mounting architecture"""
   isLocalEm = getattr( entityManager, 'isLocalEm', False )
   # init smashParams
   return Tac.Value( 'Asic::HwTableMountParams::SmashParams',
                     localEm=isLocalEm,
                     user=featureName,
                     hamOwner=featureName,
                     entryCollSize=tableDesc.numEntries(),
                     statusEntryCollSize=tableDesc.numEntries(),
                     mountAsExclusiveReader=True )

def mountTables( entityManager,
                 unitId,
                 tableName,
                 featureName=None,
                 tableMountInfo=None,
                 isHardwareAccess=False ):
   """Unified mounting infrastructure for dmamem or smash tables"""
   try:
      tableDesc = ResourceStub.tableDesc[ tableName ]
   except KeyError:
      print( 'Table does not have a table descriptor' )
      raise
   tableEntity = Tac.singleton( tableName )
   responseNeeded = hasattr( tableEntity, 'writeResponse0' )
   hwTableParams = Tac.Value( 'Asic::HwTableMountParams',
                              tableDesc,
                              isHardwareAccess,
                              unitId,
                              responseNeeded,
                              entityManager.sysname )

   mountHwTableWrapper = Tac.Type( 'Asic::MountHwTableWrapper' )

   if hwTableParams.storageType() == 'smash':
      assert featureName
      # Populate the smash params for the hwtable
      hwTableParams.smashParams = populateSmashParams(
         entityManager, featureName, tableDesc )

   # If smash table, populate the smash params, and mount the table
   # and return it along with the configStatus if it exists
   # else it's dmamem, return the dmatable along with the mountinfo
   if tableMountInfo:
      hwTable = mountHwTableWrapper.doMountHwTable(
            hwTableParams )
      if tableMountInfo.statusPath:
         configStatus = mountHwTableWrapper.doMountHwTableResponse(
            hwTableParams )
      else:
         configStatus = None
      return hwTable, configStatus
   else:
      # need root privilege for dmamem mounting for cli
      mountInfo = tableDesc.getAmaMemMountInfo( int( unitId ) )
      # clear out any references to hardware tables, to force
      # syncing in the new table
      entityManager.unmount( mountInfo.path )
      with BasicCliUtil.RootPrivilege():
         hwTable = mountHwTableWrapper.doMountHwTable(
               hwTableParams )
      return hwTable, mountInfo

def collectMountInfos( allConfigPathDir, hwTblArg='*',
                       featureAgentArg='*', unitIdArg=None ):
   # This function collects mount information for hardware tables
   # programming the active chip and profile by checking for a
   # non-empty unitMountInfo collection.

   # mountInfos dictionary keyed by the hardware table name
   hwTableMountInfos = {}

   def collectMounts( hamOwner, tableName, tableDir ):
      m = HwTableMountInfo( tableName, hamOwner )
      if featureAgentArg and featureAgentArg != '*':
         if featureAgentArg in tableDir:
            resourceMountInfo = tableDir[ featureAgentArg ]
            m.updateMountInfoDict( resourceMountInfo.featureName,
                  resourceMountInfo.unitMountInfo, unitIdArg )
      else:
         # handles the case where the featureAgentArg value is
         # either * or empty
         for resourceMountInfo in tableDir.values():
            m.updateMountInfoDict( resourceMountInfo.featureName,
                  resourceMountInfo.unitMountInfo, unitIdArg )
      if m.mountInfoDict:
         hwTableMountInfos[ tableName ] = m
   for hamOwner, configPathDir in allConfigPathDir.items():
      t4( "figure out the configPath per resource per agent per unit id" )
      for tableName, tableDir in configPathDir.items():
         if hwTblArg in tableName or hwTblArg in '*':
            collectMounts( hamOwner, tableName, tableDir )
   return hwTableMountInfos

def shouldSkipEntry( numSubEntries, entry, subIdx, verbose ):
   skipEntry = False
   if numSubEntries > 1:
      if not entry.isSubEntryValid( subIdx ) and not verbose:
         skipEntry = True
   else:
      if not entry.isValid() and not verbose:
         skipEntry = True
   return skipEntry

def getAlignments( attrName, attrType=None, attrPrefix='' ):
   """Get the alignment from the attribute dict, or default to aligning right
   such as entryIndex, which in most cases is added without a attribute (TacAttr)"""
   # add alignment of entryIndex/hwIndex as it should be always
   # right aligned, otherwise go through attributes, check type and add them
   # appropriatel
   if attrName in ( attrPrefix + 'entryIndex', attrPrefix + 'dataIndex' ):
      return "right"
   memberType = attrType.memberType
   if memberType.isDerived( Tac.Type( "Arnet::IpGenAddr" ).tacType ):
      return "left"
   elif memberType.isDerived( Tac.Type( "Arnet::IpAddr" ).tacType ):
      return "left"
   elif memberType.isDerived( Tac.Type( "Arnet::EthAddr" ).tacType ):
      return "left"
   elif memberType.isDerived( Tac.Type( "Ark::PerUnitBitmap" ).tacType ):
      return "left"
   else:
      # everything else is right aligned
      return "right"

def getFilteredEntryAlignments( entry, numSubEntries, verbose,
                                indexedTable, attrPrefix='' ):
   """Get attribute alignments depending on a single entry"""
   bitmapRe = re.compile( r'mem\d+(Key|Mask)\d*' )
   keyFields = set()
   alignments = {}
   indexCaptured = False
   captureIndexAttributeCheck = False
   # create attribute dict by filtering one entries attributes and keeping
   # the attrType to use for alignments

   def getFieldMap( field, indexCaptured, captureIndexAttributeCheck, isKey=False ):
      for attr in field.attributes:
         attrType = field.tacType.attr( attr )
         if attr in keyFields:
            # for AsicHwHashTable model tables, the key fields
            # are contained in the entry. Avoid adding redundant
            # fields to the dataFields dict.
            continue
         if attr in ( 'entryIndex', 'viewT', 'transactionId' ):
            continue
         if isNotMutable( field, attrType ):
            # filter out all fields such as functions
            # and entryKey, physical field consts, etc.
            continue
         if isBitMap( attrType ):
            if bitmapRe.match( attr ):
               # remove bitmap fields such as key/mask from the alignments
               continue
         if isString( attrType ):
            continue
         if attrType.isCollection:
            if attrType.hasDataMember:
               # remove the collections of the form(data/dataBye/byte)
               # from the alignments as they aren't displayed
               continue
            if numSubEntries == 1:
               captureIndexAttributeCheck = True
         alignments[ attr ] = getAlignments( attr, attrType )
         if isKey:
            keyFields.add( attr )
   key = getattr( entry, 'entryKey', None )
   if not indexedTable and key:
      getFieldMap( key, indexCaptured, captureIndexAttributeCheck, isKey=True )
   getFieldMap( entry, indexCaptured, captureIndexAttributeCheck )
   if captureIndexAttributeCheck:
      captureIndexAttribute( numSubEntries, attrPrefix,
            alignments, captureDataIndex=True )
      indexCaptured = True
   if not indexCaptured:
      captureIndexAttribute( numSubEntries, attrPrefix,
            alignments, captureDataIndex=False )
   return alignments

def getEntryFields( entry, subIdx, numSubEntries, verbose, indexedTable,
      entryKey, attrPrefix='' ):
   bitmapRe = re.compile( r'mem\d+(Key|Mask)\d*' )
   keyFields = set()
   keyFieldsList = []
   dataFieldsList = []
   # Example format for the {key|data}FieldsList, it is a list of dictionaries
   # dataFieldsList[ 0 ] = {'field1':1, 'field2':2, 'entryIndex': 5, 'dataIndex': 0}
   # dataFieldsList[ 1 ] = {'field1':4, 'field2':4, 'entryIndex': 5, 'dataIndex': 1}

   def getFieldMap( field, fieldMapList, isKey=False ):
      indexCaptured = False
      for attr in field.attributes:
         val = getattr( field, attr )
         attrType = field.tacType.attr( attr )
         if attr in keyFields:
            # for AsicHwHashTable model tables, the key fields
            # are contained in the entry. Avoid adding redundant
            # fields to the dataFields dict.
            continue
         if attr in ( 'entryIndex', 'viewT' ):
            continue
         if isNotMutable( field, attrType ):
            # filter out all fields such as functions
            # and entryKey, physical field consts, etc.
            continue
         if isinstance( val, int ) and not verbose and ( val == 0 or
                                                         attr == 'transactionId' ):
            # remove the transactionId from the generic output, and
            # display them only in the verbose mode
            continue
         if isBitMap( attrType ):
            if not verbose and bitmapRe.match( attr ):
               # move displaying bitmap fields such as key/mask to the
               # verbose mode.
               continue
            val = bitmapToString( val )
         if not verbose and isString( attrType ):
            # remove static string attributes`
            continue
         if attrType.isCollection:
            if attrType.hasDataMember:
               # The condition is used to filter out
               # collections of the form
               # (data/dataByte/byte)#{memId}. These
               # are generated as tacc byte arrays
               # to hold the entry value in chip-native format.
               continue
            if not verbose and numSubEntries > 1 and \
               isinstance( val[ subIdx ], int ) and \
                  val[ subIdx ] == 0:
               continue
            if numSubEntries == 1:
               numOfEntries = Tac.Type( attrType.indexType.fullTypeName ).size
               for colIndex in range( numOfEntries ):
                  captureValue( val[ colIndex ], attr, verbose, fieldMapList,
                     attrPrefix, colIndex )
                  if indexedTable and not indexCaptured:
                     captureIndex( numSubEntries, entryKey, colIndex,
                        fieldMapList, attrPrefix, captureDataIndex=True )
               indexCaptured = True
               continue
            val = val[ subIdx ]
         captureValue( val, attr, verbose, fieldMapList, attrPrefix )
         if indexedTable:
            captureIndex( numSubEntries, entryKey, subIdx, fieldMapList, attrPrefix )
         if isKey:
            keyFields.add( attr )

   key = getattr( entry, 'entryKey', None )
   if not indexedTable and key:
      getFieldMap( key, keyFieldsList, isKey=True )
   getFieldMap( entry, dataFieldsList )
   if not dataFieldsList:
      captureIndex( numSubEntries, entryKey, subIdx, dataFieldsList, attrPrefix )
   return keyFieldsList, dataFieldsList

def getEntryFieldValue( attrList, fieldVal1, fieldVal2 ):
   fieldValueList = []
   for k in attrList:
      if k in fieldVal1:
         fieldValueList.append( fieldVal1.get( k, '0' ) )
      else:
         fieldValueList.append( fieldVal2.get( k, '0' ) )
   return fieldValueList

def isNotMutable( field, attrType ):
   # check to see if the field is user writable at all
   # currently this only will check tables that have the
   # new isUserWritable attribute
   isUserWritable = getattr( field, 'isUserWritable', None )
   if isUserWritable:
      return not isUserWritable( attrType.attributeId )
   else:
      # Filter out non-Mutable fields such as physical const fields and tableFields,
      # allow entryKey and entryIndex by checking attr.constructorArgNameQ
      return attrType.isFunction or not attrType.writable

def isBitMap( attrType ):
   return attrType.memberType.fullTypeName == 'Ark::PerUnitBitmap'

def isString( attrType ):
   return attrType.memberType.fullTypeName == 'Tac::String'

def bitmapToString( val ):
   # display the bitmap as a string of hexadecimal numbers instead of a U32
   # bitlist
   bitStr = val.bitmapToString()
   return bitStr.replace( ' 0x', '_' ).lstrip( '_' )

def shouldCaptureValue( value, verbose ):
   shouldCapture = False
   if verbose:
      return True
   if isinstance( value, int ):
      if value != 0:
         shouldCapture = True
   elif str( value ) != '0x0':
      shouldCapture = True
   return shouldCapture

def captureValue( value, attr, verbose, fieldMapList, attrPrefix, colIndex=0 ):
   while len( fieldMapList ) <= colIndex:
      fieldMapList.append( {} )
   if shouldCaptureValue( value, verbose ):
      fieldMapList[ colIndex ][ attrPrefix + attr ] = str( value )

def captureIndexAttribute( numSubEntries,
      attrPrefix, alignments, captureDataIndex=False ):
   """Capture the entryIndex/dataIndex alignments depending
   on if dataIndex is needed"""
   if captureDataIndex:
      alignments[ attrPrefix + 'entryIndex' ] = getAlignments( 'entryIndex',
                                                               attrPrefix )
      alignments[ attrPrefix + 'dataIndex' ] = getAlignments( 'dataIndex',
                                                              attrPrefix )
   else:
      alignments[ attrPrefix + 'entryIndex' ] = getAlignments( 'entryIndex',
                                                               attrPrefix )

def captureIndex( numSubEntries, entryKey, colIndex, fieldMapList,
      attrPrefix, captureDataIndex=False ):
   if not fieldMapList:
      fieldMapList.append( {} )
   if numSubEntries > 1:
      logicalIndex = entryKey * numSubEntries + colIndex
      fieldMapList[ 0 ][ attrPrefix + 'entryIndex' ] = str( logicalIndex )
   elif captureDataIndex:
      fieldMapList[ colIndex ][ attrPrefix + 'entryIndex' ] = str( entryKey )
      fieldMapList[ colIndex ][ attrPrefix + 'dataIndex' ] = str( colIndex )
   else:
      fieldMapList[ colIndex ][ attrPrefix + 'entryIndex' ] = str( entryKey )

def getEntryAtIdx( entryView, entryIdx, mode ):
   """Retrieve the entry at the given index or error"""
   if entryIdx is not None:
      try:
         return { entryIdx: entryView[ entryIdx ] }
      except IndexError:
         # The index is not permitted in the table
         mode.addError( f'entryIdx {entryIdx} is not permitted in entryView' )
      except KeyError:
         # The index is permitted, but has no entry in the table
         mode.addError( f'entryIdx {entryIdx} is permitted but has no entry in the '
            'entryView' )
   return None
