#!/usr/bin/python3
#
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import argparse
import subprocess
import os
import AddInOrder
import RepoArg

MetadataTags = [ "requires", "conflicts", "provides", "obsoletes" ]

def getRpmMetadata( pkgs, tag ):
   tagCmd = [ "rpm", "-q", f"--{tag}" ]
   tagCmd.append( pkgs )
   output = subprocess.check_output( tagCmd, stderr=subprocess.STDOUT )
   lines = output.decode( 'utf-8' ).splitlines()
   uniqueLines = list( set( line for line in lines if "warning:" not in line ) )
   return uniqueLines

def handleSimultaneousRequireAndProvideEntries( data ):
   provides = set( prov.split()[ 0 ] for prov in data[ 'provides' ] )
   data[ 'requires' ] = [ req for req in data[ 'requires' ]
                           if req.split()[ 0 ] not in provides ]

def writeSpecFormattedDepsMetadata( data, filename ):
   with open( filename, "w" ) as f:
      for tag, deps in data.items():
         f.write( f"{tag.capitalize()}: " )
         for idx, dep in enumerate( deps ):
            f.write( dep )
            if idx != len( deps ) - 1:
               f.write( ", " )
            else:
               f.write( "\n" )

def printExceptionDetails( command, exception ):
   print( f"Command {command} failed with exception {exception}" )
   print( exception.stdout )
   print( exception.stderr )

def createPayload( wdir, archName, rpmList, output, repoArgs ):
   # Create directory tree
   abswdir = os.path.abspath( wdir )
   rpmdir = os.path.join( abswdir, "rpms" )
   cpiodir = os.path.join( rpmdir, "cpio" )
   sqfsrootdir = os.path.join( abswdir, "squashroot" )
   subprocess.run( [ "mkdir", "-p", abswdir, rpmdir, cpiodir, sqfsrootdir ],
                   check=True )

   print( "Created directory tree" )

   # Download all requested RPMs
   cmd = [ "a4", "yumdownloader", f"--destdir={rpmdir}",
           f"--archlist={archName}" ]
   cmd.extend( repoArgs )
   cmd.extend( rpmList )
   try:
      subprocess.run( cmd, check=True )
   except subprocess.CalledProcessError as e:
      printExceptionDetails( cmd, e )
      raise e

   print( f"Downloaded requested rpms to {rpmdir}" )

   # Extract rpmdb metadata
   rpmPattern = os.path.join( rpmdir, "*.rpm" )
   metadata = {}
   for metadataTag in MetadataTags:
      metadata[ metadataTag ] = getRpmMetadata( rpmPattern, metadataTag )
   handleSimultaneousRequireAndProvideEntries( metadata )
   writeSpecFormattedDepsMetadata( metadata, f"{output}.deps" )

   print( f"Output rpmdb metadata to {output}.deps" )

   # Extract CPIOs from the RPMs
   for rpm in os.listdir( rpmdir ):
      if not rpm.endswith( ".rpm" ):
         continue
      file = os.path.join( rpmdir, rpm )
      cpioName = os.path.join( cpiodir, rpm.replace( ".rpm", ".cpio" ) )
      cmd = [ "rpm2cpio", file ]
      with open( cpioName, "wb" ) as cpio:
         try:
            subprocess.run( cmd, stdout=cpio, check=True )
         except subprocess.CalledProcessError as e:
            printExceptionDetails( cmd, e )
            raise e

   print( f"Extracted cpios to {cpiodir}" )

   # Extract the CPIOs to the mksquashfs root
   origCwd = os.getcwd()
   cmd = [ "cpio", "-idv" ]
   for cpio in os.listdir( cpiodir ):
      if not cpio.endswith( ".cpio" ):
         continue
      try:
         with open( os.path.join( cpiodir, cpio ), "r" ) as inCpio:
            os.chdir( sqfsrootdir )
            subprocess.run( cmd, stdin=inCpio, check=True )
      except subprocess.CalledProcessError as e:
         printExceptionDetails( cmd, e )
         raise e
   os.chdir( origCwd )

   print( f"Installed cpio content to {sqfsrootdir}" )

   # Generate the squashfs file
   mksqfscmd = [ "mksquashfs", sqfsrootdir, output ]
   try:
      subprocess.run( mksqfscmd, check=True )
   except subprocess.CalledProcessError as e:
      printExceptionDetails( mksqfscmd, e )
      raise e

   print( f"Created {output}" )

   # Generate the install/uninstall payload commands and file list
   listcmd = [ "unsquashfs", "-lls", output ]
   uninstallCmds = []
   try:
      filelist = subprocess.check_output( listcmd ).decode( "utf-8" ).splitlines()
      with open( f"{output}.install", "w" ) as installFile, \
           open( f"{output}.uninstall", "w" ) as uninstallFile, \
           open( f"{output}.list", "w" ) as listFile:
         for line in filelist:
            line = line.strip()

            if not ( line.startswith( "-" ) or line.startswith( "d" ) ):
               continue

            cols = line.split()

            # mksquashfs/unsquashfs uses squashfs-root to mark the root dir,
            # this might change at some point so make sure it is actually there
            assert cols[ 5 ].startswith( "squashfs-root" )

            attrs = cols[ 0 ]
            name = ' '.join( cols[ 5 : ] ).replace( "squashfs-root", "", 1 )

            if not name:
               continue

            # Write the install commands and generate uninstall commands
            if attrs.startswith( "d" ) and not os.path.isdir( name ):
               installFile.write( f"mkdir -p \"{name}\"\n" )
               uninstallCmds.append( f"rmdir \"{name}\"\n" )
            elif attrs.startswith( "-" ):
               installFile.write( f"ln -s \"%{{EXT_ROOTDIR}}{name}\" \"{name}\"\n" )
               uninstallCmds.append( f"unlink \"{name}\"\n" )
               listFile.write( f"{name}\n" )

         # Write the uninstall commands
         for line in reversed( uninstallCmds ):
            uninstallFile.write( line )

      print( f"Generated the {output}.install/uninstall/list files" )
   except subprocess.CalledProcessError as e:
      printExceptionDetails( listcmd, e )
      raise e

if __name__ == '__main__':
   ap = argparse.ArgumentParser()
   ap.add_argument( '-w', '--workdir',
                   help='Work directory' )
   ap.add_argument( '-a', '--arch',
                   help="Target architecture" )
   ap.add_argument( '--enablerepo',
                    action=AddInOrder.AddEnableDisableReposInOrder,
                    help="Enable repository (may be specified multiple "
                         "times)" )
   ap.add_argument( '--disablerepo',
                    action=AddInOrder.AddEnableDisableReposInOrder,
                    help="Disable repository (may be specified multiple "
                         "times)" )
   ap.add_argument( '-o', '--output',
                   help="Output squashfs filename, used also as prefix for "
                        "metadata files" )
   ap.add_argument( '--rpms', nargs="+",
                   help="One or more packages to install to the squashfs" )
   args = ap.parse_args()

   out = args.output
   workdir = os.path.abspath( args.workdir )
   rpms = args.rpms
   arch = args.arch

   try:
      orderedArgs = RepoArg.createRepoArgs( args.ordered_args )
   except AttributeError as _:
      orderedArgs = []
   repoArgsList = [ f"--{arg.op.value}repo={arg.repo}" for arg in orderedArgs ]

   createPayload( workdir, arch, rpms, out, repoArgsList )
