#!/usr/bin/env python3

# pylint: disable=consider-using-f-string

# Copyright (c) 2010  Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import argparse
import errno
import json
import os
import re
import subprocess
import sys
import FirmwareRev

from collections import defaultdict
from distutils.spawn import find_executable

import TpmGeneric.Defs as TpmDefs
from TpmGeneric.Tpm import TpmGeneric

from FlashUtil import (
   flashromDetect,
   FlashUtilErrorCode as ErrorCode,
   FlashUtilLock,
   getRomToFlash,
   parseLayoutFile,
   run,
)

###########################################################################
##### Hardcoded values for different supported platforms ##################
###########################################################################

# Dictionary containing cpu/flash architecture groups and
# their member platforms. This maps platform names to an architecture
# group and is used when different actions are needed to be taken
# on different architectures.
# We now primarily use the "Aboot" cmdline option, but fall back on this
# for older Aboot versions.
ARCHGROUP_TO_PLAT = {
   'norcal1' : [ ], # flashUtil was not used for norcal1
   'norcal2' : [ 'raven' ],
   'norcal3' : [ 'oak', 'blackbird', 'eaglepeak' ],
   'norcal4' : [ 'crow', 'mendocino' ],
   'norcal7' : [ 'oldfaithful' ],
   'norcal9' : [ 'woodpecker' ],
   'norcal10' : [ 'lorikeet', 'cormorant', 'shearwater' ],
   'norcal11' : [ 'puffin', 'alcatraz', 'prairieisland' ],
   'norcal12' : [ 'councilbluffs' ],
   'norcal13' : [ 'independence' ],
}

PSEUDO_TARGET_LUT = {
   'norcal7' : {
      'image' : ( 'fallback', 'normal', 'microcode', 'bootblock' ),
   },
   'norcal9' : {
      'image' : ( 'coreboot', 'agesa', 'bootblock' ),
   },
   'norcal10' : {
      'normal' : ( 'sec_pei', 'dxe' ),
      'fallback' : ( 'bkp_sec_pei', 'bkp_dxe' ),
      'image' : ( 'agesa', 'sec_pei', 'bkp_sec_pei', 'dxe', 'bkp_dxe' ),
   },
   'norcal11' : {
      'image' : ( 'normal', 'fallback' ),
   },
}

###########################################################################
##### End hardcoded values section ########################################
###########################################################################

LAYOUT_DIR = '/usr/share/EosAbootFirmwareUtils'

KiB = 1024
MiB = 1024 * KiB

def exitWithPrint( code, out=None, err=None ):
   if out:
      sys.stdout.write( out + '\n' )
   if err:
      sys.stderr.write( err + '\n' )
   sys.exit( code )

class FlashUtilBinary( object ): # pylint: disable=useless-object-inheritance
   def __init__( self, opts ):
      self.opts = opts
      self.flashOffsetLUT = defaultdict( lambda: {} )
      self.spiTotalLen = 0
      self.archGroup = ''
      self.abootVersion = FirmwareRev.abootFirmwareRevNumbers()
      self.extraPseudoTarget = {}

      self._pod = None

   @property
   def pod( self ): # Platform Offset Dictionary
      return self.flashOffsetLUT.get( self.archGroup, None )

   @property
   def totalSectionLen( self ):
      return self.pod[ 'total' ][ 'length' ] if self.pod else 0

   def norcalLine( self ):
      if not self.archGroup:
         return None

      m = re.search( r'norcal(\d+)', self.archGroup )
      if not m:
         return None
      return int( m.group( 1 ) )

   def computeArchGroup( self ):
      if self.opts.archGroup:
         return self.opts.archGroup

      platform = ''
      archGroup = ''

      try:
         with open( '/proc/cmdline' ) as f:
            cmdline = f.read()
      except IOError:
         pass
      else:
         platformRe  = re.search( r"platform=(.*?)($| |\n)", cmdline )
         archGroupRe  = re.search( r"Aboot=Aboot-(\w+)-", cmdline )
         if platformRe:
            platform = platformRe.group( 1 )
         if archGroupRe:
            archGroup = archGroupRe.group( 1 ).lower()
            # In case some Aboot version strings cause us to get garbage
            if archGroup not in self.flashOffsetLUT:
               archGroup = ''

      # Older versions of Aboot did not include the archgroup in
      # /proc/cmdline, so fall back on using the platform (eventually
      # we should get rid of this code)
      if not archGroup and platform:
         for archGroupItem, platforms in ARCHGROUP_TO_PLAT.items():
            if platform in platforms:
               archGroup = archGroupItem
               break

      pod = self.flashOffsetLUT.get( archGroup, None )
      if not pod:
         exitWithPrint( ErrorCode.FAIL_GENERIC,
                        err="No layout file found for archGroup %s" % archGroup )

      if 'total' not in pod:
         exitWithPrint( ErrorCode.FAIL_GENERIC,
            err="'total' section layout must be hardcoded for archGroup %s." %
                archGroup )

      if pod[ 'total' ][ 'length' ] != self.spiTotalLen:
         archGroup = '%s-%dm' % ( archGroup, self.spiTotalLen // MiB )
         if not self.flashOffsetLUT.get( archGroup, None ):
            exitWithPrint( ErrorCode.FAIL_GENERIC,
               err="SPI flash size %d kB does not match layout" %
                   ( self.spiTotalLen // KiB ) )

      return archGroup

   def flashromBin( self ):
      if self.opts.write and find_executable( "flashrom-diag" ):
         return "flashrom-diag"
      return "flashrom"

   def layoutPathFromPlatform( self, p ):
      if p == 'norcal11' and self.abootVersion == '11.0.0':
         p = 'norcal11-old'
      if p == 'norcal11-32m' and self.abootVersion == '11.0.0':
         p = 'norcal11-old-32m'
      return os.path.join( LAYOUT_DIR, p )

   def loadLayoutFiles( self ):
      for fileName in os.listdir( LAYOUT_DIR ):
         layout = parseLayoutFile( self.layoutPathFromPlatform( fileName ) )
         if layout is None:
            sys.exit( 1 )
         self.flashOffsetLUT[ fileName ] = layout

   def loadPseudoTargetFile( self ):
      if self.opts.pseudoTargetFile is None:
         return

      extraPseudoTarget = json.load( self.opts.pseudoTargetFile )

      for name, sections in extraPseudoTarget.items():
         if not isinstance( sections, list ):
            exitWithPrint( ErrorCode.FAIL_GENERIC,
                           err='Invalid target list for pseudo target %s' % name )
         for section in sections:
            if section not in self.flashOffsetLUT[ self.archGroup ]:
               exitWithPrint( ErrorCode.FAIL_GENERIC,
                              err='No section %s defined for %s' % ( section,
                                 self.archGroup ) )

      self.extraPseudoTarget = extraPseudoTarget

   def pseudoTargetDict( self ):
      # We can't directly use self.archGroup here, because in the case of *-32m
      # archGroup, we won't find pseudo targets. This is why we recompute the
      # norcalX here from the norcalLine extracted from the archGroup.
      return PSEUDO_TARGET_LUT.get( 'norcal%d' % self.norcalLine(), {} )

   def computeTargets( self ):
      target = self.opts.read or self.opts.write

      if target in self.pod:
         return ( target, )

      self.loadPseudoTargetFile()

      targets = self.extraPseudoTarget.get( target, [] )
      if targets:
         if self.opts.read:
            exitWithPrint( ErrorCode.FAIL_GENERIC,
                           err="PseudoTarget is not supported for read" )
         return targets

      if target not in self.pseudoTargetDict():
         exitWithPrint( ErrorCode.FAIL_UNKNOWN_SECTION,
            err="Target section %s is not supported for archGroup %s." %
                ( target, self.archGroup ) )

      if self.opts.read:
         exitWithPrint( ErrorCode.FAIL_GENERIC,
                        err="PseudoTarget is not supported for read" )

      targets = self.pseudoTargetDict()[ target ]
      valid = set( targets ).issubset( list( self.pod ) ) and len( targets ) > 1
      if not valid:
         exitWithPrint( ErrorCode.FAIL_GENERIC,
            err="PseudoTarget %s is invalid on archGroup %s." %
                ( target, self.archGroup ) )

      return targets

   def doRead( self, target, layoutFile ):
      start = self.pod[ target ][ 'start' ]
      length = self.pod[ target ][ 'length' ]

      cmd = [ self.flashromBin(), '-l', layoutFile, '-i', target ]
      flashromFileOut = '-'
      if self.opts.verbose:
         # If we're in verbose mode, we save the flashrom output to the file provided
         # by the user. We read it back after the flashrom run for proper processing.
         flashromFileOut = self.opts.filename
      else:
         cmd += [ '-q' ]
      cmd += [ '-r', flashromFileOut ]

      try:
         flashromOutput = run( cmd, verboseOut=self.opts.verbose )
      except subprocess.CalledProcessError as e:
         # Write debug log on failure...
         sys.stderr.flush()
         sys.stderr.buffer.write( e.output + b'\n' )
         raise

      if self.opts.verbose:
         # ... or if verbose flag was set
         sys.stdout.flush()
         sys.stdout.buffer.write( flashromOutput + b'\n' )

         with open( self.opts.filename, 'rb' ) as f:
            flashromOutput = f.read()

      if len( flashromOutput ) < length:
         exitWithPrint( ErrorCode.FAIL_GENERIC,
                        err="Unexpected output size returned by flashrom" )

      dataOutput = flashromOutput[ start : start + length ]

      if self.opts.filename != "-":
         with open( self.opts.filename, 'wb' ) as f:
            f.write( dataOutput )
      else:
         try:
            os.write( 1, dataOutput )
         except OSError as e:
            # Ignore broken pipes
            if e.errno != errno.EPIPE:
               raise

   def _flashromWriteFile( self, flashromBaseCmd, fileToWrite, start, length ):
      if start != 0 or length != self.spiTotalLen:
         fileToWrite = getRomToFlash( fileToWrite, start, self.spiTotalLen )

      if self.opts.try_single_erase:
         flashromBaseCmd += [ '--try-single-erase', ]

      try:
         flashromOutput = run( flashromBaseCmd + [ '-w', fileToWrite ],
                               verboseOut=self.opts.verbose )
      except subprocess.CalledProcessError as e:
         # Write debug log on failure...
         sys.stderr.flush()
         sys.stderr.buffer.write( e.output + b'\n' )
         raise

      if self.opts.verbose:
         # ... or if verbose flag was set
         sys.stdout.flush()
         sys.stdout.buffer.write( flashromOutput + b'\n' )

   def doWriteTarget( self, target, layoutFile ):
      start = self.pod[ target ][ 'start' ]
      length = self.pod[ target ][ 'length' ]

      # Check input file size
      imageSize = os.path.getsize( self.opts.filename )

      cmd = [ self.flashromBin(), '-N', '-l', layoutFile, '-i', target ]
      if imageSize != self.totalSectionLen:
         if imageSize > length:
            exitWithPrint( ErrorCode.FAIL_GENERIC,
               err="Image size doesn't match for target section %s (%d)." %
                   ( target, length ) )
         self._flashromWriteFile( cmd, self.opts.filename, start, length )
      else:
         self._flashromWriteFile( cmd, self.opts.filename, 0, self.totalSectionLen )

   def doWritePseudoTarget( self, targets, layoutFile ):
      if os.path.getsize( self.opts.filename ) != self.totalSectionLen:
         exitWithPrint( ErrorCode.FAIL_GENERIC,
                        err="PseudoTarget needs a full-size image" )

      cmd = [ self.flashromBin(), '-N', '-l', layoutFile ]
      for target in targets:
         cmd += [ '-i', target ]
      self._flashromWriteFile( cmd, self.opts.filename, 0, self.totalSectionLen )

   def run( self ):
      vendor, model, size = flashromDetect()
      if self.opts.name:
         exitWithPrint( ErrorCode.SUCCESS, out=model )
      elif self.opts.info:
         out = f'Vendor: {vendor}\nModel: {model}\nSize: {size}'
         exitWithPrint( ErrorCode.SUCCESS, out=out )

      self.spiTotalLen = size
      if not self.spiTotalLen:
         exitWithPrint( ErrorCode.FAIL_GENERIC,
                        err="Failed to detect SPI flash size" )

      self.loadLayoutFiles()

      self.archGroup = self.computeArchGroup().lower()
      if not self.archGroup:
         exitWithPrint( ErrorCode.FAIL_GENERIC,
            err="Cannot determine archGroup, please supply one with -a option." )

      if self.opts.show_total_size:
         exitWithPrint( ErrorCode.SUCCESS, out=str( self.totalSectionLen ) )

      if self.opts.layout:
         # Override existing layout with provided one
         layout = parseLayoutFile( self.opts.layout )
         if layout is None:
            sys.exit( 1 )
         self.flashOffsetLUT[ self.archGroup ] = layout
         layoutFile = self.opts.layout
      else:
         layoutFile = self.layoutPathFromPlatform( self.archGroup )

      if self.opts.show_layout:
         exitWithPrint( ErrorCode.SUCCESS, out=str( layoutFile ) )

      targets = self.computeTargets()
      if self.opts.read:
         self.doRead( targets[ 0 ], layoutFile )
      elif len( targets ) == 1:
         self.doWriteTarget( targets[ 0 ], layoutFile )
      else:
         self.doWritePseudoTarget( targets, layoutFile )

FLASHUTIL_DESCRIPTION = '''Wrapper Utility around flashrom to perform read write \
action on NorCal SPI flash.

The script can be directed to read or write any given SECTION.
During writes <FILENAME> provides the source data and during reads
data is written to it.
Caution should be used with the total option as the system can
be rendered unbootable if a write does not complete.

Section can be:
  total ( entire flash image )
  prefdl
  fdl
  mac
  image ( aboot )
  fallback

* not all sections are supported on all platforms.
'''

def parseCommandLine():
   op = argparse.ArgumentParser( prog='flashUtil',
                                 description=FLASHUTIL_DESCRIPTION,
                                 formatter_class=argparse.RawTextHelpFormatter )

   op.add_argument( 'filename', metavar='FILENAME', nargs='?', type=str,
                    default=None, help='Input/output file to read/write the ROM' )
   op.add_argument( "-r", "--read", action='store', default=None,
                    help="Read from Flash" )
   op.add_argument( "-w", "--write", action='store', default=None,
                    help="Write to Flash" )
   op.add_argument( "-a", "--archGroup", action='store', default=None,
                    help="Specify an archGroup (e.g. norcalN)" )
   op.add_argument( "-n", "--flash-name", dest='name', action='store_true',
                    help="Only probe for flash chip name" )
   op.add_argument( "--flash-info", dest='info', action='store_true',
                    help="Only probe for flash chip information" )
   op.add_argument( "-l", "--layout", dest='layout', action='store', default=None,
                    help="Use the provided layout file instead of the installed "
                         "ones" )
   op.add_argument( "-v", "--verbose", dest='verbose', action='store_true',
                    help="verbose output" )
   op.add_argument( "--pseudo-target-file", dest='pseudoTargetFile',
                    type=argparse.FileType( 'r' ), default=None,
                    help='JSON file that contains additional pseudo target ' \
                         'definitions' )
   op.add_argument( "--try-single-erase", action='store_true',
                    help="Only try the first defined erase opcode" )
   op.add_argument( "--show-total-size", action='store_true',
                    help="Show the size of the total section" )
   op.add_argument( "--show-layout", action='store_true',
                    help="Show the path of the layout that would currently be used" )
   op.add_argument( "--no-tpm-check", action='store_true',
                    help="Bypass TPM related checks" )

   opts = op.parse_args()

   if opts.filename == "-" and opts.verbose:
      exitWithPrint( ErrorCode.FAIL_GENERIC,
                     err="Verbose is not supported when using stdout" )

   if opts.read and opts.write:
      op.print_help()
      exitWithPrint( ErrorCode.FAIL_GENERIC, err="Invalid option." )

   if not ( opts.filename or opts.name or opts.info or opts.show_total_size or
            opts.show_layout ):
      op.print_help()
      exitWithPrint( ErrorCode.FAIL_GENERIC, err="FILENAME not specified." )

   return opts

def main():
   opts = parseCommandLine()

   # There are 2 cases when we don't want to print messages about the flock state:
   # 1) when we use -n to print the model of the SPI flash
   # 2) when we're asked to output data to stdout
   skipLockWarning = opts.name or opts.info or opts.filename == '-' or \
                     opts.show_total_size or opts.show_layout

   if os.getuid():
      exitWithPrint( ErrorCode.FAIL_GENERIC,
                     err="flashUtil needs to be run as root" )

   if opts.write and not opts.no_tpm_check:
      try:
         tpm = TpmGeneric()
         if not tpm.isToggleBitSet( TpmDefs.SBToggleBit.UNLOCKSPIFLASH ):
            exitWithPrint( ErrorCode.FAIL_SPI_WP_ACTIVE,
                           err="SPI flash write protection is active" )
      except ( TpmDefs.NoTpmDevice, TpmDefs.NoTpmImpl, TpmDefs.NoSBToggle ):
         pass
      except TpmDefs.Error as e:
         if not skipLockWarning:
            print( 'WARNING: Failed to check SPI flash write protection '
                   'status: %s' % str( e ) )

   # NOTE: Every operation done with flashrom **MUST** be done locked. If there's
   # a possibility for 2 flashrom instances to run at the same time, this will
   # cause puzzling failures in EOS, like BUG451105, BUG532344, ...
   # To make sure we don't re-introduce bugs like these in the future, just run
   # all the flashUtil logic with the lock grabbed.
   with FlashUtilLock( not skipLockWarning ):
      flashUtil = FlashUtilBinary( opts )
      flashUtil.run()

if __name__ == '__main__':
   main()
