# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import contextlib
import os
import re
import sys
import tempfile

from AbootVersion import AbootVersion
import BiosBlinkLib
import BiosSlimLib
import FileReplicationCmds
import FirmwareRev
import Fru
import Tac

ABOOT_FEATURES_FILE = "/export/aboot/features"
ABOOT_FEATURE_SPIUPGRADE = "spiupgrade"
NOT_DEPRECATED_SB = 26

#--------------------------------------------------------------------------------
# Aboot firmware readers
#--------------------------------------------------------------------------------

class AbootReaderBase:
   def _parseVersion( self, sectionData ):
      abootRe = AbootVersion.abootRe().encode()
      abootLooseRe = AbootVersion.abootLooseRe().encode()
      m = re.search( abootRe, sectionData )
      if not m:
         m = re.search( abootLooseRe, sectionData )
      version = AbootVersion( m.group( 0 ) ) if m else None

      return version

   def getProgrammedVersion( self, standby=False ):
      raise NotImplementedError

   def getFallbackVersion( self, standby=False ):
      raise NotImplementedError

   def getProgrammedAndFallbackVersions( self, standby=False ):
      return ( self.getProgrammedVersion( standby=standby ),
               self.getFallbackVersion( standby=standby ) )

class AbootSpiFlashReaderBase( AbootReaderBase ):
   imageSection = None
   fallbackSection = None

   def getProgrammedVersion( self, standby=False ):
      sectionData = self.readFlash( section=self.imageSection, standby=standby )
      return self._parseVersion( sectionData )

   def getFallbackVersion( self, standby=False ):
      sectionData = self.readFlash( section=self.fallbackSection, standby=standby )
      return self._parseVersion( sectionData )

   def readFlash( self, section=None, standby=False ):
      if section:
         cmd = "flashUtil -r %s -" % section
      else:
         # Read the entire flash
         cmd = "flashUtil -r total -"

      if standby:
         cmd = FileReplicationCmds.runCmd( Fru.slotId(), cmd, useKey=True )
      else:
         cmd = cmd.split()

      return Tac.run( cmd, asRoot=True, stdout=Tac.CAPTURE, stderr=Tac.DISCARD,
                      text=False )

@contextlib.contextmanager
def mountRo( label, standby=False ):
   # Create temporary mount point
   mountDir = None
   if standby:
      cmd = FileReplicationCmds.runCmd( Fru.slotId(), "mktemp -d", useKey=True )
      mountDir = Tac.run( cmd, asRoot=True, stdout=Tac.CAPTURE, stderr=Tac.DISCARD )
   else:
      mountDir = tempfile.mkdtemp()

   # Mount partition
   cmd = f"mount -r -L {label} {mountDir}"
   if standby:
      cmd = FileReplicationCmds.runCmd( Fru.slotId(), cmd, useKey=True )
   else:
      cmd = cmd.split()
   Tac.run( cmd, asRoot=True )

   try:
      yield mountDir
   finally:
      cmd = f"umount {mountDir}"
      if standby:
         cmd = FileReplicationCmds.runCmd( Fru.slotId(), cmd, useKey=True )
      else:
         cmd = cmd.split()
      Tac.run( cmd, asRoot=True )

      cmd = f"rm -r {mountDir}"
      if standby:
         cmd = FileReplicationCmds.runCmd( Fru.slotId(), cmd, useKey=True )
      else:
         cmd = cmd.split()
      Tac.run( cmd, asRoot=True )

class AbootBlkPartReaderBase( AbootReaderBase ):
   def __init__( self, partitionLabel, imageName, fallbackPartLabel=None,
                 fallbackName=None ):
      self.partLabel = partitionLabel
      self.imageName = imageName
      self.fallbackPartLabel = fallbackPartLabel
      self.fallbackName = fallbackName

   def mountAndRead( self, partitionLabel, file, standby=False ):
      with mountRo( partitionLabel, standby=standby ) as mountDir:
         filePath = os.path.join( mountDir, file )
         data = readBinFileOnSupe( filePath, standby=standby )
      return data

   def getProgrammedVersion( self, standby=False ):
      fileData = self.mountAndRead( self.partLabel, self.imageName, standby=standby )
      return self._parseVersion( fileData )

   def getFallbackVersion( self, standby=False ):
      if not self.fallbackPartLabel or not self.fallbackName:
         return None
      fileData = self.mountAndRead( self.fallbackPartLabel, self.fallbackName,
                                    standby=standby )
      return self._parseVersion( fileData )

class AbootUnknownReader( AbootSpiFlashReaderBase ):

   # No gain compared to AbootSpiFlashReaderBase case, but keeps working interface
   def getProgrammedVersion( self, standby=False ):
      programmedVersion, _ = self.getProgrammedAndFallbackVersions( standby=standby )
      return programmedVersion

   def getFallbackVersion( self, standby=False ):
      _, fallbackVersion = self.getProgrammedAndFallbackVersions( standby=standby )
      return fallbackVersion

   def getProgrammedAndFallbackVersions( self, standby=False ):
      abootRe = AbootVersion.abootRe().encode()
      aboots = re.findall( abootRe, self.readFlash( standby=standby ) )
      # assume fallback section comes before programmed section
      if not aboots:
         return ( None, None )
      elif len( aboots ) == 1:
         return ( AbootVersion( aboots[ 0 ][ 0 ] ), None )
      else:
         return ( AbootVersion( aboots[ -1 ][ 0 ] ),
                  AbootVersion( aboots[ 0 ][ 0 ] ) )

class Aboot2Reader( AbootSpiFlashReaderBase ):
   imageSection = 'image'
   fallbackSection = None

class Aboot6Reader( AbootSpiFlashReaderBase ):
   imageSection = 'image'
   fallbackSection = 'fallback'

class Aboot7Reader( AbootSpiFlashReaderBase ):
   imageSection = 'normal'
   fallbackSection = 'fallback'

class Aboot9Reader( AbootSpiFlashReaderBase ):
   imageSection = 'coreboot'
   fallbackSection = 'bootblock'

class Aboot10Reader( AbootSpiFlashReaderBase ):
   imageSection = 'dxe'
   fallbackSection = 'bkp_dxe'

   def getProgrammedVersion( self, standby=False ):
      section = self.readFlash( section=self.imageSection, standby=standby )
      return AbootVersion( BiosBlinkLib.getAbootVersion( section ) )

   def getFallbackVersion( self, standby=False ):
      section = self.readFlash( section=self.fallbackSection, standby=standby )
      return AbootVersion( BiosBlinkLib.getAbootVersion( section ) )

class Aboot11Reader( AbootSpiFlashReaderBase ):
   imageSection = 'normal'
   fallbackSection = 'fallback'

class Aboot12Reader( AbootSpiFlashReaderBase ):
   imageSection = 'normal'
   fallbackSection = 'fallback'

class Aboot13Reader( AbootSpiFlashReaderBase ):
   imageSection = 'image'
   fallbackSection = None

   def getProgrammedVersion( self, standby=False ):
      section = self.readFlash( section=self.imageSection, standby=standby )
      return AbootVersion( BiosSlimLib.getAbootVersion( section ) )

   def getFallbackVersion( self, standby=False ):
      if self.fallbackSection is None:
         return None
      section = self.readFlash( section=self.fallbackSection, standby=standby )
      return AbootVersion( BiosSlimLib.getAbootVersion( section ) )

class Aboot16Reader( AbootReaderBase ):
   class Aboot16SpiFlashReader( AbootSpiFlashReaderBase ):
      fallbackSection = 'aboot'

   def __init__( self ):
      self.emmcReader = AbootBlkPartReaderBase( 'aboot0', 'Aboot.rom',
                                                fallbackPartLabel='aboot1',
                                                fallbackName='Aboot.rom' )
      self.goldenSpiFlashReader = self.Aboot16SpiFlashReader()

   def getProgrammedVersion( self, standby=False ):
      return self.emmcReader.getProgrammedVersion( standby=standby )

   def getFallbackVersion( self, standby=False ):
      return self.emmcReader.getFallbackVersion( standby=standby )

   def getGoldenVersion( self, standby=False ):
      return self.goldenSpiFlashReader.getProgrammedVersion( standby=standby )

class Aboot17Reader( AbootSpiFlashReaderBase ):
   imageSection = 'image'
   fallbackSection = None

   # flashrom -r total will fail on Willamette, since ME firmware is
   # not readable, so getProgrammedAndFallbackVersions will crash if
   # fallbackSection is none, since it attempts to read whole flash.
   def getFallbackVersion( self, standby=False ):
      return None

def getAbootReader( abootLine ):
   if abootLine == 'norcal2':
      abootReader = Aboot2Reader()
   elif abootLine in [ 'norcal4', 'norcal6' ]:
      abootReader = Aboot6Reader()
   elif abootLine == 'norcal7':
      abootReader = Aboot7Reader()
   elif abootLine == 'norcal9':
      abootReader = Aboot9Reader()
   elif abootLine == 'norcal10':
      abootReader = Aboot10Reader()
   elif abootLine == 'norcal11':
      abootReader = Aboot11Reader()
   elif abootLine == 'norcal12':
      abootReader = Aboot12Reader()
   elif abootLine == 'norcal13':
      abootReader = Aboot13Reader()
   elif abootLine == 'norcal16':
      abootReader = Aboot16Reader()
   elif abootLine == 'norcal17':
      abootReader = Aboot17Reader()
   else:
      abootReader = AbootUnknownReader()
   return abootReader

def getRunningAbootVersion( standby=False ):
   if standby:
      cmd = '%s -c "import FirmwareRev; print( ' \
            'FirmwareRev.abootFirmwareRev() )"' % sys.executable
      cmd = FileReplicationCmds.runCmd( Fru.slotId(), cmd, useKey=True )
      return Tac.run( cmd, asRoot=True, stdout=Tac.CAPTURE, stderr=Tac.DISCARD,
                      text=False ).decode( 'UTF-8', 'ignore' )

   return FirmwareRev.abootFirmwareRev()

def getAbootVersion( standby=False ):
   version = AbootVersion( getRunningAbootVersion( standby=standby ) )

   # Try and get a running version from flash
   abootReader = getAbootReader( version.norcal )
   programmedVersion = abootReader.getProgrammedVersion( standby=standby )
   if programmedVersion:
      programmedVersion.reconcileMissingFields( version )
      version = programmedVersion

   return version

def copyFileOnSupe( filepath ):
   tmp = tempfile.NamedTemporaryFile()
   simulation = 'SIMULATION_VMID' in os.environ

   cmd = FileReplicationCmds.runCmd( Fru.slotId(), "test -f %s" % filepath,
                                     useKey=True, loopback=simulation )

   try:
      Tac.run( cmd, stdout=Tac.DISCARD, stderr=Tac.DISCARD, asRoot=True )
   except Tac.SystemCommandError:
      raise OSError

   cmd = FileReplicationCmds.copyFile( Fru.slotId(), tmp.name, source=filepath,
                                       useKey=True, loopback=simulation,
                                       peerSource=True )
   Tac.run( cmd, stdout=Tac.CAPTURE, stderr=Tac.DISCARD, asRoot=True )
   # copyFile() uses rsync --no-p to support vfat, while the previous
   # call was ran as root, meaning we might not be able to read the file.
   Tac.run( ( [ "chgrp", "eosadmin", tmp.name ] ), stdout=Tac.DISCARD,
              stderr=Tac.DISCARD, asRoot=True )
   Tac.run( ( [ "chmod", "g+r", tmp.name ] ), stdout=Tac.DISCARD,
              stderr=Tac.DISCARD, asRoot=True )
   return tmp

def readBinFileOnSupe( filepath, standby=False ):
   if standby:
      tmpFile = copyFileOnSupe( filepath )
      filepath = tmpFile.name

   with open( filepath, "rb" ) as f:
      content = f.read()

   if standby:
      tmpFile.close()

   return content

def readFileOnSupe( filepath, standby=False ):
   if standby:
      tmpFile = copyFileOnSupe( filepath )
      filepath = tmpFile.name

   with open( filepath ) as f:
      yield from f

   if standby:
      tmpFile.close()

def getAbootFeatures( standby=False ):
   features = {}
   try:
      for line in readFileOnSupe( ABOOT_FEATURES_FILE, standby=standby ):
         m = re.match( r'([\w\-\.]+):([0-9]+)', line )
         if not m:
            continue
         try:
            ftRevision = int( m.group( 2 ) )
         except ValueError:
            continue
         features[ m.group( 1 ) ] = ftRevision
   except OSError:
      pass
   return features

def matchAbootFeature( featureName, func, standby=False ):
   features = getAbootFeatures( standby=standby )
   featureRev = features.get( featureName )

   return func( featureRev )
