# Copyright (c) 2008-2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import tempfile
import glob
import os
import errno
import fnmatch
import re
import sys
import subprocess
import shutil
import zipfile
import Swi.extract, Swi.create
import Swi.sign
import SwimFlatten
import SwiSignLib
import VerifySwi
from SwimHelperLib import parseSwimSqshMap, mountSwimRootfs, unmountSwimRootfs, \
                          unmountAndCleanupWd, getSwimSqshMapAndOverlayDirs, \
                          mountPathOverLowers

ARCH_x86_64 = "x86_64"

class SwiOptions:
   def __init__( self, **kargs ):
      self.update( **kargs )
   def update( self, **kargs ):
      self.__dict__.update( **kargs )

def run( argv ):
   subprocess.check_call( argv )

def runAndReturnOutput( argv, printStdout=True ):
   # pylint: disable-next=consider-using-with
   p = subprocess.Popen( argv, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
      universal_newlines=True )
   stdoutdata, stderrdata = p.communicate()
   if p.returncode != 0:
      print( "Command %s returned error code %d" % ( argv, p.returncode ) )
   if stdoutdata and printStdout:
      print( stdoutdata )
   if stderrdata:
      print( stderrdata )
   assert p.returncode == 0
   return stdoutdata

def secureBootDutGrabbed( quiet=False ):
   # If secbootHelper is available, use it to determine if any dut is using
   # secure boot and update SWI signature if that is the case.
   try:
      subprocess.check_call( [ "which", "SwiSecureBootHelper.py" ] )
   except subprocess.CalledProcessError:
      print( "SwiSecureBootHelper unavailable. " )
      return False

   try:
      if not quiet:
         print( "Checking if user has any secure-boot DUT grabbed." )
      hasSecboot = subprocess.check_output( [ "SwiSecureBootHelper.py" ],
                                            timeout=12 )
      if "True" in hasSecboot.decode():
         if not quiet:
            print( "Grabbed DUT has secure-boot." )
         return True
   except subprocess.CalledProcessError:
      print( "Warning: SwiSecureBootHelper exited abnormally." )
   except subprocess.TimeoutExpired:
      print( "Warning: SwiSecureBootHelper timed out. Resuming as"
             " if no secure boot DUTs are grabbed." )

   return False

def rpmdbAndVersionFilesWildcard():
   return [ "*.rpmdb.sqsh", "*.version", "version" ]

def getFirstRpmdbDir( dir ): # pylint: disable=redefined-builtin
   dirs = glob.glob( os.path.join( dir, "*.rpmdb.dir" ) )
   return None if not dirs else dirs[ 0 ]

def getVersionFiles( dir ): # pylint: disable=redefined-builtin
   verfiles = glob.glob( os.path.join( dir, "*.version" ) )
   verfiles.extend( glob.glob( os.path.join( dir, "version" ) ) )
   return verfiles

def warnIfMultipleRpmdbs( rpmDbDirs ):
   if len( rpmDbDirs ) > 1:
      print( "WARNING: SWI command is being run on an image with multiple "
             "RPM databases which may result in undesired behavior" )

def zipContains( zipFileName, zippedFileWildcard ):
   if not os.path.isfile( zipFileName ):
      raise FileNotFoundError( errno.ENOENT, os.strerror( errno.ENOENT ),
                              zipFileName )

   with zipfile.ZipFile( zipFileName ) as zipFile:
      if fnmatch.filter( zipFile.namelist(), zippedFileWildcard ):
         return True

   return False

def extractRpmdbAndVersionFiles( swifile, workdir, quiet, updateVersion=True ):
   flist = ( [ "*.rpmdb.sqsh" ] if not updateVersion else
           rpmdbAndVersionFilesWildcard() )

   # Unzip each entry separately in case any of the patterns won't match
   # returning an error.
   for entry in flist:
      if ( not zipContains( swifile, entry ) and
           # Skip *.version file check until BUG976637 is fixed
           entry != "*.version" ):
         raise KeyError( f'The {swifile} is missing {entry}' )

      unzipcmd = [ "unzip" ]
      # Quiet arg must come before the zip file, otherwise unzip binary doesn't
      # run in quiet mode.
      if quiet:
         unzipcmd.append( "-q" )
      unzipcmd.append( swifile )
      unzipcmd.append( entry )
      unzipcmd.extend( [ "-d", workdir ] )

      p = subprocess.Popen( unzipcmd, # pylint: disable=consider-using-with
                           stdout=subprocess.PIPE,
                           stderr=subprocess.PIPE )
      out, err = p.communicate()
      if ( err # pylint: disable=no-else-continue
           and p.returncode == 11
           and "filename not matched:  %s" % entry in err.decode() ):
         continue
      elif err:
         raise subprocess.CalledProcessError( returncode=p.returncode,
                                              cmd=p.args )

      if not quiet and out:
         print( out.decode() )

   rpmdbs = glob.glob( os.path.join( workdir, "*.rpmdb.sqsh" ) )
   if not rpmdbs:
      print( "No rpmdb squash found in the swi." )
      return 1
   rpmdb = rpmdbs[ 0 ]

   warnIfMultipleRpmdbs( rpmdbs )

   Swi.extract.extractRootfs( rpmdb, quiet=quiet, unsquashArgs=None )

   return 0

def packAndFreshenRpmdbAndVersionFiles( swifile, workdir, quiet, fast, zstd,
                                        output, forceSign, zipRoot, targets,
                                        updateVersion, extraFiles ):
   """Prepares rpmdb and version files for packing and zips everything back 
   to swifile/output file.
   """

   # Get all the directories and filenames to operate on
   rpmdbdir = getFirstRpmdbDir( workdir )
   if not rpmdbdir:
      print( "No rpmdb directory found in %s." )
      return 1
   rpmdbfs = ".".join( rpmdbdir.split( "." )[ : -1 ] ) + ".sqsh"

   if updateVersion:
      verfiles = getVersionFiles( workdir )
      if not verfiles:
         print( "No version files found, aborting." )
         return 1
      for n, f in enumerate( verfiles ):
         verfiles[ n ] = os.path.basename( f )

   # Resquash and update the version files if applicable
   Swi.create.mksquashfs( rpmdbdir, rpmdbfs, fast=fast, rootfsRpmDbOnly=True,
                          zstd=zstd )
   if updateVersion:
      Swi.create.refreshVersionFiles( workdir, verfiles )

   # Generate exact filelist to freshen in zipfile
   filesToFreshen = [ os.path.basename( rpmdbfs ) ]
   if updateVersion:
      filesToFreshen.extend( verfiles )
   if zipRoot and targets:
      for t in targets:
         _, dst = t
         filesToFreshen.append( dst )
   if extraFiles:
      filesToFreshen.extend( extraFiles )

   if not quiet:
      print( "Creating a copy of swi file at %s." % workdir )

   # Zip updated files into copied swi, then copy it back into swifile/output
   workswi = os.path.join( workdir, os.path.basename( swifile ) )
   owd = os.getcwd()
   shutil.copyfile( swifile, workswi )
   os.chdir( workdir )

   if not quiet:
      print( "Updating the swi file at %s." % workswi )

   zipcmd = [ "zip", "-0", "-u", workswi ]
   zipcmd.extend( filesToFreshen )
   if quiet:
      zipcmd.append( "-q" )
   subprocess.check_call( zipcmd )

   if not quiet:
      print( "Copying the file back to %s." % output )

   shutil.copyfile( workswi, output )
   os.chdir( owd )

   return 0

def fastInSwi( swifile, fns, output=None, fast=False, readOnly=False,
               zstd=False, quiet=False, forceSign=False, zipRoot=False,
               targets=None, updateVersion=True ):
   """Performs a function on rpmdb only and updates all the version files in
   swifile if applicable. Requires a SWIM image and no overrideRpmOp.
   """
   if output is None:
      output = swifile

   output = os.path.abspath( output )
   swifile = os.path.abspath( swifile )

   swiPermissionsCheck( swifile, output, readOnly )
   workdir = tempfile.mkdtemp()

   try:
      retcode = extractRpmdbAndVersionFiles( swifile, workdir, quiet=quiet,
                                             updateVersion=updateVersion )
      if retcode != 0:
         sys.exit( retcode )


      rpmdbdir = getFirstRpmdbDir( workdir )
      if not rpmdbdir:
         print( "No rpmdb directory found in %s. Aborting." % workdir )
         sys.exit( retcode )
      if not readOnly:
         arch = getArchFromVersionFile( "%s/version" % workdir )
         setupRpmDbDir( rpmdbdir, quiet, arch )
      extraUpdatedFiles = []
      for fn in fns:
         files = fn( rpmdbdir )
         if files:
            extraUpdatedFiles += files

      if not readOnly:
         retcode = packAndFreshenRpmdbAndVersionFiles( swifile, workdir,
                                                       quiet=quiet, fast=fast,
                                                       zstd=zstd, output=output,
                                                       forceSign=forceSign,
                                                       zipRoot=zipRoot,
                                                       targets=targets,
                                                       updateVersion=updateVersion,
                                                       extraFiles=extraUpdatedFiles )
         if retcode != 0:
            sys.exit( retcode )

         if forceSign:
            if not quiet:
               print( "Updating signatures" )
            result = Swi.sign.updateSwiSignature( output )
            if not quiet:
               print( "Result: %s" % Swi.sign.UPDATE_SIG_MESSAGE[ result ] )
         else:
            print( "Not updating SWI signature. Use --fSign to override "
                   "this behavior" )
            print( "SWI signature will be updated automatically for "
                   "secure-boot DUTs during sanitize/newimage" )
   finally:
      subprocess.check_call( [ "sudo", "rm", "-rf", workdir ] )

def swiPermissionsCheck( swiFile, outputFile, readOnly ):
   # first verify permissions on file.
   if not os.access( swiFile, os.R_OK ):
      print( "Permission Error: Can't read " + swiFile + ". Does it exist?" )
      sys.exit( 1 )
   if ( os.path.exists( outputFile ) and not os.access( outputFile, os.W_OK ) and
        not readOnly ):
      print( "Permission Error: Can't write to " + outputFile +
             ". Retry with sudo?" )
      sys.exit( 1 )

def repackContainerImage( rootFsDir, outputFile, tmpFile ):
   # Use --xattrs to ensure that capabilities (see "man capabilities")
   # are packed into the tarball.
   Swi.runAndReturnOutput( [ 'sudo', 'tar', '--xattrs', '-cf', tmpFile,
                             '-C', rootFsDir, '.' ] )
   # Install pxz if it doesn't exist
   try:
      Swi.run( [ 'which', 'pxz' ] )
   except AssertionError:
      Swi.run( [ 'a4', 'yum', 'install', '-y', 'pxz' ] )
   Swi.runAndReturnOutput( [ 'sudo', 'pxz', '-T30', '-k', tmpFile ] )
   Swi.runAndReturnOutput( [ 'sudo', 'mv', tmpFile + '.xz', outputFile ] )

def inTarXzSwi( file, funcs, outputfile, readOnly, extImageDir, quiet ):
   swiPermissionsCheck( file, outputfile, readOnly )
   # Create temporary directory to extract SWI to
   Swi.extract.setTempDirIfNeeded()
   tmpdir = extImageDir or tempfile.mkdtemp()
   try:
      rootFsDir = tempfile.mkdtemp( dir=tmpdir )

      if not quiet:
         print( f"Extracting {file} to {rootFsDir}" )
      Swi.runAndReturnOutput( [ 'sudo', 'tar', '--xattrs', '-xJf', file,
                                '-C', rootFsDir ] )

      for func in funcs:
         func( rootFsDir )

      if not quiet:
         print( f"Constructing {outputfile} from {rootFsDir}" )
      # pylint: disable-next=consider-using-with
      tmpFileForRepack = tempfile.NamedTemporaryFile( dir=tmpdir ).name
      repackContainerImage( rootFsDir, outputfile, tmpFileForRepack )
      Swi.run( [ 'sudo', 'rm', tmpFileForRepack ] )
   finally:
      if not extImageDir:
         if not quiet:
            print( "Cleaning up", tmpdir )
         run( [ "sudo", "rm", "-rf", tmpdir ] )

def inSwi( file, funcs, fast=False, outputfile=None, readOnly=False,
        flattenSwim=False, nonOverlayFlatten=False,
        zstd=False, extImageDir=None, quiet=False,
        optimization=None, rpmOp=False, forceSign=False,
        updateVersion=True ):
   """
   Args:
      optimization: specifies which optimization rootfs to apply funcs in.
         optimization=None will result in funcs being applied to all found
         rootfs sqshes in the image.
   """
   if outputfile is None:
      outputfile = file

   # We chdir below, and stay in the rootfs for the SWI. The user specified a
   # filename relative to their current directory, not somewhere in /tmp, so
   # use a sensible absolute path when we've a different current dir.
   outputfile = os.path.abspath( outputfile )

   if file.endswith( ".tar.xz" ):
      inTarXzSwi( file, funcs, outputfile, readOnly, extImageDir, quiet )
      return

   swiPermissionsCheck( file, outputfile, readOnly )
   # Create temporary directory to extract SWI to
   Swi.extract.setTempDirIfNeeded()
   tmpdir = extImageDir or tempfile.mkdtemp()
   swimOverlayDirs = {}
   try:
      rpmDbDir = None
      rpmDbUnionDir = None

      # If extImageDir is passed in, we already have an extracted SWI to work
      # with or a partially extracted one in case of `extractImage`
      if not extImageDir:
         if not quiet:
            print( "Extracting", file, "to", tmpdir )
         Swi.extract.extract( file, SwiOptions( dirname=tmpdir,
                                                use_existing=True ),
                              quiet=quiet, rpmOp=rpmOp, readOnly=readOnly )

      # We don't support use of most swi commands on EOS-meta,
      # so we shouldn't see multiple rpmdb dirs here
      rpmDbDirs = glob.glob( os.path.join( tmpdir, "*.rpmdb.dir" ) )

      if rpmOp and rpmDbDirs:
         warnIfMultipleRpmdbs( rpmDbDirs )
         if readOnly:
            rpmDbDir = rpmDbDirs[ 0 ]
         else:
            _, tmpSwimOverlayDirs = getSwimSqshMapAndOverlayDirs( tmpdir,
                                                                  None )

            if 'Default' in tmpSwimOverlayDirs:
               lowersToMount = tmpSwimOverlayDirs[ "Default" ][ "lower" ]
            elif 'Default-DPE' in tmpSwimOverlayDirs:
               lowersToMount = tmpSwimOverlayDirs[ "Default-DPE" ][ "lower" ]
            else:
               _, dirs = list( tmpSwimOverlayDirs.items() )[ 0 ]
               lowersToMount = dirs[ "lower" ]

            dbDir = rpmDbDirs[ 0 ]
            lowersToMount = lowersToMount.replace( "%s:" % dbDir, "" )
            rpmDbUnionDir = mountPathOverLowers( tmpdir, dbDir, lowersToMount )

      # For rpmOps on a SWI with an rpmdb sqsh, there's no need to flatten
      # or mount SWIM dirs
      if not readOnly and not rpmDbDir and not rpmDbUnionDir:
         # Flatten a multi sqsh FS into a single squash FS
         if flattenSwim:
            if not quiet:
               print( "Flattening multiple sqsh SWI into a single sqsh" )
            SwimFlatten.swimFlatten( tmpdir, nonOverlay=nonOverlayFlatten )

         # Passing optimization=None to getSwimSqshMapAndOverlayDirs()
         # will return all found SWIM Dirs, while passing a valid
         # optimization will only fill swimOverlayDirs with
         # information pertaining to the specified optimization
         _, swimOverlayDirs = getSwimSqshMapAndOverlayDirs( tmpdir, optimization )

         if swimOverlayDirs:
            mountSwimRootfs( swimOverlayDirs )

      for func in funcs:
         if rpmDbDir or rpmDbUnionDir:
            dbDir = ( rpmDbDir if rpmDbDir is not None else rpmDbUnionDir )
            if not quiet:
               print( "Working on RPMDB dir/union mount %s" % dbDir )
            if not readOnly:
               arch = getArchFromVersionFile( "%s/version" % tmpdir )
               setupRpmDbDir( dbDir, quiet, arch )
            func( dbDir )
         elif swimOverlayDirs:
            # pylint: disable-next=consider-using-dict-items
            for flavor in swimOverlayDirs:
               if not quiet:
                  print( "Working on swim flavor %s" % flavor )
               func( swimOverlayDirs[ flavor ][ "union" ] )
         else:
            rootfsDir = glob.glob( os.path.join( tmpdir, "*rootfs-*.dir" ) )[ 0 ]
            if not quiet:
               print( "Working on rootfs %s" % rootfsDir )
            func( rootfsDir )

      # Need to cleanup SWIM or rpmdb overlay mounts and dirs before we
      # recreate the SWI
      if rpmDbUnionDir:
         unmountAndCleanupWd( tmpdir, rpmDbUnionDir )
      elif swimOverlayDirs:
         unmountSwimRootfs( swimOverlayDirs )

      if not readOnly:
         if not quiet:
            print( "Creating", outputfile, "from", tmpdir )
         options = SwiOptions( dirname=tmpdir, squashfs=False, fast=fast,
                               installfs_only=False, force=True, trace=True,
                               zstd=zstd, force_resquash=False,
                               update_version=updateVersion )
         if glob.glob( os.path.join( tmpdir, "*rootfs-*.sqsh" ) ):
            options.update( squashfs=True )
         Swi.create.create( outputfile, options )

         # Skip resigning of image during rpmOp commands run in user workspaces
         # to speed-up image modification for devs
         if ( 'AUTOTEST' not in os.environ and
              rpmOp and
              not forceSign and
              SwiSignLib.swiSignedWithDevCA( outputfile ) ):
            print( "Not updating SWI signature. Use --fSign to override "
                   "this behavior" )
            print( "SWI signature will be updated automatically for "
                   "secure-boot DUTs during sanitize/newimage" )
         else:
            print( "Updating SWI Signature" )
            result = Swi.sign.updateSwiSignature( outputfile )
            print( Swi.sign.UPDATE_SIG_MESSAGE[ result ] )
   finally:
      # Cleanup SWIM mounts/dirs in case exception occurred in try block
      if rpmDbUnionDir:
         unmountAndCleanupWd( tmpdir, rpmDbUnionDir )
      elif swimOverlayDirs:
         unmountSwimRootfs( swimOverlayDirs )

      if not extImageDir:
         if not quiet:
            print( "Cleaning up", tmpdir )
         run( [ "sudo", "rm", "-rf", tmpdir ] )

def getBlessedAndVersion( swiFile ): # pylint: disable=inconsistent-return-statements
   ''' Returns a tuple of ( BLESSED, SWI_VERSION ) from the 'version' 
       file in @swiFile, ex ( '1', '4.20.0F' ) for a blessed image, 
       ( None, '4.20.0F' ) for non-blessed, and ( None, None ) if 
       version and blessed cannot be determined.'''
   try:
      with zipfile.ZipFile( swiFile, 'r' ) as swi:
         if swi.getinfo( 'version' ):
            with swi.open( 'version', 'r' ) as versionFile:
               # Not using SimpleConfigFile here because it's an archive
               # in a zipfile and we would have to pass the file descriptor
               blessed = None
               version = None
               for line in versionFile:
                  line = line.decode().strip().split( '=' )
                  if line[ 0 ] == 'BLESSED':
                     blessed = int( line[ 1 ] )
                  if line[ 0 ] == 'SWI_VERSION':
                     version = line[ 1 ]
               return blessed, version
   except KeyError:
      return None, None

def getArchFromVersionFile( versionPath, quiet=False ):
   arch = ARCH_x86_64
   try:
      with open( versionPath, 'r' ) as versionFile:
         versionData = versionFile.read()
         archMatch = re.search( r"SWI_ARCH=(\S+)", versionData )
         if archMatch:
            arch = archMatch.group( 1 )
   # pylint: disable-msg=W0703
   except Exception as e:
      if not quiet:
         print( "Failed to discover SWI_ARCH from version file" )
         print( e )
   return arch

def setupRpmDbDir( rootdir, quiet, arch ):
   ''' This function prepares the rootdir by creating directories and symlinks
       which would normally exist in a rootdir had the filesystem RPM been
       installed. Since RPM database dirs do not have any RPMs installed and
       just consist an rpm database files ( /var/lib/rpm/ ), this
       func is needed to setup the filesystem so swi freshen/rpm/update installs
       files into the right place.'''
   os.system( "sudo mkdir %s/usr" % rootdir )
   os.system( "sudo mkdir %s/usr/bin" % rootdir )
   os.system( "sudo mkdir %s/usr/sbin" % rootdir )
   os.system( "sudo mkdir %s/usr/lib" % rootdir )
   os.system( "sudo mkdir %s/usr/lib/debug" % rootdir )
   os.system( "sudo mkdir %s/usr/lib/debug/usr" % rootdir )
   os.system( "sudo mkdir %s/usr/lib/debug/usr/bin" % rootdir )
   os.system( "sudo mkdir %s/usr/lib/debug/usr/sbin" % rootdir )
   os.system( "sudo mkdir %s/usr/lib/debug/usr/lib" % rootdir )
   if arch == ARCH_x86_64:
      os.system( "sudo mkdir %s/usr/lib/debug/usr/lib64" % rootdir )
      os.system( "sudo mkdir %s/usr/lib64" % rootdir )
   else:
      os.system( "sudo mkdir %s/usr/lib/debug/usr/lib" % rootdir )
      os.system( "sudo mkdir %s/usr/lib" % rootdir )
   os.system( "sudo mkdir %s/run" % rootdir )
   os.system( "sudo ln -s usr/bin %s/bin" % rootdir )
   os.system( "sudo ln -s usr/sbin %s/sbin" % rootdir )
   os.system( "sudo ln -s usr/lib %s/lib" % rootdir )
   os.system( "sudo ln -s usr/bin %s/usr/lib/debug/bin" % rootdir )
   os.system( "sudo ln -s usr/lib %s/usr/lib/debug/lib" % rootdir )
   if arch == ARCH_x86_64:
      os.system( "sudo ln -s usr/lib64 %s/usr/lib/debug/lib64" % rootdir )
   else:
      os.system( "sudo ln -s usr/lib %s/usr/lib/debug/lib" % rootdir )
   os.system( "sudo ln -s ../.dwz %s/usr/lib/debug/usr/.dwz" % rootdir )
   os.system( "sudo ln -s usr/sbin %s/usr/lib/debug/sbin" % rootdir )
   if arch == ARCH_x86_64:
      os.system( "sudo ln -s usr/lib64 %s/lib64" % rootdir )
   else:
      os.system( "sudo ln -s usr/lib %s/lib" % rootdir )
   os.system( "sudo ln -s ../run %s/var/run" % rootdir )
   os.system( "sudo ln -s ../run/lock %s/var/lock" % rootdir )
