# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
from __future__ import absolute_import, division, print_function

import errno
import os
import tempfile
import shutil
import fnmatch
import Plugins
import EosVersion
import Logging
import SwiSignLib

def doCheckProperties( noExtract=False, plugins=None, pluginPath=None,
                       versionPath=None, exceptionsPath=None ):
   '''
   This function is a decorator to set the path from which doCheck should load
   ReloadPolicy plugins should the next image not support ReloadPolicy.
   noExtract is "True" if doCheck shouldn't load plugins from the next Swi.
   plugins is a comma-separated list of names of plugins to run.
   pluginPath is a comma-separated list of directories containing plugins.
   versionPath is a 'swi-version' file.
   exceptionsPath is a ReloadPolicy exceptions file.

   Usage:
            @ReloadPolicy.doCheckProperties( noExtract=<"True" | "False">,
                                             plugins=<list of plugin names>,
                                             pluginPath=<list of plugin directories>,
                                             versionPath=<version file>,
                                             exceptionsPath=<exceptions file> )
            def func(): ...

   '''
   def decorator( func ):
      def setPath( *args, **kwargs ):
         varPath = [ ( 'RELOADPOLICY_NO_EXTRACT', noExtract ),
                     ( 'RELOADPOLICY_PLUGINS', plugins ),
                     ( 'RELOADPOLICY_PLUGIN_PATH', pluginPath ),
                     ( 'RELOADPOLICY_VERSION_PATH', versionPath ),
                     ( 'RELOADPOLICY_EXCEPTIONS_PATH', exceptionsPath ) ]
         for var, path in varPath:
            if path:
               os.environ[ var ] = path
         val = func( *args, **kwargs )
         for var, path in varPath:
            if path:
               del os.environ[ var ]
         return val
      return setPath
   return decorator

def doCheck( imagePath, category=None, mode=None,
             abootSbStatus=None, syslogNextSwiVersion=False ):
   '''
   Run ReloadPolicy checks of the specified category from the image in imagePath
   and return a collection of errors and warnings generated by the checks.
   '''
   # Create plugin context.
   noExtract = None
   pc = _PluginContext( imagePath, category, mode )
   verified, msg, caUsed = SwiSignLib.verifySwiSignature( imagePath, userHint=True )
   if not verified:
      # If the path is not a SWI, then add an error and exit.
      # If SecureBoot is enabled and the image is unsigned or has a bad
      # signature, then add an error and exit as the switch isn't going
      # to allow the image to boot.
      # If the image is unsigned or has a bad signature and SecureBoot
      # is disabled, then add a warning and continue.
      if ( 'does not seem to be a swi image' in msg
           or ( abootSbStatus and abootSbStatus.supported and
                not abootSbStatus.securebootDisabled ) ):
         pc.addError( msg )
         return pc.retVal, caUsed
      else:
         pc.addWarning( msg )
   try:
      pc.getVersionInfo( imagePath, syslogNextSwiVersion=syslogNextSwiVersion )
   except _PluginContext.VersionError as e:
      pc.addError( str( e ) )
      return pc.retVal, caUsed

   # Check for test overrides
   plugins, pluginPath = None, None
   if not noExtract and os.environ.get( 'RELOADPOLICY_NO_EXTRACT' ):
      noExtract = os.environ[ 'RELOADPOLICY_NO_EXTRACT' ]
   if os.environ.get( 'RELOADPOLICY_PLUGINS' ):
      plugins = os.environ[ 'RELOADPOLICY_PLUGINS' ].split( ',' )
   if os.environ.get( 'RELOADPOLICY_PLUGIN_PATH' ):
      pluginPath = os.environ[ 'RELOADPOLICY_PLUGIN_PATH' ].split( ',' )

   # Check if the image is in the exceptions file.
   ec = _ExceptionChecker( pc.nextVersion.version() )
   if ec.isBlock():
      # Block rule: prevent loading the next image.
      pc.addError( "Image has incompatible version %s." %
                   pc.nextVersion.version() )
      return pc.retVal, caUsed
   elif ec.isOverride() or noExtract == 'True':
      # Override rule: prevent loading the next image's plugins.
      loadedPlugins = _loadPlugins( pc, plugins, pluginPath )
   else:
      # Extract plugins from the next image.
      try:
         with _PluginExtractor( imagePath ) as pe:
            baseModuleName, extractPath = pe.extract()
            pluginPath = [ extractPath ]
            loadedPlugins = _loadPlugins( pc, plugins, pluginPath, baseModuleName )
      except _PluginExtractor.NoReloadPolicyError:
         pc.addWarning( "Image does not support next image compatibility checks."
                        " Running only checks from the current image." )
         loadedPlugins = _loadPlugins( pc, plugins, pluginPath )
      except IOError as e:
         pc.addWarning( "Failed to run next image compatibility checks on the image "
                        "(%s). Resolve the error to ensure compatibility or proceed "
                        "with reload if the image is known to be supported on this "
                        "system." % str( e ) )
         return pc.retVal, caUsed

   # Run the policies.
   assert loadedPlugins
   pc.runPolicies()

   # Return a collection of errors and warnings.
   return pc.retVal, caUsed

class _PluginExtractor( object ):
   '''
   Context manager to extract ReloadPolicy plugins from nextImagePath into
   a temporary plugin directory for loading them as modules. This is used
   so that we don't need to worry about cleaning up the temporary plugin
   directory.

   Usage:
            with PluginExtractor( imagePath ) as pluginPath:
               # do something with pluginPath temporary directory
   '''

   class NoReloadPolicyError( Exception ):
      pass

   def __init__( self, nextImagePath ):
      self.pluginDirRoot = None
      self.imagePath = nextImagePath
      self.plugins = None

   def __enter__( self ):
      ''' Check that plugins exist and create a directory for extraction. '''
      import zipfile
      with zipfile.ZipFile( self.imagePath ) as imageFile:
         self.plugins = [ path for path in imageFile.namelist()
                          if path.split( '/' )[ 0 ] == 'ReloadPolicyPlugin' ]
         if not self.plugins:
            raise self.NoReloadPolicyError()
         self.pluginDirRoot = tempfile.mkdtemp()
         return self

   def __exit__( self, *args ):
      shutil.rmtree( self.pluginDirRoot )

   def extract( self ):
      '''
      Extract plugins from the image to the temporary directory. Return
      the directory path.

      The extracted base module name is 'ReloadPolicyPlugin-' plus a suffix that
      uniquely identifies the image.
      '''

      def imageInfoChecksum( imageZipFile ):
         # Don't compute the checksum of the entire image, as that takes too long.
         # Instead, just compute the checksum of the file info list and the content
         # of the version file. Using this checksum makes development and testing
         # easier as it allows us to alter the image contents post abuild and have
         # ReloadPolicy recognize it as a new image.
         import hashlib
         chksm = hashlib.md5()

         for fileInfo in imageZipFile.infolist():
            chksm.update( '{} {} {} {}\n'.format(
                          # Upper 16 bits of external_attr have permissions.
                          fileInfo.external_attr >> 16,
                          fileInfo.file_size,
                          fileInfo.date_time,
                          fileInfo.filename ).encode( 'utf-8' ) )
         # We can assume that the version file exists since it's always checked
         # by getVersionInfo first.
         chksm.update( imageZipFile.read( 'version' ) )
         return chksm.hexdigest()

      import zipfile
      try:
         with zipfile.ZipFile( self.imagePath ) as imageFile:
            imageFile.extractall( self.pluginDirRoot, self.plugins )
            suffix = imageInfoChecksum( imageFile )
            oldPluginPath = os.path.join( self.pluginDirRoot, 'ReloadPolicyPlugin' )
            baseModuleName = 'ReloadPolicyPlugin' + '-' + suffix
            newPluginPath = os.path.join( self.pluginDirRoot, baseModuleName )
            os.rename( oldPluginPath, newPluginPath )
            return ( baseModuleName, newPluginPath )
      except IOError as e:
         # If the disk is full, specify the device in the error message.
         if e.errno == errno.ENOSPC and self.pluginDirRoot:
            e.strerror = e.strerror + ' ' + os.path.dirname( self.pluginDirRoot )
         raise

class _PluginContext( object ):
   '''
   The plugin context stores all ReloadPolicy check policy and cleanup functions,
   which will populate the context when Plugins.loadPlugins is called. The context
   is then passed to these functions when they run, allowing the functions to add
   errors and warnings resulting from the checks.
   '''

   class PluginRetVal( object ):
      '''
      Simple object to store errors and warnings collected from the ReloadPolicy
      checks.
      '''
      def __init__( self ):
         self.errors = []
         self.warnings = []
         self.policySuccess = {} # plugin name : True | False

      def addError( self, error ):
         self.errors.append( error )

      def addWarning( self, warning ):
         self.warnings.append( warning )

      def hasError( self ):
         return bool( self.errors )

      def addPluginResult( self, plugin, success ):
         # Since plugin name has full module name, split at last period
         # to get basename.
         if success is not None:
            pluginName = plugin.rsplit( '.', 1 )[ 1 ]
            assert pluginName not in self.policySuccess
            self.policySuccess[ pluginName ] = success

   def __init__( self, imagePath, category, mode ):
      '''
      The category argument is a list of plugin categories. Only policy
      and cleanup functions from plugins of categories contained in the
      category list will be run.
      '''
      self.nextImage = imagePath
      self.category = category
      self.mode = mode
      self.policies = []
      self.cleanups = []
      self.currentVersion = None
      self.nextVersion = None
      self.retVal = self.PluginRetVal()

   class VersionError( Exception ):
      ''' Error if error occurs getting version info. '''
      pass

   def getVersionInfo( self, nextImagePath, syslogNextSwiVersion=False ):
      '''
      Get the current and next image version information and add it to the
      context.
      '''

      SYS_SYSTEM_INFO = Logging.LogHandle(
                    "SYS_SYSTEM_INFO",
                    severity=Logging.logInfo,
                    fmt="%s",
                    explanation="[ Informational log message ]",
                    recommendedAction=Logging.NO_ACTION_REQUIRED )

      def _getCurrentVersionInfo():
         '''
         Read the 'swi-version' file in the running image (or in the specified
         file if the hook is used) and create a VersionInfo object. Raise a
         VersionError if the file path does not exist.
         '''
         currImagePath = os.environ.get( 'RELOADPOLICY_VERSION_PATH' )
         if not currImagePath:
            currImagePath = '/etc/swi-version'
         try:
            with open( currImagePath, 'r' ) as cp:
               info = cp.read()
            currInfo = EosVersion.VersionInfo( None, versionFile=info, arch="" )
         except IOError:
            if syslogNextSwiVersion:
               Logging.log( SYS_SYSTEM_INFO, "Software image: ( unknown )" )
            raise self.VersionError( "Unable to read version information from"
                                     " running image." )
         if syslogNextSwiVersion:
            Logging.log( SYS_SYSTEM_INFO, "Software image: %s" % currInfo.version() )
         return currInfo

      def _getNextVersionInfo( nextImagePath ):
         '''
         Read the 'version' file from the next image and create a VersionInfo
         object.
         '''
         nextInfo = EosVersion.swiVersion( nextImagePath )
         if not nextInfo.version():
            if syslogNextSwiVersion:
               Logging.log( SYS_SYSTEM_INFO,
                            "Attempting reload to software"
                            " image version: ( unknown )" )
            raise self.VersionError( "Image has invalid version file." )
         if syslogNextSwiVersion:
            Logging.log( SYS_SYSTEM_INFO,
                         "Attempting reload to software image version: %s" %
                         ( nextInfo.version() ) )
         return nextInfo

      self.currentVersion = _getCurrentVersionInfo()
      self.nextVersion = _getNextVersionInfo( nextImagePath )
      assert self.currentVersion and self.nextVersion

   def addPolicy( self, func, category ):
      '''
      Register a policy function of the specified category to the
      context. Do not register functions of unwanted categories.
      '''
      if not self.category or not set( category ).isdisjoint( self.category ):
         self.policies.append( func )

   def addCleanup( self, func, category ):
      '''
      Register a cleanup function of the specified category to the
      context. Do not register functions of unwanted categories.
      '''
      if not self.category or not set( category ).isdisjoint( self.category ):
         self.cleanups.append( func )

   def addError( self, error ):
      ''' Add an error message to the output collection. '''
      self.retVal.addError( error )

   def addWarning( self, warning ):
      ''' Add a warning message to the output collection. '''
      self.retVal.addWarning( warning )

   def runPolicies( self ):
      for policyFunc in self.policies:
         success = policyFunc( self )
         self.retVal.addPluginResult( policyFunc.__module__, success )
      if self.retVal.hasError():
         for cleanupFunc in self.cleanups:
            cleanupFunc( self )

class _ExceptionChecker( object ):
   '''
   Check the ReloadPolicy exceptions file to see if the next image matches
   an exception. Takes the next image's version string as an argument in
   the constructor.
   '''
   def __init__( self, version ):
      self.match = None

      # Import the exceptions.
      try:
         exceptionsPath = os.environ.get( 'RELOADPOLICY_EXCEPTIONS_PATH' )
         if exceptionsPath and os.path.isfile( exceptionsPath ):
            import imp
            with open( exceptionsPath ) as ep:
               exceptionsModule = imp.new_module( 'Exceptions' )
               exec( ep.read(), exceptionsModule.__dict__ )
         else:
            import ReloadPolicy.Exceptions as exceptionsModule

         # Look for a match.
         if exceptionsModule:
            self.match = self._findMatch( version,
                                          exceptionsModule.policyExceptions )
      except ( ImportError, IOError ):
         # Consider a bad exceptions file a "no match".
         pass

   def _findMatch( self, version, policyExceptions ):
      '''
      policyExceptions is a dictionary whos keys are version strings supporting
      Unix-style wildcards. The version string that best matches the next image's
      version string will be considered the match. Here, "best" is determined as
      follows:
               1. If the exception version string is an exact match to the
                  next image version string, choose that match.
               2. Otherwise, the longest wildcard match is taken.
               3. If no wildcard matches, then no match.
      '''

      def bestMatch( match ):
         return match == version, len( match )

      def isMatch( exceptionVersion ):
         return exceptionVersion == version or fnmatch.fnmatch( version,
                                                                exceptionVersion )

      # Filter the policyException version strings for matches, then take the
      # "best" match as defined above.
      matches = [ p for p in policyExceptions if isMatch( p ) ]
      return policyExceptions[ max( matches, key=bestMatch )
                             ] if matches else None

   def isBlock( self ):
      return self.match == 'Block'

   def isOverride( self ):
      return self.match == 'Override'

def _loadPlugins( ctx, plugins, pluginPath, baseModuleName='ReloadPolicyPlugin' ):
   assert not pluginPath or isinstance( pluginPath, list )
   return Plugins.loadPlugins( baseModuleName, ctx, plugins, pluginPath )

