#!/usr/bin/env python3
# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import argparse
import sys
import traceback
import textwrap

import BcmAsicQspiUtils
from BcmAsicQspiUtils import CmicGen
import EntityManager
import Pci
import ScdRegisters
import Tac

strToInt = lambda s: int( s, base=0 )

class ScdSmbus:
   def __init__(
         self, scdPciAddr, accelId, busId, deviceId, smbusSpacing=0x80,
         resetSetAddr=None, resetClearAddr=None, resetMask=None,
         cmicGen=CmicGen.CMICX ):
      self.resetSetAddr = resetSetAddr
      self.resetClearAddr = resetClearAddr
      self.resetMask = resetMask

      scdPath = ScdRegisters.scdPciResourceFile( pciAddress=scdPciAddr )
      self.scd = Pci.MmapResource( scdPath )
      self.scdOffset = ScdRegisters.scdPciOffset()

      self.qspi = BcmAsicQspiUtils.QspiFlash(
         scdPciAddr, accelId, busId, 4, smbusSpacing, responderAddr=deviceId,
         cmicGen=cmicGen )

   def isReset( self ):
      assert self.resetSetAddr is not None and self.resetMask is not None
      return self.scd.read32( self.scdOffset + self.resetSetAddr ) & \
         self.resetMask == self.resetMask

   def reset( self ):
      assert self.resetSetAddr is not None and self.resetMask is not None
      self.scd.write32( self.scdOffset + self.resetSetAddr, self.resetMask )

   def clearReset( self ):
      assert self.resetClearAddr is not None and self.resetMask is not None
      self.scd.write32( self.scdOffset + self.resetClearAddr, self.resetMask )

   def readVersion( self ):
      return self.qspi.getFwVersion()

   def erase( self, size ):
      self.qspi.eraseFlash( size )

   def read( self, start, size ):
      return self.qspi.qspiDeviceRead( start, start + size )[ :size ]

   def program( self, filename ):
      self.qspi.programFlash( filename )

   def verify( self, filename ):
      self.qspi.verifyFlash( filename )

   def readFlashId( self ):
      return self.qspi.qspiDeviceReadId()

# For Fabric Upgrades
class SupeSmbus:
   def __init__(
         self, scdPciAddr, accelId, busId, deviceId, smbusSpacing=0x80,
         resetPinName=None, pcieResetPinName=None, smbusEnPinName=None,
         cmicGen=CmicGen.CMICX ):
      self.resetPinName = resetPinName
      self.pcieResetPinName = pcieResetPinName
      self.smbusEnPinName = smbusEnPinName

      sysdb = EntityManager.Sysdb( "ar" )
      mg = sysdb.mountGroup()
      self.config = mg.mount( "hardware/archer/gpio/config", "Tac::Dir", "wi" )
      self.status = mg.mount( "hardware/archer/gpio/status", "Tac::Dir", "ri" )
      mg.close( blocking=True )

      self.qspi = BcmAsicQspiUtils.QspiFlash(
         scdPciAddr, accelId, busId, 4, smbusSpacing,
         responderAddr=deviceId, factoryType=None, cmicGen=cmicGen )

   def createGpioEntities( self ):
      if self.smbusEnPinName:
         self.config.newEntity(
            "Hardware::Gpio::ArcherPinConfig", self.smbusEnPinName )
      if self.resetPinName:
         self.config.newEntity(
            "Hardware::Gpio::ArcherPinConfig", self.resetPinName )
      if self.pcieResetPinName:
         self.config.newEntity(
            "Hardware::Gpio::ArcherPinConfig", self.pcieResetPinName )
      Tac.runActivities( 0 )

      if self.smbusEnPinName:
         Tac.waitFor( lambda: self.smbusEnPinName in self.status, timeout=15 )
      if self.resetPinName and self.pcieResetPinName:
         Tac.waitFor(
            lambda: self.resetPinName in self.status and
            self.pcieResetPinName in self.status, timeout=15 )
      elif self.resetPinName:
         Tac.waitFor( lambda: self.resetPinName in self.status, timeout=15 )
      elif self.pcieResetPinName:
         Tac.waitFor( lambda: self.pcieResetPinName in self.status, timeout=15 )

   def isReset( self ):
      self.createGpioEntities()
      resetPinStatus = (
         self.status[ self.resetPinName ].value ==  1
         if self.resetPinName
         else True )
      pcieResetPinStatus = (
         self.status[ self.pcieResetPinName ].value ==  1
         if self.pcieResetPinName
         else True )
      return resetPinStatus and pcieResetPinStatus

   def reset( self ):
      self.createGpioEntities()
      if self.smbusEnPinName:
         self.config[ self.smbusEnPinName ].value = 0
      if self.resetPinName:
         self.config[ self.resetPinName ].value = 1
      if self.pcieResetPinName:
         self.config[ self.pcieResetPinName ].value = 1
      Tac.runActivities( 0 )

      if self.smbusEnPinName:
         Tac.waitFor(
            lambda: self.status[ self.smbusEnPinName ].value == 0, timeout=15 )
      if self.resetPinName and self.pcieResetPinName:
         Tac.waitFor(
            lambda: self.status[ self.resetPinName ].value == 1 and
            self.status[ self.pcieResetPinName ].value == 1, timeout=15 )
      elif self.resetPinName:
         Tac.waitFor(
            lambda: self.status[ self.resetPinName ].value == 1, timeout=15 )
      elif self.pcieResetPinName:
         Tac.waitFor(
            lambda: self.status[ self.pcieResetPinName ].value == 1, timeout=15 )

   def clearReset( self ):
      self.createGpioEntities()
      if self.smbusEnPinName:
         self.config[ self.smbusEnPinName ].value = 1
      if self.resetPinName:
         self.config[ self.resetPinName ].value = 0
      if self.pcieResetPinName:
         self.config[ self.pcieResetPinName ].value = 0
      Tac.runActivities( 0 )

      if self.smbusEnPinName:
         Tac.waitFor(
            lambda: self.status[ self.smbusEnPinName ].value == 1, timeout=15 )
      if self.resetPinName and self.pcieResetPinName:
         Tac.waitFor(
            lambda: self.status[ self.resetPinName ].value == 0 and
            self.status[ self.pcieResetPinName ].value == 0, timeout=15 )
      elif self.resetPinName:
         Tac.waitFor(
            lambda: self.status[ self.resetPinName ].value == 0, timeout=15 )
      elif self.pcieResetPinName:
         Tac.waitFor(
            lambda: self.status[ self.pcieResetPinName ].value == 0, timeout=15 )

   def readVersion( self ):
      return self.qspi.getFwVersion( BcmAsicQspiUtils.QspiInitMode.MSPI16 )

   def erase( self, size ):
      self.qspi.eraseFlash( size )

   def read( self, start, size ):
      return self.qspi.qspiDeviceRead16( start, start + size )[ :size ]

   def program( self, filename ):
      self.qspi.programFlash( filename, BcmAsicQspiUtils.QspiInitMode.MSPI16 )

   def verify( self, filename ):
      self.qspi.verifyFlash( filename, BcmAsicQspiUtils.QspiInitMode.MSPI16 )

   def readFlashId( self ):
      return self.qspi.qspiDeviceReadId()


def checkReset( args, impl ):
   if impl.isReset():
      sys.exit( 1 )

def reset( args, impl ):
   impl.reset()

def clearReset( args, impl ):
   impl.clearReset()

def readVersion( args, impl ):
   print( f'Version: {impl.readVersion()}' )

def erase( args, impl ):
   impl.erase( args.size )

def read( args, impl ):
   sys.stdout.buffer.write( impl.read( args.start, args.size ) )

def program( args, impl ):
   impl.program( args.filename )

def verify( args, impl ):
   impl.verify( args.filename )

def readFlashId( args, impl ):
   manufacturerId, deviceId = impl.readFlashId()
   print( f'QSPI manufacturer: {manufacturerId:#x}, device ID: {deviceId:#x}' )

def runScdSmbus( args ):
   path = args.smbusPath.split( '/' )
   assert len( path ) == 5 and path[ 1 ] == 'scd'

   impl = ScdSmbus(
      args.scdAddress, strToInt( path[ 2 ] ), strToInt( path[ 3 ] ),
      strToInt( path[ 4 ] ), smbusSpacing=args.smbusSpacing,
      resetSetAddr=args.resetSetAddr, resetClearAddr=args.resetClearAddr,
      resetMask=args.resetMask, cmicGen=args.cmicGen )
   args.op( args, impl )

def runSupeParser( args ):
   path = args.smbusPath.split( '/' )
   assert len( path ) == 5 and path[ 1 ] == 'scd'

   impl = SupeSmbus(
      args.scdAddress, strToInt( path[ 2 ] ), strToInt( path[ 3 ] ),
      strToInt( path[ 4 ] ), smbusSpacing=args.smbusSpacing,
      resetPinName=args.resetPinName, pcieResetPinName=args.pcieResetPinName,
      smbusEnPinName=args.smbusEnPinName, cmicGen=args.cmicGen )
   args.op( args, impl )


def main():
   parser = argparse.ArgumentParser(
      description=textwrap.dedent(
      '''
      Utility to flash Broadcom ASIC QSPI via a variety of access methods (designed
      for modular systems).
      ''' ), formatter_class=argparse.ArgumentDefaultsHelpFormatter )
   parser.add_argument(
      '--scdAddress', help='PCI address of SCD on linecard', required=True )
   parser.add_argument(
      '--cmicGen', help='CMIC generation', type=CmicGen, default=CmicGen.CMICX )

   schemeParsers = parser.add_subparsers( dest='scheme', help='QSPI access method' )

   scdSmbusParser = schemeParsers.add_parser(
      'scdSmbus', formatter_class=argparse.ArgumentDefaultsHelpFormatter )
   scdSmbusParser.set_defaults( runScheme=runScdSmbus )
   scdSmbusParser.add_argument(
      '--resetSetAddr', help='SCD register for setting ASIC reset', type=strToInt )
   scdSmbusParser.add_argument(
      '--resetClearAddr', help='SCD register for clearing ASIC reset',
      type=strToInt )
   scdSmbusParser.add_argument(
      '--resetMask', help='Mask for setting / clearing only ASIC reset',
      type=strToInt )

   supeSmbusParser = schemeParsers.add_parser(
      'supeSmbus', formatter_class=argparse.ArgumentDefaultsHelpFormatter )
   supeSmbusParser.set_defaults( runScheme=runSupeParser )
   supeSmbusParser.add_argument(
      '--resetPinName', help='GPIO Name for Ramon Reset Pin' )
   supeSmbusParser.add_argument(
      '--pcieResetPinName', help='GPIO Name for Ramon Pcie Reset Pin' )
   supeSmbusParser.add_argument(
      '--smbusEnPinName', help='GPIO Name for Enabling Ramon Smbus access' )

   for schemeParser in ( scdSmbusParser, supeSmbusParser ):
      schemeParser.add_argument(
         'smbusPath', help='smbus command-style path (e.g. /scd/4/6/0x44)' )
      schemeParser.add_argument(
         '--smbusSpacing', help='SMBus accelerator address spacing', type=strToInt,
         default='0x80' )

      opParsers = schemeParser.add_subparsers( dest='operation' )

      checkResetParser = opParsers.add_parser(
         'checkReset', help='Check if the ASIC is in reset (exit status 1)',
         formatter_class=argparse.RawTextHelpFormatter )
      checkResetParser.set_defaults( op=checkReset )

      resetParser = opParsers.add_parser(
         'reset', help='Reset the ASIC (where applicable)',
         formatter_class=argparse.RawTextHelpFormatter )
      resetParser.set_defaults( op=reset )

      clearResetParser = opParsers.add_parser(
         'clearReset', help='Take the ASIC out of reset (where applicable)',
         formatter_class=argparse.RawTextHelpFormatter )
      clearResetParser.set_defaults( op=clearReset )

      versionParser = opParsers.add_parser(
         'readVersion', help='Read the PCIe firmware version from QSPI',
         formatter_class=argparse.RawTextHelpFormatter )
      versionParser.set_defaults( op=readVersion )

      eraseParser = opParsers.add_parser(
         'erase', help='Erase QSPI flash',
         formatter_class=argparse.RawTextHelpFormatter )
      eraseParser.add_argument(
         'size', type=strToInt, help='Number of bytes to erase' )
      eraseParser.set_defaults( op=erase )

      readParser = opParsers.add_parser(
         'read', help='Read arbitrary data from the PCIe firmware QSPI',
         formatter_class=argparse.RawTextHelpFormatter )
      readParser.add_argument( 'start', type=strToInt, help='Flash offset' )
      readParser.add_argument(
         'size', type=strToInt, help='Size of data to read (in bytes)' )
      readParser.set_defaults( op=read )

      programParser = opParsers.add_parser(
         'program', help='Program QSPI flash',
         formatter_class=argparse.RawTextHelpFormatter )
      programParser.add_argument( 'filename', help='Image to flash' )
      programParser.set_defaults( op=program )

      verifyParser = opParsers.add_parser(
         'verify', help='Verify QSPI flash',
         formatter_class=argparse.RawTextHelpFormatter )
      verifyParser.add_argument( 'filename', help='Image to verify with flash' )
      verifyParser.set_defaults( op=verify )

      readIdParser = opParsers.add_parser(
         'readFlashId', help='Read QSPI device IDs',
         formatter_class=argparse.RawTextHelpFormatter )
      readIdParser.set_defaults( op=readFlashId )

   args = parser.parse_args()
   args.runScheme( args )


if __name__ == '__main__':
   try:
      main()
   except Exception: # pylint: disable=broad-except
      # Catch all exceptions so we can exit with -1 instead of the default 1
      traceback.print_exc( file=sys.stderr )
      sys.exit( -1 )
