# Copyright (c) 2022 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import atexit
import os
import re
import shutil
import tempfile
import zipfile

SKU_REGEX = r"^SKU:\s*(.+)\s*$"

def extractPlatformArchFile( ctx, tmpDir ):
   # Attempt to extract platform architecture data file from next image
   platformArchFn = "PlatformSupportedArch"
   platformArchPath = os.path.join( tmpDir, platformArchFn )
   try:
      with zipfile.ZipFile( ctx.nextImage ) as nextImage:
         if platformArchFn not in nextImage.namelist():
            ctx.addWarning( "Next image is missing platform architecture "
                            "data file." )
            return False
         nextImage.extract( platformArchFn, tmpDir )
   except zipfile.BadZipfile:
      # Raise error since a bad zip is not safe to reload with
      ctx.addError( "Unable to read next image." )
      return False

   return platformArchPath

# ReloadPolicy ctx's mode isn't configurable through the
# current infra, so have this test mode func
# which reads from the prefdl for the btest
def getTestEnvPlatformSku():
   skuList = []
   prefdlPath = '/etc/prefdl'
   if not os.path.exists( prefdlPath ):
      return False
   try:
      with open( prefdlPath, "r" ) as f:
         prefdlData = f.read()
   except IOError:
      return False

   skuMatch = re.search( SKU_REGEX, prefdlData, re.MULTILINE | re.IGNORECASE )
   if not skuMatch:
      return False

   skuList.append( skuMatch.group( 1 ) )
   return skuList

def getCurrentPlatformSku( ctx ):
   skuList = []
   # Access SKU from Sysdb. For modulars, we retrieve all inserted cards
   # SKUs and append them to skuList.
   chassisPrefdlPath = "/etc/prefdl-chassis"
   if os.path.exists( chassisPrefdlPath ):
      try:
         entityMib = ctx.mode.entityManager.root()[ 'hardware' ][ 'entmib' ]
         skuList.append( entityMib.root.modelName )
         for cardSlot in entityMib.chassis.cardSlot.values():
            if cardSlot.card:
               skuList.append( cardSlot.card.modelName )
      except: # pylint: disable=bare-except
         ctx.addWarning( "Unable to retrieve inserted card slot information." )
         return None
   else:
      try:
         entityMib = ctx.mode.entityManager.root()[ 'hardware' ][ 'entmib' ]
         skuList.append( entityMib.root.modelName )
      except: # pylint: disable=bare-except
         ctx.addWarning( "Unable to retrieve system model name." )
         return None

   return skuList

def checkPlatformsSupportedArch( ctx ):
   """Check the current systems supported OS Archs in
   the PlatformSupportedArch file from the nextImage.
   The PlatformSupported maps a SKU to a comma separated
   list of OS Architectures that the platform is supported on.
   If the platform is not supported in the next image's OS Arch,
   we print a warning but do not block the reload."
   """
   # Don't run reload policy check against old EOS versions
   if ctx.nextVersion.version() < "4.28.2":
      return True

   tmpDir = tempfile.mkdtemp()
   def cleanupTmpDir():
      shutil.rmtree( tmpDir )
   atexit.register( cleanupTmpDir )

   platformSupportFile = extractPlatformArchFile( ctx, tmpDir )
   if not platformSupportFile:
      return False

   # Retrieve SKUs supported OS architectures
   if os.environ.get( "PLATFORM_ARCH_TEST_ENV" ):
      skuList = getTestEnvPlatformSku()
   else:
      skuList = getCurrentPlatformSku( ctx )

   if not skuList:
      return True

   # Read in data from platformSupportFile and store in platformArchs dict
   reqArchsList = {}
   try:
      with open( platformSupportFile, "r" ) as f:
         for l in f.readlines():
            if l.strip().startswith( '#' ):
               continue
            currSku, archs = l.strip().split( "=" )
            archs = archs.split( "," )
            if currSku in skuList:
               reqArchsList[ currSku ] = archs
               break
   except IOError:
      ctx.addWarning( "Unable to access platform architecture data file from "
                      "next image." )
      return False

   if not reqArchsList:
      return True

   # Map nextImage's architecture to OS Architecture labels we use in SkuDb and
   # in the platform architecture data file
   nextImageArch = ctx.nextVersion.architecture()
   if nextImageArch in [ "i686", "i386" ]:
      nextImageArch = "EOS32"
   else:
      nextImageArch = "EOS64"

   # Add a warning message if the next image's arch isn't supported
   for reqArchs in reqArchsList.values():
      if nextImageArch not in reqArchs:
         requiredArchString = ""
         if "EOS32" in reqArchs:
            requiredArchString = "Please use 32 bit EOS."
         elif "EOS64" in reqArchs:
            requiredArchString = "Please use 64 bit EOS."

         errorMsg = ( "Next image's OS Architecture may not be supported on "
                      "this system. " + requiredArchString )
         ctx.addError( errorMsg )
         return False

   return True

def Plugin( ctx ):
   category = [ "ASU", "ASU+", "General" ]
   ctx.addPolicy( checkPlatformsSupportedArch, category )
