#!/usr/bin/env python3
# Copyright (c) 2008, 2009 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import argparse
import os
import re
import sys
import Swi
import tempfile
import subprocess
import glob
import zipfile
from SwimHelperLib import isSwimImage

def isMatchingRegex( arglist, regex ):
   for arg in arglist:
      if re.search( regex, arg ):
         return True
   return False

def hasQueryModeArg( arglist ):
   reg = r"^\s*(-[a-zA-Z]*(q)[a-zA-Z]*|--query)\s*$"
   return isMatchingRegex( arglist, reg )

def hasFileModeArg( arglist ):
   reg = re.compile( '^\\s*(-[a-zA-Z]*(i|U|F)[a-zA-Z]*|--install|--upgrade|'
                     '--freshen|--reinstall)\\s*$' )
   return isMatchingRegex( arglist, reg )

def hasEraseModeArg( arglist ):
   reg = r"^\s*(-[a-zA-Z]*(e)[a-zA-Z]*|--erase)\s*$"
   return isMatchingRegex( arglist, reg )

def hasInstallModeArg( arglist ):
   reg = r"^\s*(-[a-zA-Z]*(i)[a-zA-Z]*|--install)\s*$"
   return isMatchingRegex( arglist, reg )

def hasVerifyOrOtherMiscArg( arglist ):
   reg = re.compile( "^\\s*(-[a-zA-Z]*(V|K)[a-zA-Z]*|--verify|--showrc"
                     "--setperms|--setugids|--setcaps|--restore)\\s*$" )
   return isMatchingRegex( arglist, reg )

def getRpmsFromArglist( arglist ):
   rpmlist = []
   for arg in arglist:
      if not arg.startswith( "-" ):
         rpmlist.append( arg )
   return rpmlist

def getRpmNames( rpmlist ):
   return [ os.path.basename( pkg ).split( '.' )[ 0 ] for pkg in rpmlist ]

def generateRpmsFileset( rpms, rpmdbdir ):
   """ Generates a set of files for provided rpms. """
   if rpmdbdir:
      # pylint: disable-next=consider-using-f-string
      installedRpmsQuery = [ "sudo", "rpm", "--root=%s" % rpmdbdir, "-qa",
                             "--qf", "%{NAME}\n" ]
      installedRpms = ( subprocess.check_output( installedRpmsQuery )
                        .decode().split() )
      rpms = set( installedRpms ).intersection( set( rpms ) )
      # pylint: disable-next=consider-using-f-string
      query = [ "sudo", "rpm", "--root=%s" % rpmdbdir, "-ql" ]
   else:
      query = [ "sudo", "rpm", "-qlp" ]

   if not rpms:
      return set()
   else:
      query.extend( rpms )
      return set( subprocess.check_output( query ).decode().split() )

def doRpmsHaveScriptlets( workdir, arglist ):
   """ Checks if rpms provided as arguments have any install/uninstall 
   scriptlets to run.
   """
   if not arglist:
      return False

   rpmdbdir = Swi.getFirstRpmdbDir( workdir )
   # pylint: disable-next=consider-using-f-string
   cmd = [ "sudo", "rpm", "--root=%s" % rpmdbdir ]
   if hasFileModeArg( arglist ):
      cmd.extend( [ "-qp", "--scripts" ] )
   else:
      cmd.extend( [ "-q", "--scripts" ] )
   cmd.extend( getRpmsFromArglist( arglist ) )
   output = subprocess.check_output( cmd )
   reg = re.compile( "^(preinstall|postinstall|preuninstall|postuninstall) "
                     "(scriptlet|program)", re.MULTILINE )
   return bool( re.search( reg, output.decode() ) )

def doesFileModeCmdRemoveFiles( workdir, arglist ):
   """ Checks if the command used removes any files by comparing the original
   packages file set with the new one.
   """
   if hasInstallModeArg( arglist ):
      return False

   rpmdbdir = Swi.getFirstRpmdbDir( workdir )
   rpmFiles = set()
   for pkg in getRpmsFromArglist( arglist ):
      rpmFiles.update( glob.glob( pkg ) )
   rpmNames = getRpmNames( rpmFiles )
   oldFileset = generateRpmsFileset( rpmNames, rpmdbdir )
   newFileset = generateRpmsFileset( rpmFiles, None )
   return oldFileset - newFileset

def isRpmdbSqshInSwi( swifile ):
   z = zipfile.ZipFile( swifile, "r" ) # pylint: disable=consider-using-with
   for f in z.namelist():
      if f.endswith( ".rpmdb.sqsh" ):
         return True
   return False

def canUseFastInSwi( swifile, args, opts ):
   """ Determines the code path to use for `swi rpm` based on what mode/command
   we are using, does it remove any files and whether it has install/uninstall 
   scriptlets.

   Returns: 
      useFastInSwi: bool
   """
   # In case of erase, verify and any misc modes that need the actual files,
   # we can't use fastInSwi.
   if ( not isSwimImage( swifile )
        or not isRpmdbSqshInSwi( swifile )
        or opts.overrideRpmOp
        or opts.overrideScriptletCheck
        or hasEraseModeArg( args )
        or hasVerifyOrOtherMiscArg( args ) ):
      return False

   if ( hasQueryModeArg( args ) or opts.readOnly ):
      return True

   try:
      Swi.extract.setTempDirIfNeeded()
      workdir = tempfile.mkdtemp()
      retcode = Swi.extractRpmdbAndVersionFiles( swifile, workdir,
                                                 quiet=opts.quiet )
      if retcode != 0:
         sys.exit( retcode )

      # We can be fairly sure we have a mode that takes filenames as input at
      # this point, in case we got here with some unhandled command type
      # default to inSwi
      hasFileCmd = hasFileModeArg( args )
      hasRemovedFiles = doesFileModeCmdRemoveFiles( workdir, args )
      hasScriptlets = doRpmsHaveScriptlets( workdir, args )
      useFastInSwi = hasFileCmd and not ( hasRemovedFiles or hasScriptlets )

      return useFastInSwi
   finally:
      subprocess.check_call( [ "sudo", "rm", "-rf", workdir ] )

def rpmInSwiFunc( *args, **kwargs ):
   def rpmWithinExtractedSwi( rootdir ):
      # Check if bdb/ndb backend is being used if there's a sqlite db file
      if os.path.exists( os.path.join( rootdir, "var/lib/rpm/rpmdb.sqlite" ) ):
         rpmCmd = [ "sudo", "rpm", "-E", "\"%{_db_backend}\"" ]
         out = subprocess.run( rpmCmd, check=False,
                               capture_output=True ).stdout.decode()
         if "bdb" in out or "ndb" in out:
            print( "Error: local rpm binary isn't running with sqlite _db_backend." )
            print( "Please run 'swi rpm' from a supported workspace." )
            sys.exit( 1 )

      # pylint: disable-next=consider-using-f-string
      rpmCmd = [ "sudo", "rpm", "--root=%s" % rootdir ]
      rpmCmd += list( args )
      Swi.run( rpmCmd )
   return rpmWithinExtractedSwi

def swiOutputSanityCheck( swiFile, output ):
   """Prevents from wasting time by trying to put swiFile in
   non-existent directory or running from non-existing directory

   Raises:
      OSError: when directory does not exist
   """
   try:
      _ = os.getcwd()
   except OSError:
      print( 'You might be working from a directory that does not exist' )
      print( 'Please switch to an existing directory' )
      raise

   if output is None:
      result = swiFile
   else:
      result = output
      dirTo = os.path.dirname( output )
      assert os.path.exists( dirTo ), f'{dirTo} does not exist'

   print( f'Updated swiFile will appear in {result}' )

def rpm( swifile, opts, *args, **kwargs ):
   if not opts.readOnly:
      swiOutputSanityCheck( swifile, opts.file )

   extraFns = kwargs.get( 'extraFns' ) or []
   rpmOp = not opts.overrideRpmOp
   useFastInSwi = canUseFastInSwi( swifile, args, opts )
   rpmWithinExtractedSwi = rpmInSwiFunc( *args, quiet=opts.quiet )

   if useFastInSwi:
      Swi.fastInSwi( swifile, [ rpmWithinExtractedSwi ] + extraFns,
                     output=opts.file, fast=opts.fast, readOnly=opts.readOnly,
                     zstd=opts.zstd, quiet=opts.quiet, forceSign=opts.forceSign )
   else:
      Swi.inSwi( swifile, [ rpmWithinExtractedSwi ] + extraFns, fast=opts.fast,
                 outputfile=opts.file, readOnly=opts.readOnly,
                 zstd=opts.zstd, quiet=opts.quiet, rpmOp=rpmOp,
                 forceSign=opts.forceSign )

def addRpmOptions( op ):
   group = op.add_mutually_exclusive_group()
   group.add_argument( '--fast',
                       help='compress with gzip (will generate bigger SWI image)',
                       action='store_true' )
   group.add_argument( '--zstd',
                       help='compress with zstd (will generate bigger SWI image)',
                       action='store_true' )
   op.add_argument( '-o', '--output',
                    dest='file',
                    help='output the swi into <file>',
                    action='store' )
   op.add_argument( '-r', '--readOnly',
                    help='skip swi generation after rpm execution',
                    action='store_true' )
   op.add_argument( '--quiet',
                    help='silence output from all except the rpm commands',
                    action='store_true' )
   op.add_argument( "--overrideRpmOp",
                    help="Disables the rpmOp optimization by unsquashing "
                         "everything in the SWI",
                    action="store_true" )
   op.add_argument( "--fSign", dest="forceSign",
                    help="Force updating the swi signature after rpm commands run",
                    action="store_true" )
   op.add_argument( "--overrideScriptletCheck",
                    help="Disables the rpm scriptlet and deleted files check",
                    action="store_true" )

def rpmHandler( args=None ):
   args = sys.argv[1:] if args is None else args
   # not using optparse because we want people to run swi rpm EOS.swi -U foo.rpm
   # and optparse wants to parse -U
   op = argparse.ArgumentParser(
         prog="swi rpm",
         description="Updating/removing RPMs in a SWI file",
         usage="usage: %(prog)s [--fast] [-o <file>] EOS.swi [options]" )
   addRpmOptions( op )

   opts, args = op.parse_known_args( args )
   if len( args ) < 2:
      op.error( 'Missing arguments' )

   rpm( args[ 0 ], opts, *args[ 1: ] )

