# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from Swi import installrootfs

run = installrootfs.run

def getInstallCmd( yumConfig, erasePattern, repoArgs, verifyProvenance ):
   cmd = [ "swi", "installrootfs", "--keepdoc" ]
   if yumConfig:
      cmd += [ "-c", yumConfig ]
   for repoArg in repoArgs:
      cmd += [ f"--{ repoArg.op.value }repo={ repoArg.repo }" ]
   if erasePattern:
      cmd += [ "--erase-pattern", erasePattern ]
   if verifyProvenance:
      cmd += [ "--verify-provenance" ]
   return cmd

def getInstalledRpms( sqshdir, printOutput=False ):
   cmd1 = [ "rpm", f"--root={sqshdir}", "-qa",
            "--queryformat", "%{NAME} " ]
   output = run( cmd1, asRoot=True, captureStdout=True,
                 printOutput=printOutput )
   return set( output.strip().split() )

def installRpms( rpms, yumConfig, erasePattern, repoArgs, sqshdir,
                  lowerdirs, verifyProvenance=False, printOutput=False ):
   cmd = getInstallCmd( yumConfig, erasePattern, repoArgs, verifyProvenance )
   pArgs = []
   installedRpms = set()
   for p in rpms:
      if not p in installedRpms:
         pArgs += [ "-p", p ]
      else:
         continue
      if len( pArgs ) >= 30:
         if lowerdirs:
            cmd += [ "-l", lowerdirs ]
         cmd += pArgs
         cmd += [ sqshdir ]
         run( cmd, asRoot=True, printOutput=printOutput )
         cmd = getInstallCmd( yumConfig, erasePattern, repoArgs, verifyProvenance )
         pArgs = []
         installedRpms |= getInstalledRpms( sqshdir, printOutput )

   if pArgs:
      if lowerdirs:
         cmd += [ "-l", lowerdirs ]
      cmd += pArgs
      cmd += [ sqshdir ]
      run( cmd, asRoot=True )
