# Copyright (c) 2022 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import os
import re
import shutil
import tempfile
import zipfile

def extractFileFromNextImage( ctx, tmpDir, nextImage, filename ):
   '''Extract file from next image and return path to it.
   Arguments:
      ctx: _PluginContext: Any encountered errors, warnings will be added here.
                           Also used to get next image.
      tmpDir: String: Path temporary directory where the file will be extracted
      nextImage: ZipFile:
      filename: String: Name of the file to extract
   Returns:
      Path to extracted filename or None in case of failure
   '''
   if filename not in nextImage.namelist():
      ctx.addError( 'Next image is missing metadata file \'{}\''
                    .format( filename ) )
      return None
   nextImage.extract( filename, tmpDir )
   return os.path.join( tmpDir, filename )

def getImageFormatVersion( ctx, path ):
   '''Retrieve IMAGE_FORMAT_VERSION value from a given file.
   Arguments:
      ctx: _PluginContext: Any encountered errors, warnings will be added here.
      path: String: Path to version file
   Returns:
      Image format version in form of 'd+.d+'. If image format version cannot
      be determined (because 'version' file is missing entry) return '2.0' so
      reload won't be blocked. However if the version file is missing or cannot be
      accessed, return None to signalize error
   '''
   if not os.path.isfile( path ):
      return None
   try:
      with open( path ) as f:
         # Looking for such entry: IMAGE_FORMAT_VERSION=3.0
         readFile = f.read()
         version = re.search( r'IMAGE_FORMAT_VERSION=(\d+\.\d+)', readFile,
                              re.MULTILINE | re.IGNORECASE )
         if version:
            return version.group( 1 )
   except IOError:
      ctx.addError( 'Failed while accessing next image\'s version file' )
      return None

   # Downgrading to old images without reload policy infra will have the current
   # image's reload policy applied to them, so we allow for image format
   # to not be found in case this is a downgrade.
   return '2.0'

def compareVersions( ver1, ver2 ):
   '''Compare two versions. Each version is a string d+.d+
   Arguments:
      ver1, ver2: String: version to compare
   Returns:
      -1 if ver1 is smaller
       1 if ver1 is greater
       0 if versions are equal
   '''
   ver1a, ver1b = ver1.split( '.' )
   ver2a, ver2b = ver2.split( '.' )
   ver1a = int( ver1a )
   ver2a = int( ver2a )
   if ver1a < ver2a:
      return -1
   if ver1a > ver2a:
      return 1
   ver1b = int( ver1b )
   ver2b = int( ver2b )
   if ver1b < ver2b:
      return -1
   if ver1b > ver2b:
      return 1
   return 0

def fileOrderIsCorrect( ctx, nextImage ):
   '''Verify file order is correct in next image
   Arguments:
      ctx: _PluginContext: Any encountered errors, warnings will be added here.
                           Also used to get next image
      nextImage: ZipFile
   Returns:
      True if file order is correct, False otherwise
   '''
   infolist = nextImage.infolist()

   expectedFiles = set( [ 'swimSid2Optimization', 'swimSqshMap' ] )

   # First file can be a bootstrap swix, check if that is the case
   if infolist[ 0 ].filename.endswith( ".swix" ):
      filesInArchive = set( [ infolist[ 1 ].filename, infolist[ 2 ].filename ] )
   else:
      filesInArchive = set( [ infolist[ 0 ].filename, infolist[ 1 ].filename ] )

   if expectedFiles != filesInArchive:
      sortedFiles = sorted( expectedFiles )
      ctx.addError( 'Next image does not begin with bootstrap swix or '
                    'optimization control files '
                    '\'{}\' ' .format( ', '.join( v for v in sortedFiles ) ) +
                    'and cannot be optimized' )
      return False
   return True

def getSwimSqshMap( ctx, path ):
   '''Parse swimSqshMap
   Arguments:
      ctx: _PluginContext: Any encountered errors, warnings will be added here.
      path: String: Path to swimSqshMap file
   Returns:
      swimSqshMap in form of dictionary, key is optimization (Default, Strata...)
      and value are squashes which are needed by given optimization.
   '''
   swimSqshMap = {}
   try:
      with open( path ) as f:
         for line in f:
            optimization, squashes = line.split( '=' )
            swimSqshMap[ optimization ] = \
               [ s.strip() for s in squashes.split( ':' ) ]
   except IOError:
      ctx.addError( 'Unable to read next image\' metadata file \'{}\''
                    .format( path ) )
      return None
   # swimSqshMap was empty?
   if swimSqshMap == {}:
      ctx.addError( 'Failed while parsing {}'.format( os.path.basename( path ) ) )
      return None
   return swimSqshMap

def getSidFrom( ctx, path ):
   '''Looks for sid in a given file.
   Arguments:
      ctx: _PluginContext: Any encountered errors, warnings will be added here.
      path: String: Path to file that contain SID details
   Returns:
      sid in form of string or None in case of failure
   '''
   if not os.path.isfile( path ):
      return None
   try:
      with open( path ) as f:
         readFile = f.read()
         # Longer explanation of the regexp
         # (^|\s)sid(:|=)\s*(\S+)
         # \-1--/\2/\-3--/\4/\-5-/
         # 1) sid entry can be found at the start of the line or
         #    it can be found somewhere in the middle. Hence we have
         #    ^ - start of the line, \s - a single space
         # 2) 'sid' (duh) which is case insensitive (re.IGNORECASE)
         # 3) Just right after 'sid' we have either ':' or '='.
         #    ':' is found in prefdl files, '=' in cmdline files
         # 4) There might be some spaces or no spaces
         # 5) Grab all non-whitespace characters. That's our sid
         sidMatch = re.search( r'(^|\s)sid(:|=)\s*(\S+)', readFile,
                               re.MULTILINE | re.IGNORECASE )
         if sidMatch:
            return sidMatch.group( 3 )
      return None
   except IOError:
      ctx.addWarning( 'Unable to open/parse file \'{}\'' .format( path ) )
      return None

def getSid( ctx ):
   '''Looks for sid in /etc/prefdl and /proc/cmdline
   Arguments:
      ctx: _PluginContext: Any encountered errors, warnings will be added here.
   Returns:
      sid in form of string or None in case of failure
   '''
   sid = getSidFrom( ctx, '/etc/prefdl' )
   if sid:
      return sid
   # If sid cannot be found in /etc/prefdl there's still a chance
   # it will be present in /proc/cmdline
   sid = getSidFrom( ctx, '/proc/cmdline' )
   if not sid:
      ctx.addWarning( 'Unable to discover platform SID information' )
   return sid

def getSidOptimization( ctx, sid, path ):
   '''Looks for squash name of directory for given sid in provided file
   Arguments:
      ctx: _PluginContext: Any encountered errors, warnings will be added here.
      sid: String: sid to look for
      path: String: path to file containing sid->optimization entries
   Returns:
      optimization in form of a string or None in case of failure
   '''
   with open( path ) as f:
      for line in f:
         # Because Capitola*:Strata-4GB
         entry, optimization = line.split( ':' )
         if entry[ -1 ] == '*':
            # Make that '*' regexp readable
            entry = entry[ : -1 ] + r'\S*'
         reStr = r'^%s$' % entry
         entryMatch = re.search( reStr, sid,
                                 re.IGNORECASE )
         if entryMatch:
            # Remove endlines
            return optimization.strip()
   return None

def getSupportedOptimizations( squashList, swimSqshMap ):
   '''Returns list of supported optimization by the next image
   Arguments:
      squasList: list: squashes in next image
      swimSqshMap: dict: swimSqshMap in form of a dictionary
   Return:
      List of supported optimizations
   '''
   supportedOptimizations = []
   for optimization, requiredSquashes in swimSqshMap.items():
      if set( requiredSquashes ).issubset( set( squashList ) ):
         supportedOptimizations.append( optimization )

   return supportedOptimizations

def checkIfSquashesExist( ctx, nextImage, optimization, swimSqshMap ):
   '''Checks if needed squashes exist in next image
   Arguments:
      ctx: _PluginContext: Any encountered errors, warnings will be added here
      nextImage: ZipFile:
      optimization: String: Optimization to check
      swimSqshMap: dict: swimSqshMap in form of a dictionary
   Return:
      True if required squashes exist for optimizationCandidate or default
      or default-dpe, False otherwise.
   '''
   squashList = [ f for f in nextImage.namelist()
                       if f.endswith( '.sqsh' ) ]
   if optimization not in swimSqshMap:
      optimizationDPE = optimization + '-DPE'
      if optimizationDPE not in swimSqshMap:
         ctx.addError( 'Next image does not support the {} optimization'
                       .format( optimization ) )
         supportedOptimizations = \
               getSupportedOptimizations( squashList, swimSqshMap )
         if supportedOptimizations:
            ctx.addWarning( 'Next image supports the following optimizations: {}'
                         .format( ', '.join( supportedOptimizations ) ) )
         else:
            ctx.addError( 'Next image does not support any optimizations' )

         return False
      requiredSquashes = swimSqshMap[ optimizationDPE ]
   else:
      requiredSquashes = swimSqshMap[ optimization ]

   if not set( requiredSquashes ).issubset( set( squashList ) ):
      ctx.addError( 'Next image does not contain required filesystem' )
      return False
   return True

def checkNeededSquashes( ctx ):
   ''' Policy that checks if all needed squashes are present in next image,
   first two files are swimSqshMap and swimSid2Optimization. Also checks
   for the version (if version is before 3.0, this policy does not apply.
   Arguments:
      ctx: _PluginContext: Any encountered errors, warnings will be added here
   Returns:
      True is next image can be loaded, False otherwise
   '''
   tmpDir = tempfile.mkdtemp( dir='/tmp' )
   try:
      with zipfile.ZipFile( ctx.nextImage ) as nextImage:
         result = checkOptimizedSwimImage( ctx, nextImage, tmpDir )
   except zipfile.BadZipfile:
      ctx.addError( 'Unable to read next image' )
      result = False

   shutil.rmtree( tmpDir )
   return result

def checkOptimizedSwimImage( ctx, nextImage, tmpDir ):
   # Try to read version file from next image
   swiVersionPath = extractFileFromNextImage( ctx, tmpDir, nextImage, 'version' )
   if swiVersionPath:
      # Read image format version from the version file
      nextImageFormatVersion = getImageFormatVersion( ctx, swiVersionPath )
      if nextImageFormatVersion:
         minimumVersion = '3.0'
         if compareVersions( nextImageFormatVersion, minimumVersion ) == -1:
            # Old image format, ignore further checks
            return True
      # Cannot determine version, something's wrong
      else:
         return False
   # If 'version' file cannot be found that is a blocker
   else:
      return False
   # Extract files needed for checks, but don't block reload yet, if some of them
   # are missing. We want to gather as much information as possible.
   pathToSidMap = extractFileFromNextImage( ctx, tmpDir, nextImage,
                                            'swimSid2Optimization' )
   pathToSwimSqshMap = extractFileFromNextImage( ctx, tmpDir, nextImage,
                                                 'swimSqshMap' )
   # Try to get the SID, but also don't block reload yet.
   sid = getSid( ctx )
   # If any of those files are missing, then something is wrong with the next image
   if not pathToSidMap or not pathToSwimSqshMap:
      return False
   # Check if files are in order.If not, image was created incorrectly
   if not fileOrderIsCorrect( ctx, nextImage ):
      return False
   # Not having a sid is a problem, but don't block reload
   if not sid:
      return True
   # Try to find optimization for given SID.
   optimization = getSidOptimization( ctx, sid, pathToSidMap )
   optimization = 'Default' if optimization is None else optimization
   # Parse the swimSqshMap. If it cannot be parsed (for example file is corrupt)
   # this will block the reload.
   swimSqshMap = getSwimSqshMap( ctx, pathToSwimSqshMap )
   if not swimSqshMap:
      # Error logged inside function getSwimSqshMap
      return False
   return checkIfSquashesExist( ctx, nextImage, optimization, swimSqshMap )

def Plugin( ctx ):
   # Define a set of categories for this plugin.
   category = [ "ASU", "ASU+", "General" ]
   # Register the check with the context.
   ctx.addPolicy( checkNeededSquashes, category )
