#!/usr/bin/env python3
# Copyright (c) 2022 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import argparse
import sys
import time
import AgentDirectory
import Pci
import PyClient
import Tac

class FRAMSpiSettings:
   def __init__( self, mmapAddr=None, pciAddr=None, regBaseAddr=None,
                 regStride=None, fifoSize=32, verbose=False ):
      self.regBaseAddr = regBaseAddr
      self.regStride = regStride
      self.fifoSize = fifoSize
      self.verbose = verbose

      # In the case of any failure, resource will be left as None.
      self.resource = None

      if mmapAddr or pciAddr:
         if not regBaseAddr:
            print( 'No base address specified.', file=sys.stderr )
            return

         if regBaseAddr < 0 or regBaseAddr % 4:
            print( 'Base address must be postive and word-aligned',
                   file=sys.stderr )
            return

         if not regStride:
            print( 'No register stride specified.', file=sys.stderr )
            return

         if regStride <= 0 or regStride % 4:
            print( 'Stride must be postive and word-aligned',
                   file=sys.stderr )
            return

      if mmapAddr:
         endOffset = mmapAddr + regBaseAddr + regStride * 4
         self.resource = Pci.MmapResource( '/dev/mem', startOffset=mmapAddr,
                                           endOffset=endOffset )
         self.regBaseAddr += mmapAddr
      elif pciAddr:
         self.resource = Pci.Device( pciAddr ).resource( 0 )
      else:
         self.initFromSysdb()
         if not self.resource:
            print( 'Could not determine settings from sysdb. Specify manually' )

   def initFromSysdb( self ):
      if not AgentDirectory.agent( 'ar', 'Sysdb' ):
         print( 'Sysdb is not running.', file=sys.stderr )
         return

      sysdbClient = PyClient.PyClient( 'ar', 'Sysdb' )
      sysdb = sysdbClient.agentRoot()
      hwCellDir = sysdb[ 'cell' ][ '1' ][ 'hardware' ]
      if 'blackbox' not in hwCellDir:
         print( 'Blackbox dir not found in sysdb', file=sys.stderr )
         return

      bbHam = hwCellDir[ 'blackbox' ][ 'config' ].ham
      self.regBaseAddr = bbHam.offset
      self.regStride = bbHam.size

      if bbHam.kind != 'hamTypeRecursive':
         print( f'Unsupported bbHam kind {bbHam.kind}',
                file=sys.stderr )
         return

      if bbHam.base.kind == 'hamTypePci':
         pciAddr = bbHam.base.address.stringValue()
         self.resource = Pci.Device( pciAddr ).resource( 0 )
         if self.verbose:
            print( 'Discovered PCI FRAM access:' )
            # pylint: disable-next=consider-using-f-string
            print( '   PCI address: {}, base address: 0x{:x}, '
                   'stride: 0x{:x}'.format( pciAddr, self.regBaseAddr,
                                            self.regStride ) )
      elif bbHam.base.kind == 'hamTypeMemMapped':
         mmapAddr = bbHam.base.offset
         self.regBaseAddr += mmapAddr
         self.resource = Pci.MmapResource(
            '/dev/mem', startOffset=mmapAddr,
            endOffset=self.regBaseAddr + self.regStride * 4 )
         if self.verbose:
            print( 'Discovered mmap FRAM access:' )
            # pylint: disable-next=consider-using-f-string
            print( '   base address: 0x{:x}, offset: 0x{:x}, '
                   'stride: 0x{:x}'.format( mmapAddr, bbHam.offset,
                                            self.regStride ) )
      else:
         print( f'Unsupported bbHam base kind {bbHam.base.kind}',
                file=sys.stderr )

class FRAMDevice:
   def __init__( self, settings ):
      self.fifoSize = settings.fifoSize
      self.verbose = settings.verbose
      self.resource = settings.resource
      assert self.resource
      self.cmdFifo = settings.regBaseAddr 
      self.readFifo = settings.regBaseAddr + settings.regStride
      self.controlReg = settings.regBaseAddr + 2 * settings.regStride
      self.logCtrlReg = settings.regBaseAddr + 3 * settings.regStride
      self.resetAccel()

   def loggingEnabled( self ):
      logCtrlVal = self.resource.read32( self.logCtrlReg )
      return bool( logCtrlVal & 0x1 )

   def enableLogging( self ):
      self.resource.write32( self.logCtrlReg, 0x1 )
      if self.verbose:
         print( 'Logging enabled' )

   def disableLogging( self ):
      self.resource.write32( self.logCtrlReg, 0x0 )
      if self.verbose:
         print( 'Logging disabled' )

   def resetAccel( self ):
      # Write 1 to bits 29-31 to clear error conditions
      self.resource.write32( self.controlReg, 0xe0000000 )

      def _fifoCounts():
         ctrlVal = self.resource.read32( self.controlReg )
         return ctrlVal & 0x3f, ( ctrlVal & 0x3f0000 ) >> 16

      readCount, writeCount = _fifoCounts()
      if not readCount and not writeCount:
         return

      if self.verbose:
         print( 'Draining read and write FIFO queues' )

      def _queueProgress( origReadCount, origWriteCount ):
         readCount, writeCount = _fifoCounts()
         return readCount < origReadCount or writeCount < origWriteCount

      while readCount or writeCount:
         self.resource.read32( self.readFifo )
         Tac.waitFor(
            lambda: _queueProgress( readCount, writeCount ),
            description='progress in draining queues',
            timeout=5 )
         readCount, writeCount = _fifoCounts()

   def getReadBytes( self, count ):
      def _numReadBytes():
         ctrlVal = self.resource.read32( self.controlReg )
         return 0x3f & ctrlVal

      def _waitForCount( expCount ):
         Tac.waitFor(
            lambda: _numReadBytes() == expCount,
            description='ctrl reg read bytes to be expected value',
            timeout=10 )
      _waitForCount( count )

      readBytes = []
      while len( readBytes ) < count:
         val = self.resource.read32( self.readFifo )
         assert 0x80000000 & val, 'Invalid data read from read register'
         readBytes += [ 0xff & val ]
         _waitForCount( count - len( readBytes ) )

         # For some reason, on memory mapped FRAMs, even after the read FIFO
         # count changes, the same byte is available. Sleep 10ms to avoid this.
         time.sleep( 0.01 )

      return readBytes

   def doRdidCmd( self ):
      rdIdCmdId = 0x9f
      rdIdCmd = [ rdIdCmdId ] + [ 0x1000 ] * 8 + [ 0x80001000 ]
      for val in rdIdCmd:
         self.resource.write32( self.cmdFifo, val )

      return self.getReadBytes( 9 )

   def doRdStatusCmd( self ):
      rdStatusCmdId = 0x05
      rdStatusCmd = [ rdStatusCmdId ] + [ 0x80001000 ]
      for val in rdStatusCmd:
         self.resource.write32( self.cmdFifo, val )

      return self.getReadBytes( 1 )[ 0 ]

   def doReadCmd( self, addr, dataCount ):
      readCmdId = 0x3
      self.resource.write32( self.cmdFifo, readCmdId )
      self.resource.write32( self.cmdFifo, ( addr >> 8 ) & 0xff )
      self.resource.write32( self.cmdFifo, addr & 0xff )
      for _ in range( dataCount - 1 ):
         self.resource.write32( self.cmdFifo, 0x1000 )
      self.resource.write32( self.cmdFifo, 0x80001000 )
         
      return self.getReadBytes( dataCount )

   def doWriteEnableCmd( self ):
      wrEnCmdId = 0x6
      self.resource.write32( self.cmdFifo, wrEnCmdId | 0x80000000 )

   def doWriteDisableCmd( self ):
      wrDiCmdId = 0x4
      self.resource.write32( self.cmdFifo, wrDiCmdId | 0x80000000 )

   def enableWrite( self ):
      # First, ensure write is enabled
      writeEnabledMask = 0x2
      status = self.doRdStatusCmd()
      if not status & writeEnabledMask:
         self.doWriteEnableCmd()
         Tac.waitFor( lambda: self.doRdStatusCmd() & writeEnabledMask,
                      description='write to be enabled', timeout=30 )

   def doWriteCmd( self, addr, data ):
      self.enableWrite()

      writeCmdId = 0x2
      self.resource.write32( self.cmdFifo, writeCmdId )
      self.resource.write32( self.cmdFifo, ( addr >> 8 ) & 0xff )
      self.resource.write32( self.cmdFifo, addr & 0xff )
      for i, val in enumerate( data ):
         if i < len( data ) - 1:
            self.resource.write32( self.cmdFifo, val )
         else:
            self.resource.write32( self.cmdFifo, val | 0x80000000 )

      self.doWriteDisableCmd()

   def doWrStatusCmd( self, val ):
      self.enableWrite()

      wrStatusCmdId = 0x01
      self.resource.write32( self.cmdFifo, wrStatusCmdId )
      self.resource.write32( self.cmdFifo, val | 0x80000000 )

      self.doWriteDisableCmd()

class MfgDecoder:
   def __init__( self, mfgBytes ):
      assert len( mfgBytes ) == 9
      self.mfgBytes = mfgBytes
      self.continuationCode = 0x7f
      self.cypressId = 0xc2
      self.fujitsuId = 0x4

   def _decodeCypress( self, vals ):
      for val in vals[ : 6 ]:
         if val != self.continuationCode:
            return False
      if vals[ 6 ] != self.cypressId:
         return False

      idStr = [ hex( x ) for x in vals[ : 7 ] ]
      print( f'Manufacturer ID {idStr}: Cypress' )
      productId = ( vals[ 7 ] << 8 ) | vals[ 8 ]
      print( f'Product ID: 0x{productId:02x}' )
      print( f'   Family(15-13): 0x{( productId & 0xe000 ) >> 13:02x}' )
      print( f'   Density(12-8): 0x{( productId & 0x1f00 ) >> 8:02x}' )
      print( f'   Sub(7-6): 0x{( productId & 0xc0 ) >> 6:02x}' )
      print( f'   Rev(5-3): 0x{( productId & 0x38 ) >> 3:02x}' )
      print( f'   Rsvd(2-0): 0x{productId & 0x7:02x}' )

      return True

   def _decodeFujitsu( self, vals ):
      if vals[ 0 ] != self.fujitsuId:
         return False

      print( f'Manufacturer ID ({vals[ 0 ]:02x}): Fujitsu' )
      if vals[ 1 ] != self.continuationCode:
         # pylint: disable-next=consider-using-f-string
         print( 'Warning: expected continuation code {} '
                'but found 0x{:02x}'.format( self.continuationCode, vals[ 1 ] ) )
      print( f'Product ID(1st byte): 0x{vals[ 2 ]:02x}' )
      print( f'Product ID(2nd byte): 0x{vals[ 3 ]:02x}' )
      return True

   def dump( self ):
      byteStr = [ hex( x ) for x in self.mfgBytes ]
      print( f'Manufacturer data: {byteStr}' )

      if ( not self._decodeCypress( self.mfgBytes ) and 
           not self._decodeFujitsu( self.mfgBytes ) ):
         print( 'Unrecognized vendor. Manufacturer bytes:' )
         for val in self.mfgBytes:
            print( '   ' + hex( val ) )

def verifyCount( device, count ):
   if count > device.fifoSize:
      print( f'Count is greater than fifo size ({device.fifoSize})',
             file=sys.stderr )
      print( 'Specify --fifoSize if the device supports a larger FIFO',
             file=sys.stderr )
      sys.exit( 1 )

def verifyLoggingDisabled( device ):
   if device.loggingEnabled():
      print( 'Any reads or writes require logging to be disabled',
             file=sys.stderr )
      print( 'Use the disable-logging command to disable it', file=sys.stderr )
      sys.exit( 1 )

def handleReadCmd( args, device ):
   verifyCount( device, args.count )
   verifyLoggingDisabled( device )
   addr = args.offset
   data = device.doReadCmd( addr, args.count )
   print( f'Data starting at 0x{addr:x}:' )
   for val in data:
      print( f' {val:02x}', end='' )
   print()

def handleDumpCmd( args, device ):
   # device.fifoSize is automatically used in the calculation below
   # specifying the fifoSize via the args will speed up this process
   verifyLoggingDisabled( device )
   addr = args.offset
   countLeft = args.count
   dumpData = bytearray()
   while countLeft > 0:
      if countLeft > device.fifoSize:
         readCount = device.fifoSize
      else:
         readCount = countLeft
      data = device.doReadCmd( addr, readCount )
      bytesRead = len( data )
      countLeft -= bytesRead
      addr += bytesRead
      dumpData.extend( bytearray( data ) )

   args.output.write( dumpData )

def handleWriteCmd( args, device ):
   verifyCount( device, len( args.data ) )
   verifyLoggingDisabled( device )
   device.doWriteCmd( args.offset, args.data )

def handleEnableLoggingCmd( args, device ):
   device.enableLogging()

def handleDisableLoggingCmd( args, device ):
   device.disableLogging()

def handleInventoryCmd( args, device ):
   verifyLoggingDisabled( device )

   mfgData = device.doRdidCmd()
   MfgDecoder( mfgData ).dump()

def handleStatusCmd( args, device ):
   verifyLoggingDisabled( device )

   status = device.doRdStatusCmd()
   print( f'FRAM status: 0x{status:x}' )

def handleWriteStatusCmd( args, device ):
   verifyLoggingDisabled( device )

   device.doWrStatusCmd( args.data )

def main():
   parser = argparse.ArgumentParser( description='Blackbox Utility' )
   parser.add_argument( '-v', '--verbose', action='store_true', help='Be verbose' )
   parser.add_argument( '--mmap', type=lambda x: int( x, 16 ),
                        help='Device address as an mmap' )
   parser.add_argument( '--pciAddr', help='PCI address of the device' )
   parser.add_argument( '--spiBaseAddr', type=lambda x: int( x, 16 ),
                        help='SPI base address' )
   parser.add_argument( '--spiRegStride', type=lambda x: int( x, 16 ),
                        help='SPI register stride' )
   parser.add_argument( '--fifoSize', type=int, default=32,
                        help='FIFO size' )

   if sys.version_info.major < 3:
      subparsers = parser.add_subparsers( dest='cmd', help='bbutil command' )
   else:
      subparsers = parser.add_subparsers( dest='cmd', required=True,
         help='bbutil command' )

   readParser = subparsers.add_parser( 'read', help='read bytes from the FRAM' )
   readParser.add_argument( 'offset', type=lambda x: int( x, 16 ),
                            help='Offset in bytes' )
   readParser.add_argument( 'count', type=int,
                            help='Number of bytes to read (1-32)' )
   readParser.set_defaults( func=handleReadCmd )

   dumpParser = subparsers.add_parser( 'dump',
                            help='read region of bytes from the FRAM' )
   dumpParser.add_argument( 'offset', type=lambda x: int( x, 16 ),
                            help='Offset in bytes' )
   dumpParser.add_argument( 'count', type=int,
                            help='Number of bytes to read' )
   dumpParser.add_argument( '-o', '--output',
                            type=argparse.FileType( "wb" ), 
                            default=sys.stdout.buffer,
                            help='Output file to write binary data' )
   dumpParser.set_defaults( func=handleDumpCmd )

   writeParser = subparsers.add_parser( 'write', help='write bytes to the FRAM' )
   writeParser.add_argument( 'offset', type=lambda x: int( x, 16 ),
                            help='Offset in bytes' )
   writeParser.add_argument( 'data', type=lambda x: int( x, 16 ), nargs='+',
                            help='Data to write' )
   writeParser.set_defaults( func=handleWriteCmd )

   enableLoggingParser = subparsers.add_parser( 'enable-logging',
                                                help='enable console logging' )
   enableLoggingParser.set_defaults( func=handleEnableLoggingCmd )

   disableLoggingParser = subparsers.add_parser( 'disable-logging',
                                                 help='disable console logging' )
   disableLoggingParser.set_defaults( func=handleDisableLoggingCmd )

   inventoryParser = subparsers.add_parser( 'inventory', help='print inventory' )
   inventoryParser.set_defaults( func=handleInventoryCmd )

   statusParser = subparsers.add_parser( 'status', help='print status register' )
   statusParser.set_defaults( func=handleStatusCmd )

   writeStatusParser = subparsers.add_parser(
      'write-status', help='write to the status register' )
   writeStatusParser.add_argument( 'data', type=lambda x: int( x, 16 ),
                                   help='Data to write' )
   writeStatusParser.set_defaults( func=handleWriteStatusCmd )

   args = parser.parse_args()

   settings = FRAMSpiSettings( mmapAddr=args.mmap, pciAddr=args.pciAddr,
                               regBaseAddr=args.spiBaseAddr,
                               regStride=args.spiRegStride,
                               fifoSize=args.fifoSize,
                               verbose=args.verbose )
   if not settings.resource:
      sys.exit( 1 )

   device = FRAMDevice( settings )
   args.func( args, device )

if __name__ == "__main__":
   main()
