# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from collections import namedtuple
from enum import IntEnum
import datetime
import errno
import os
import time

from AbootVersion import AbootVersion
import BiosLib
import Prefdl
import Tac
import TpmGeneric.Defs as TpmDefs
from TpmGeneric.Tpm import TpmGeneric

#--------------------------------------------------------------------------------
# Installation history helpers
#--------------------------------------------------------------------------------

# History codes
# Keep in sync with Aboot spi-upgrade
class HistoryCode( IntEnum ):
   FLASH_SUCCESS = 0
   SIGNATURE_ERROR = 1
   ZIP_ERROR = 2
   VERSION_ERROR = 3
   COMPAT_ERROR = 4
   LAYOUT_ERROR = 5
   IMAGE_ERROR = 6
   SHA_ERROR = 7
   SPI_ERROR = 8

   VALIDATION_SUCCESS = FLASH_SUCCESS

def historyError( errorCode ):
   errorCodeToString = {
      HistoryCode.SIGNATURE_ERROR: 'Invalid AUF signature',
      HistoryCode.ZIP_ERROR: 'Unable to decompress AUF',
      HistoryCode.VERSION_ERROR: 'AUF cannot be applied this version of Aboot',
      HistoryCode.COMPAT_ERROR: 'AUF compatibility check failed',
      HistoryCode.LAYOUT_ERROR: 'No Aboot layout in AUF',
      HistoryCode.IMAGE_ERROR: 'No img to flash Aboot with',
      HistoryCode.SHA_ERROR: 'img checksum failed',
      HistoryCode.SPI_ERROR: 'Error flashing Aboot',
   }
   return errorCodeToString.get( errorCode, None )

def historyFormat( aufSha, aufName, errorType ):
   historyLog = '%d,%s,%s,%d' % ( int( time.time() ), aufSha,
                                  aufName, errorType )
   return historyLog

ABOOT_UPDATE_DIR = 'flash:/aboot/update/'
ABOOT_UPDATE_LOG_DIR = os.path.join( ABOOT_UPDATE_DIR, 'logs/' )
ABOOT_UPDATE_HISTORY = 'history'

PERMALOCKED_SPI_FLASH_SIDS = [
   'Redondo',
]
PREFDL_PATH = "/etc/prefdl"

class InstallerErrorCode( IntEnum ):
   SUCCESS = 1
   FAIL = 2
   ABORT = 3
   UPGRADE_FAIL = 4 # More catastophic than a regular failure
   SETUP_SUCCESS = 5

InstallStatus = namedtuple( "InstallStatus", "active standby" )

# Supervisors to update. Options are mutually exclusive.
class Supervisor( IntEnum ):
   # auto() not available in 2.7
   ALL = 1
   ACTIVE = 2
   STANDBY = 3
   NONE = 4

   def __str__( self ):
      fullName = [
         "",
         "all supervisors",
         "active supervisor",
         "standby supervisor",
         "",
      ]
      if int( self ) not in [ self.ALL,
                              self.ACTIVE,
                              self.STANDBY ]:
         raise NotImplementedError
      return fullName[ self ]

class LoggerModeStub:
   def addError( self, msg ):
      print( msg )

   def addMessage( self, msg ):
      print( msg )

class Logger:
   def __init__( self, mode, logFile ):
      self._mode = mode
      self._logFile = logFile

   @property
   def logFile( self ):
      return self._logFile

   @property
   def filename( self ):
      return self.logFile.name

   def writeLog( self, msg ):
      try:
         self.logFile.write( f'{datetime.datetime.now()}\t{msg}\n' )
         self.logFile.flush()
      except OSError:
         # Be resilient to write/flush failures and don't crash the caller
         # because we couldn't write a log. For example, target may be full.
         # see 899669
         pass

   @staticmethod
   def localizeMsg( msg, standby=None ):
      if standby is not None:
         msg %= ( ' on the %s supervisor' % ( 'standby' if standby else 'active' ) )
      elif '%s' in msg:
         msg %= ''
      return msg

   def printMsg( self, msg, error=False, standby=None ):
      msg = self.localizeMsg( msg, standby=standby )
      if error:
         self._mode.addError( msg )
         self.writeLog( msg )
      else:
         self._mode.addMessage( msg )

class LocalInstaller:
   FLASH_RETRY_LIMIT = 3 # Arbitrary retry count

   def __init__( self, updateDir, historyGroup='eosadmin' ):
      # Enable loopback for file replication in simulation mode
      self.simulation = 'SIMULATION_VMID' in os.environ

      self.historyGroup = historyGroup
      self.makeDir( updateDir )
      self.historyPath = os.path.join( updateDir, ABOOT_UPDATE_HISTORY )
      self._checkHistoryFile()

      self.logger = None

      self.installStatus = InstallerErrorCode.ABORT

   def _checkHistoryFile( self ):
      # Rely on /mnt/flash/aboot ACL to set proper permissions if missing
      self.runCmd( f"touch {self.historyPath}" )

      historyGroup = self.runCmd( f"stat -c '%G' {self.historyPath}" ).strip()
      if historyGroup != self.historyGroup:
         self.runCmd( f"chgrp {self.historyGroup} {self.historyPath}" )

      historyPerms = self.runCmd( f"stat -c '%a' {self.historyPath}" ).strip()
      if historyPerms != "664":
         self.runCmd( f"chmod 664 {self.historyPath}" )

   def isSpiLocked( self ):
      # TpmGeneric allows reaching the TPM before Tpm{2,} agent is up
      tpm = TpmGeneric()
      try:
         return not tpm.isToggleBitSet( TpmDefs.SBToggleBit.UNLOCKSPIFLASH )
      except ( TpmDefs.NoTpmDevice, TpmDefs.NoSBToggle ):
         return None

   def isSpiUpdateEnabled( self ):
      tpm = TpmGeneric()
      try:
         return tpm.isToggleBitSet( TpmDefs.SBToggleBit.ENABLESPIUPDATE )
      except ( TpmDefs.NoTpmDevice, TpmDefs.NoSBToggle ):
         return None

   def isSpiFlashWpHardwired( self ):
      # hardwiredSpiFlashWP attribute is defined in product fdl and then exposed
      # in Sysdb by the Tpm agent.
      # But we can't rely on agents/sysdb in early boot, so match prefdl sid with
      # a predefined list of hardwired CPU cards.
      with open( PREFDL_PATH ) as f:
         prefdl = Prefdl.parsePrefdl( f.read().encode() )
      sid = prefdl.get( 'sid' )
      return any( sid.startswith( x ) for x in PERMALOCKED_SPI_FLASH_SIDS )

   def setLogger( self, logger ):
      self.logger = logger

   def makeDir( self, directory ):
      try:
         os.makedirs( directory )
      except OSError as e:
         if e.errno != errno.EEXIST:
            raise

   def getAbootVersion( self ):
      # Only read from /proc/cmdline in simulation.
      if self.simulation:
         return AbootVersion( BiosLib.getRunningAbootVersion() )
      return BiosLib.getAbootVersion()

   def copyFile( self, src, dst ):
      Tac.run( [ 'cp', src, dst ], asRoot=True )

   def moveAuf( self, src, dest ):
      self.logger.printMsg( 'Moving AUF to the update directory...' )
      self.copyFile( src.localFilename(), dest.localFilename() )
      self.logger.printMsg( 'Done.' )

   def runScript( self, script ):
      cmd = ( 'bash %s' % ( script ) ).split()
      return Tac.run( cmd, asRoot=True, stdout=Tac.CAPTURE, stderr=Tac.CAPTURE )

   def printSupeMsg( self, msg ):
      # No "active/standby" message as we don't have the distinction in the
      # base class.
      self.logger.printMsg( msg )

   def validateAuf( self, auf ):
      '''
      Validate the AUF. Even if there is a failure, keep going, and print all the
      errors that are applicable. It would be annoying to fix one error only to find
      another later on.
      '''
      self.printSupeMsg( 'Validating the AUF%s...' )
      # Check signature
      if not auf.isSigned( useDevCA=self.simulation ):
         self.logger.printMsg( 'Invalid AUF signature', error=True )
         return HistoryCode.SIGNATURE_ERROR

      # Pass version matching if no version in AUF
      version = self.getAbootVersion()
      isLine = version.line == auf.line
      isMajor = version.major == auf.major
      isMinor = version.minor <= auf.minor

      # Verify version is compatible
      if ( not isLine or not isMajor or not isMinor ):
         self.logger.printMsg( 'Incompatible AUF. Running Aboot version is '
                        '%s but AUF is for version %s.%s.%s.' %
                        ( version.version, auf.line, auf.major, auf.minor ),
                        error=True )
         return HistoryCode.VERSION_ERROR

      # If exists, run compat script
      compatScript = auf.getCompatibilityScript()
      if compatScript:
         try:
            self.logger.writeLog( self.runScript( compatScript ) )
         except Tac.SystemCommandError as e:
            self.logger.writeLog( e.output )
            self.logger.printMsg(
               f'Extended log available in {self.logger.filename}',
               error=True )
            self.logger.printMsg(
               f'AUF compatibility script failed with error code: {e.error}',
               error=True )
            return HistoryCode.COMPAT_ERROR
      self.logger.printMsg( 'Done.' )
      return HistoryCode.VALIDATION_SUCCESS

   def runCmd( self, cmd ):
      cmd = cmd.split()
      return Tac.run( cmd, asRoot=True, stdout=Tac.CAPTURE, stderr=Tac.CAPTURE )

   def getLocalFilepath( self, filePath ):
      # Locally, just return the same path
      return filePath

   def flashAufSection( self, auf ):
      '''
      Go through each of the sections in the Auf and flash the SPI.
      '''
      self.printSupeMsg( 'Flashing%s...' )
      sections = auf.getSections()

      layout = auf.getLayout()
      if not layout:
         return HistoryCode.LAYOUT_ERROR
      layout = self.getLocalFilepath( layout )

      self.logger.writeLog( 'Upgrading the following sections:\n%s\n' %
                            ( '\n\r'.join( list( sections ) ) ) )

      for imgName, payload in sections.items():
         img = payload.imgPath
         if not img:
            return HistoryCode.IMAGE_ERROR
         img = self.getLocalFilepath( img )

         cmd = f'flashUtil -l {layout} -w {imgName} {img}'
         self.logger.writeLog( cmd )

         self.logger.printMsg( 'Upgrading section "%s"...' % ( imgName ) )
         for retry in range( self.FLASH_RETRY_LIMIT ):
            if retry:
               self.logger.printMsg( 'Retrying...' )

            try:
               self.logger.writeLog( self.runCmd( cmd ) )
               self.logger.printMsg( 'Done.' )
               break
            except Tac.SystemCommandError as e:
               self.logger.printMsg( 'Section "%s" upgrade failed with error: %s'
                                     % ( imgName, format( e ) ), error=True )

            self.logger.writeLog( '\n\n' )
         else:
            return HistoryCode.SPI_ERROR

      return HistoryCode.FLASH_SUCCESS

   def writeHistory( self, aufSha, errorType, aufName='' ):
      historyLog = historyFormat( aufSha, aufName, errorType )

      try:
         with open( self.historyPath, 'a' ) as historyFile:
            historyFile.write( f'{historyLog}\n' )
            historyFile.flush()
            os.fsync( historyFile.fileno() )
      except OSError:
         # Be resilient to write/flush failures and don't crash the caller
         # because we couldn't write a log. For example, target may be full.
         # see 899669
         self.logger.printMsg(
               f'Failed to extend install history with: "{historyLog}"' )
