# Copyright (c) 2022 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from enum import Enum
import errno
import os
import tempfile

from AbootVersion import AbootVersion
import Auf
from AufVersion import ValidVersion as AufVersion, ParseError as AufVersionParseError
from BiosInstallLib import ABOOT_UPDATE_DIR, ABOOT_UPDATE_LOG_DIR, HistoryCode,     \
   LocalInstaller, Logger as BiosLogger
import BiosLib
import CEosHelper
import FirmwareUtils
from FirmwareUtils import UpgradeError
from FlashUtil import FlashUtilErrorCode
import Logging
import SwiSignLib
import Tac
from Toggles.AbootEosToggleLib import toggleCertificateBundleFirmwarePluginEnabled
import Url

# Python2 compatibility
if not hasattr( os, "scandir" ):
   # pkgdeps: py_version scandir 2
   import scandir # pylint: disable=import-error
   os.scandir = scandir.scandir

# pkgdeps: library SecureBoot
CertSection = Tac.Type( "SecureBoot::CertSection" )
CertPath = Tac.Type( "SecureBoot::CertPath" )
SpiLoaderState = Tac.Type( "SecureBoot::SpiLoader::State" )

PRODUCT = "CertificateBundle"
ABOOT_UPDATE_SRC_DIR = "/usr/share/abootupdates"
CERTS_VERSION_KEY = b"CERTS_VERSION="
DEFAULT_CERT_BUNDLE_LINE = 1

SECUREBOOT_CERTIFICATE_UPGRADENOTAPPLICABLE = Logging.LogHandle(
      "SECUREBOOT_CERTIFICATE_UPGRADENOTAPPLICABLE",
      severity=Logging.logWarning,
      fmt="A certificate bundle update could not be automatically applied: %s",
      explanation="A certificate bundle update could not be automatically applied "
      "due to a configuration issue.",
      recommendedAction="If a certificate bundle upgrade is desired, resolve the "
      "configuration issue and reboot the system to retry the installation. "
      "Otherwise, no action is required." )

SECUREBOOT_CERTIFICATE_UPGRADEFAILED = Logging.LogHandle(
      "SECUREBOOT_CERTIFICATE_UPGRADEFAILED",
      severity=Logging.logError,
      fmt="Failed to apply a certificate bundle upgrade: %s",
      explanation="A certificate bundle update failed to automatically apply.",
      recommendedAction=Logging.CONTACT_SUPPORT )

# Error cases
REASON_VALIDATION_FAILED = "AUF validation failed."
REASON_FLASH_FAILED = "the system failed to flash the upgrade."
REASON_METADATA_READ_FAILED = "an error occurred while reading metadata."
REASON_FOUND_WRONG_VERSION = "an unexpected version was found after installation."
# Warning cases
REASON_ABOOT_UPDATE_NEEDED = "an Aboot update is required to complete installation."
REASON_SPIUPDATE_REQUIRED_DISABLED = ( "the Aboot SPI flash update feature is "
                                       "required but is disabled." )
REASON_SPIUPDATE_UNSUPPORTED_LOCKED = ( "the Aboot SPI flash update feature is not "
                                        "supported by the installed Aboot version "
                                        "and the SPI flash is locked." )
REASON_SPIUPDATE_DISABLED_LOCKED = ( "the Aboot SPI flash update feature is "
                                     "disabled and the SPI flash is locked." )

def syslogUpgradeWarn( reason ):
   Logging.log( SECUREBOOT_CERTIFICATE_UPGRADENOTAPPLICABLE, reason )

def syslogUpgradeError( reason ):
   Logging.log( SECUREBOOT_CERTIFICATE_UPGRADEFAILED, reason )

class Logger( BiosLogger ):
   def __init__( self, tee, logDir ):
      try:
         os.makedirs( logDir )
      except OSError as e:
         if e.errno != errno.EEXIST:
            raise
      self.logDir = logDir
      super().__init__( None, None )
      self.tee = tee

   @property
   def logFile( self ):
      # Lazily create logFile to prevent piling empty files at every boot
      if not self._logFile:
         self._logFile = tempfile.NamedTemporaryFile( suffix=".log", dir=self.logDir,
                                                      prefix="CertificateBundle-",
                                                      delete=False, mode="w" )
      return self._logFile

   def writeLog( self, msg ):
      super().writeLog( msg )
      self.tee.all( msg )

   def printMsg( self, msg, error=False, standby=None ):
      msg = self.localizeMsg( msg, standby=standby )
      self.writeLog( msg )

class Version( FirmwareUtils.FirmwareVersion ):
   def __init__( self, version=None ):
      if version is not None:
         try:
            version = AufVersion( version )
         except AufVersionParseError:
            version = None
      super().__init__( version )

   def isValid( self ):
      return self.version is not None

   @property
   def product( self ):
      return self.version.product if self.isValid() else None

   @property
   def line( self ):
      return self.version.line if self.isValid() else None

   @property
   def major( self ):
      return self.version.major if self.isValid() else None

   @property
   def minor( self ):
      return self.version.minor if self.isValid() else None

   def __str__( self ):
      return str( self.version )

   @property
   def fileName( self ):
      return str( self ) + ".auf"

class CertificateBundle( FirmwareUtils.Device ):
   """
   This class implements upgrading certificate bundles on the SPI flash.

   Certificate bundles are not actually a physical "Device", but framing them as
   devices allows them to fit nicely into the FirmwarePlugin framework.
   """

   def __init__( self, inventory, version, updateDir, description,
                 destUrl=ABOOT_UPDATE_DIR ):
      super().__init__( "CertificateBundle" )
      self.inventory = inventory
      self.installer = inventory.installer
      self.logger = inventory.logger
      self.version = version
      self.updateDir = updateDir
      self.description = description + " certificate bundle"
      self.destUrl = destUrl
      self.installFromAboot = False

   def logPreUpgradeInfo( self, auf, aufSha ):
      self.logger.writeLog( "Current Version: " + str( self.version ) )
      self.logger.writeLog( "New Version: " + auf.getName() )
      self.logger.writeLog( "Auf version: " + str( auf.version ) )

      self.logger.writeLog( "Checksum: " + aufSha )

      if SwiSignLib.swiSignatureExists( auf.aufFile ):
         if SwiSignLib.verifySwiSignature( auf.aufFile, rootCA=CertPath.upgrade ):
            self.logger.writeLog( "Auf signed:" )
            self.logger.writeLog( "\tCert " + CertPath.upgrade )

      self.logger.writeLog( "Sections:" )
      for sectionName, payload in auf.sections.items():
         self.logger.writeLog( "\t%s: size 0x%x bytes" %
                               ( sectionName, payload.size ) )
         self.logger.writeLog( "\t\tsha256(%s)" % payload.sha )

   def upgrade( self ):
      # The Tpm agent isn't running yet, which normally loads certs later in boot. We
      # need to validate the cert bundle AUF now, so load the upgrade cert early if
      # necessary.
      self.inventory.maybeLoadUpgradeCert()

      aufFile = os.path.join( self.updateDir, self.getLocalVersion().fileName )

      # Open AUF
      aufSha = Auf.calcSha( aufFile )
      try:
         auf = Auf.Auf( aufFile )
         self.logPreUpgradeInfo( auf, aufSha )
      except ( OSError, SyntaxError, TypeError, ValueError ) as e:
         historyCode = ( HistoryCode.ZIP_ERROR
                         if isinstance( e, ( IOError, SyntaxError ) )
                         else HistoryCode.SHA_ERROR )
         self.installer.writeHistory( aufSha, historyCode )
         self.logger.printMsg( "Skipping %s upgrade: AUF parsing failed with error: "
                               "%s" % ( self.description, format( e ) ) )
         return
      aufName = auf.getName()

      # Determine how we are going to install the AUF. Either from EOS or Aboot.
      # Prefer EOS. If the SPI flash is locked or the AUF requires it, use Aboot.
      forceInstallFromAboot = auf.abootOnly
      self.installFromAboot = self.inventory.spiLocked or forceInstallFromAboot
      # If we're installing from Aboot, check that the Aboot installer is present,
      # enabled and compatible with the AUF
      if self.installFromAboot:
         if forceInstallFromAboot:
            if self.inventory.abootInstallerRev == -1:
               syslogUpgradeWarn( REASON_ABOOT_UPDATE_NEEDED )
               raise UpgradeError( "Skipping %s upgrade: Aboot SPI update not "
                                   "supported but AUF requires it" %
                                   self.description )
            if not self.inventory.spiUpdateEnabled:
               syslogUpgradeWarn( REASON_SPIUPDATE_REQUIRED_DISABLED )
               raise UpgradeError( "Skipping %s upgrade: Aboot SPI update disabled "
                                   "but AUF requires it" % self.description )

         if self.inventory.abootInstallerRev < auf.abootInstallerMinRev:
            syslogUpgradeWarn( REASON_ABOOT_UPDATE_NEEDED )
            info = ( self.description, auf.abootInstallerMinRev,
                     self.inventory.abootInstallerRev )
            raise UpgradeError( "Skipping %s upgrade: Aboot installer revision %d "
                                "required, has %d" % info )

      # Validate AUF
      historyCode = self.installer.validateAuf( auf )
      if historyCode != HistoryCode.VALIDATION_SUCCESS:
         syslogUpgradeError( REASON_VALIDATION_FAILED )
         self.installer.writeHistory( aufSha, historyCode, aufName )
         raise UpgradeError( "Skipping %s upgrade: AUF validation failed" %
                             self.description )

      # Install AUF
      if self.installFromAboot:
         dest = os.path.join( self.destUrl, self.getLocalVersion().fileName )
         context = Url.Context( None, True )
         dest = Url.parseUrl( dest, context=context )
         src = Url.parseUrl( "file:%s" % aufFile, context=context )
         self.installer.moveAuf( src, dest )
      else:
         historyCode = self.installer.flashAufSection( auf )
         self.installer.writeHistory( aufSha, historyCode, aufName )
         if historyCode != HistoryCode.FLASH_SUCCESS:
            syslogUpgradeError( REASON_FLASH_FAILED )
            raise UpgradeError( "Failure while flashing %s" % self.description )

   def postUpgradeReboot( self ):
      return self.installFromAboot

   @Tac.memoize
   def getLocalVersion( self ):
      # If the local version couldn't be read from the SPI section, fall back to a
      # global default. This may need to be more complex in the future depending on
      # how lines are used.
      targetLine = ( self.version.line if self.version.isValid()
                     else DEFAULT_CERT_BUNDLE_LINE )

      # Look for updates for targetLine. We expect only 0 or 1 such AUF to exist.
      try:
         it = os.scandir( self.updateDir )
      except OSError as e:
         if e.errno == errno.ENOENT:
            return Version()
         raise
      for entry in it:
         if not entry.is_file() or not entry.name.endswith( ".auf" ):
            continue
         version = Version( version=entry.name[ : -len( ".auf" ) ] )
         if ( version.product != PRODUCT or version.line != targetLine or
              version.fileName != entry.name ):
            continue
         return version
      return Version()

   def getHwVersion( self ):
      return self.version

   def getDescription( self ):
      return self.description

   def needsUpgrade( self ):
      # If the SWI doesn't have an embedded certificate bundle that matches the
      # current line, then we can't upgrade anything.
      return ( self.getLocalVersion().isValid() and
               super().needsUpgrade() )

class CertificateBundleInventory( FirmwareUtils.Inventory ):
   RESET_MESSAGE = "Rebooting into Aboot to complete upgrades."

   @staticmethod
   def runCmd( *cmd, **kwargs ):
      return Tac.run( list( cmd ), asRoot=True, stdout=Tac.CAPTURE, **kwargs )

   def __init__( self ):
      super().__init__()
      urlContext = Url.Context( None, True )

      try:
         logDir = Url.parseUrl( ABOOT_UPDATE_LOG_DIR, context=urlContext )
         self.logger = Logger( self.tee, logDir.localFilename() )
      except OSError:
         self.logger = Logger( self.tee, '/tmp' )
         self.logger.printMsg( f"Failed to use {logDir.localFilename()} as log "
                               "directory, falling back to /tmp" )

      try:
         dest = Url.parseUrl( ABOOT_UPDATE_DIR, context=urlContext )
         self.installer = LocalInstaller( dest.localFilename() )
         self.installer.setLogger( self.logger )
      except OSError as e:
         self.installer = None
         self.logger.printMsg( f"Error while creating the installer: {e}" )

      self.abootVersionString = BiosLib.getRunningAbootVersion()
      features = BiosLib.getAbootFeatures()
      self.abootInstallerRev = features.get( BiosLib.ABOOT_FEATURE_SPIUPGRADE, -1 )
      self.spiLocked = False
      self.spiUpdateEnabled = True
      self.incompatible = False
      self.upgradeCertLoaded = False

   def maybeLoadUpgradeCert( self ):
      if self.upgradeCertLoaded:
         return

      # Instantiate the loader Sm manually
      # pkgdeps: library SecureBootSpiLoader
      reader = Tac.newInstance( "SecureBoot::SpiReader" )
      loader = Tac.newInstance( "SecureBoot::SpiCertLoader", CertSection.upgrade,
                                CertPath.upgrade, reader )
      # Wait for the Sm to finish
      Tac.waitFor( lambda: loader.state == SpiLoaderState.done or loader.error )
      if loader.error:
         self.logger.writeLog( "Failed to load the SPI upgrade certificate." )

      self.upgradeCertLoaded = True

   def maybeAddBundle( self, section, updateDir, description ):
      # Read the certificate bundle metadata from the SPI flash
      try:
         # pkgdeps: rpmwith %{_bindir}/flashUtil
         meta = self.runCmd( "flashUtil", "-r", section, "-", text=False )
      except Tac.SystemCommandError as e:
         if e.error != FlashUtilErrorCode.FAIL_UNKNOWN_SECTION:
            syslogUpgradeError( REASON_METADATA_READ_FAILED )
            self.logger.writeLog( "flashUtil failed unexpectedly while reading "
                                  "section %s: %s" % ( section, e.error ) )
         return

      # Parse the version field
      version = Version()
      for line in meta.rstrip( b"\n\x00" ).splitlines():
         if line.startswith( CERTS_VERSION_KEY ):
            try:
               versionStr = line[ len( CERTS_VERSION_KEY ) : ].decode()
            except UnicodeDecodeError:
               continue
            version = Version( version=versionStr )
            if version.isValid():
               break

      bundle = CertificateBundle( self, version, updateDir, description )
      self.inventory.append( bundle )

   class AbootCompat( Enum ):
      COMPATIBLE = 0
      INCOMPATIBLE_VERSION = 1
      INCOMPATIBLE_LINE = 2

   def checkAbootCompatible( self ):
      # Quick feature compatibility check, running version is quicker & enough here.
      # Let the AUF checking code do the deeper installed compatibility verification.
      try:
         abootVersion = AbootVersion( self.abootVersionString )
      except: # pylint: disable=bare-except
         return self.AbootCompat.INCOMPATIBLE_LINE

      if abootVersion.norcalId in [ None, 1, 2, 3, 4, 5, 8 ]:
         return self.AbootCompat.INCOMPATIBLE_LINE

      if abootVersion.hasSecureBootCpuWP():
         syslogUpgradeWarn( REASON_ABOOT_UPDATE_NEEDED )
         self.logger.writeLog(
            "Cannot update certificate bundle for this Aboot version" )
         return self.AbootCompat.INCOMPATIBLE_VERSION

      return self.AbootCompat.COMPATIBLE

   def checkConfiguration( self ):
      if CEosHelper.isCeos():
         self.incompatible = True
         return

      # If we failed to access the update directory
      if not self.installer:
         self.incompatible = True
         self.logger.printMsg( "No installer available" )
         return

      # Either no TPM or no SB NVRAM index defined
      # Most likely not compatible with SB, skip certificates update
      self.spiLocked = self.installer.isSpiLocked()
      if self.spiLocked is None:
         self.incompatible = True
         return

      self.spiUpdateEnabled = self.installer.isSpiUpdateEnabled()
      if self.installer.isSpiFlashWpHardwired() or \
         self.checkAbootCompatible() != self.AbootCompat.COMPATIBLE:
         self.incompatible = True
         return

      if self.spiLocked:
         if self.abootInstallerRev == -1:
            syslogUpgradeWarn( REASON_SPIUPDATE_UNSUPPORTED_LOCKED )
            self.logger.writeLog( "Cannot update certificate bundles as Aboot SPI "
                                  "update is not supported and the SPI flash is "
                                  "locked" )
            self.incompatible = True
            return

         if not self.spiUpdateEnabled:
            syslogUpgradeWarn( REASON_SPIUPDATE_DISABLED_LOCKED )
            self.logger.writeLog( "Cannot update certificate bundles as Aboot SPI "
                                  "update is disabled and the SPI flash is locked" )
            self.incompatible = True
            return

   def populate( self ):
      self.checkConfiguration()

      if self.incompatible:
         # Skip everything in this case, we won't ever be able to process the update
         # on this product so warning message is pointless.
         return

      # In the future, if we have more than one set of certificate bundles, multiple
      # objects can be added to the inventory.
      self.maybeAddBundle( "certbundle_metadata", ABOOT_UPDATE_SRC_DIR,
                           "secure boot" )

   def syslogUpgradeFailed( self, _device, afterReboot ):
      if afterReboot:
         syslogUpgradeError( REASON_FOUND_WRONG_VERSION )
      # Syslogging before reboot is handled by the `ugprade()` and
      # `checkConfiguration()` functions

   def reset( self ):
      Tac.runActivities( 0.5 )
      Tac.run( [ "reboot" ] )
      Tac.runActivities( 1 )

def Plugin( context ):
   if toggleCertificateBundleFirmwarePluginEnabled():
      context.register( CertificateBundleInventory() )
