# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import os
import Plugins
import Logging
import Tracing
import sys

traceHandle = Tracing.Handle( 'FirmwareUtils' )
t0 = Tracing.trace0

# This framework can be used to upgrade firmware of any device.
# To use this framework, create a python plugin in /src/<pkg>/FirmwarePlugin/ dir.
# The plugin needs to extend following classes and implement/override methods
# mentioned below.
# 1. class Inventory
# Extend this class to implement inventory of the physical devices which needs
# to be upgraded. Plugin needs to implement method populate()
# populate method discovers all the devices and adds it to the inventory.
# 2. class FirmwareVersion
# Every firmware has different version format. Plugin needs to extend this
# class in order to do firmware version comparision.
# 3. class Device
# Plugin needs to extend this class in order to map actual physical device
# properties to class members. Method upgrade() needs to be implemented. This
# method is be called by the framework to upgrade the device.
# For an example, please checkout test/sampleFirmwarePlugin.py

HARDWARE_FIRMWARE_UPGRADEFAILED = Logging.LogHandle(
      "HARDWARE_FIRMWARE_UPGRADEFAILED",
      severity=Logging.logError,
      fmt=( "Failed to perform an upgrade of %s from version %s to version %s. "
      "The system could not perform firmware upgrade. This may be due to a "
      "hardware or software failure. System functionality will be severely "
      "limited." ),
      explanation="Failed to perform an upgrade of a device firmware",
      recommendedAction=Logging.CONTACT_SUPPORT )

# Print to console or to stdout or both.
class Tee:
   def __init__( self, *devices ):
      self.consoleDevice = None
      self.stdoutDevice = None

   def init( self ):
      self.consoleDevice = open( "/dev/console", "a" )
      self.stdoutDevice = sys.stdout

   def __print( self, msg, devices ):
      for f in devices:
         if f:
            print( msg, file=f )

   def console( self, msg ):
      if not self.consoleDevice:
         self.init()
      self.__print( msg, [ self.consoleDevice ] )

   def out( self, msg ):
      if not self.stdoutDevice:
         self.init()
      self.__print( msg, [ self.stdoutDevice ] )

   def all( self, msg ):
      if not self.stdoutDevice and not self.consoleDevice:
         self.init()
      self.__print( msg, [ self.consoleDevice, self.stdoutDevice ] )

class UpgradeError( Exception ):
   def __init__( self, msg=None ):
      super().__init__( msg )
      self.msg = msg

# To remember that the firmware upgrade is in progress, we need a flag/marker
# which is persistent across system reboot.
# Upgrade method in class Device uses instance of class Marker to create a
# persistent marker.
class Marker:
   def __init__( self, path ):
      self.path = path

   def create( self ):
      f = open( self.path, "w" )
      f.close()

   def delete( self ):
      if self.exists( self.path ):
         os.unlink( self.path )

   def exists( self, path='' ):
      if not path and self.path:
         path = self.path
      return os.path.exists( path )

# An instance of Firmware class is passed as a context by LoadPlugin.
# Plugins use firmware.register method to register themselves.
class Firmware:
   def __init__( self ):
      self.pluginList_ = []

   def register( self, firmwareObj ):
      self.pluginList_.append( firmwareObj )

   def pluginList( self ):
      return self.pluginList_

# This class is extended by FirmwarePlugin. Instance of extended class is
# created for each device that needs firmware upgrade.
class Device:
   def __init__( self, name ):
      self.name = name
      self.marker = None
      self.markerDir = "/mnt/flash"

   # Extended class needs to override this method to implement a way to upgrade
   # device firmware
   def upgrade( self ):
      raise NotImplementedError

   # Returns true if device needs reboot after upgrade
   def postUpgradeReboot( self ):
      raise NotImplementedError

   def getLocalVersion( self ):
      raise NotImplementedError

   def getHwVersion( self ):
      raise NotImplementedError

   def getDescription( self ):
      raise NotImplementedError

   def needsUpgrade( self ):
      hwVersion = self.getHwVersion()       # Current firmware version.
      localVersion = self.getLocalVersion() # Latest available firmware version.
      return hwVersion != localVersion

   def createMarker( self ):
      markerPath = f"{self.markerDir}/.{self.name}_marker"
      self.marker = Marker( markerPath )
      self.marker.create()

   def deleteMarker( self ):
      self.marker.delete()
      self.marker = None

   def markerExists( self ):
      markerPath = f"{self.markerDir}/.{self.name}_marker"
      if not self.marker:
         self.marker = Marker( markerPath )
      return self.marker.exists( markerPath )

# Every type of device (class Inventory) needs to extend this class to
# map a device firmware version to an object
class FirmwareVersion:
   def __init__( self, version ):
      # version could be integer, string etc.
      self.version = version

   def __eq__( self, other ):
      return self.version == other.version

   def __ne__( self, other ):
      return not self.__eq__( other )

   def __lt__( self, other ):
      return self.version < other.version

   def __key( self ):
      return self.version

   def __hash__( self ):
      return hash( self.__key() )

   def __str__( self ):
      raise NotImplementedError

# Each Firmware plugin should extend this class to keep inventory of same type
# of devices
class Inventory:
   RESET_MESSAGE = "Power cycling the system after successful upgrade."

   def __init__( self ):
      self.inventory = []
      self.tee = Tee()

   # Extended class should override the populate method to create a list same
   # type of devices
   # eg. some plugins might refer to fdl, some might populate devices dynamically.
   def populate( self ):
      raise NotImplementedError

   # Extend this method to reset the device.
   def reset( self ):
      raise NotImplementedError

   # Child class may override this to customise the syslog behavior.
   def syslogUpgradeFailed( self, device, _afterReboot ):
      Logging.log( HARDWARE_FIRMWARE_UPGRADEFAILED, device.name,
                   device.getHwVersion(), device.getLocalVersion() )

   def handleMarkerExistsOnBootUp( self, device ):
      if device.needsUpgrade():
         # First boot after unsuccessful upgrade. tee and continue.
         # Failed upgrade case gets handled after all devices are upgraded.
         self.syslogUpgradeFailed( device, True )
         msg = ( "Upgrade of the %s failed.\n"
               % device.getDescription() )
         self.tee.all( msg )
      else:
         # First boot after successful upgrade.
         # Marker is placed to avoid the upgrade loop on upgrade failure.
         msg = ( "Upgrade of the %s completed successfully."
               % device.getDescription() )
         self.tee.all( msg )
         device.deleteMarker()

   def doUpgrade( self, device ):
      # Upgrade the firmware.
      # It is important that the marker is created before the call to
      # upgrade(), as the storage plugin may update the /mnt/flash firmware.
      # The device may be unavailable until reboot.
      device.createMarker()
      msg = ( "Upgrading %s to %s." %
            ( device.getDescription(), device.getLocalVersion() ) )
      self.tee.out( msg )
      msg = ( "\n-----------------------------------------------------\n"
              " Upgrading the %s.\n"
              " This process can take several minutes.\n"
              " Please do not reboot your switch.\n"
              "-----------------------------------------------------"
            % device.getDescription() )
      self.tee.all( msg )
      try:
         device.upgrade()
      except UpgradeError as e:
         self.syslogUpgradeFailed( device, False )
         msg = e.msg or ( "Upgrading the %s failed.\nSystem functionality will be "
                          "severely limited." % device.getDescription() )
         self.tee.all( msg + "\n" )
         # If upgrade failed, then we don't need to reboot.
         deviceNeedsReboot = False
      else:
         deviceNeedsReboot = device.postUpgradeReboot()
         if not deviceNeedsReboot:
            # The result of `needsUpgrade()`/`getHwVersion()` may be cached by the
            # device object implementation. Consider an `upgrade()` run that doesn't
            # throw `UpgradeError` and doesn't need to reboot as a success. If a
            # reboot is needed, success will be checked on the next boot up.
            msg = ( "Upgrade of the %s completed successfully."
                  % device.getDescription() )
            self.tee.all( msg )

      if not deviceNeedsReboot:
         # A marker exists solely to prevent boot loop. If we're not going
         # to reboot for this device, then we don't need a marker.
         device.deleteMarker()

      return deviceNeedsReboot

   def upgrade( self ):
      tee = self.tee
      postUpgradeReboot = False
      for device in self.inventory:
         hwVersion = device.getHwVersion()
         localVersion = device.getLocalVersion()
         msg = ( "Checking to see if the %s needs to be upgraded.\n"
                "Version in hardware is %s, have version %s locally." %
               ( device.getDescription(), hwVersion, localVersion ) )
         tee.out( msg )

         if device.markerExists():
            self.handleMarkerExistsOnBootUp( device )
         else:
            if device.needsUpgrade():
               postUpgradeReboot |= self.doUpgrade( device )
            else:
               msg = ( "%s does not need to be upgraded." %
                     device.getDescription() )
               self.tee.out( msg )

      if postUpgradeReboot:
         tee.all( self.RESET_MESSAGE + "\n" )
         self.reset()
         msg = "Failed to power cycle system"
         tee.all( msg )
         raise AssertionError( msg )

   def handleFailedUpgrades( self ):
      for device in self.inventory:
         if device.markerExists() and device.needsUpgrade():
            # Marker was placed to avoid the upgrade loop.
            # Since it is not required and user would want to try the upgrade
            # again, delete the marker.
            device.deleteMarker()

def firmwareUpgrade():
   firmware = Firmware()
   Plugins.loadPlugins( "FirmwarePlugin", context=firmware,
         filenameFilter=lambda f: "Test.py" not in f )
   for fw in firmware.pluginList():
      fw.populate()
      fw.upgrade()

   for fw in firmware.pluginList():
      fw.handleFailedUpgrades()
