#!/usr/bin/env python3
# Copyright (c) 2009 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-with
# pylint: disable=consider-using-f-string

import argparse
import collections
import functools
import hashlib
import importlib
import operator
import os
import re
import shutil
import subprocess
import sys
import tempfile
import yaml
import six

import EosVersionValidator
import TableOutput

from Swix import schema, Lowmem

def sha1sum( filename ):
   try:
      with open( filename, "rb" ) as f:
         h = hashlib.sha1()
         while True:
            chunk = f.read( 65536 )
            if not chunk:
               break
            h.update( chunk )
         return h.hexdigest()
   except Exception as e: # pylint: disable=broad-except
      sys.exit( f"Error computing sha1 sum for {filename}: {e}\n" )


def createManifest( filename, primaryRpm, rpms ):
   """Given a set of RPM file paths, creates a manifest file with the specified
   filename."""
   lines = []
   lines.append( "format: 1" )
   lines.append( "primaryRpm: %s" % os.path.basename( primaryRpm ) )
   for rpm in [ primaryRpm ] + rpms:
      basename = os.path.basename( rpm )
      lines.append( f"{basename}-sha1: {sha1sum( rpm )}" )
   try:
      outfile = open( filename, "w" )
      outfile.write( "\n".join( lines ) )
      outfile.write( "\n" )
      outfile.close()
   except Exception as e: # pylint: disable=broad-except
      sys.exit( f"{filename}: {e}\n" )


def renderVersionTable( versionsToRpms ):
   headings = [ "EOS Versions", "Compatible RPMs" ]
   table = TableOutput.createTable( headings )
   versionFormat = TableOutput.Format( justify="left", maxWidth=20, wrap=True )
   versionFormat.noPadLeftIs( True )
   versionFormat.noTrailingSpaceIs( True )
   rpmsFormat = TableOutput.Format( justify="left", maxWidth=20, wrap=True )
   rpmsFormat.noPadLeftIs( True )
   rpmsFormat.noTrailingSpaceIs( True )
   table.formatColumns( versionFormat, rpmsFormat )

   for version in sorted( versionsToRpms ):
      rpmStr = ""
      for rpm in sorted( versionsToRpms[ version ] ):
         rpmStr += rpm + ", "
      if rpmStr != "":
         rpmStr = rpmStr[ :-2 ]
      table.newRow( version, rpmStr )
   print( table.output() )
   confirm = six.moves.input( "The above table shows which RPMs will be installed "
                               "for each EOS version using the YAML file packaged "
                               "with this swix. Are the versions and RPMs correct? "
                               "[y/n]: " )
   if not ( confirm == "y" or confirm == "Y" ): # pylint: disable=consider-using-in
      sys.exit( "Abort: Undesired versions/RPMs in install.yaml file." )


def checkInfo( infoFile, allRpms, noReleaseDb ):
   """Given a YAML file path, check if that file is well-formed and valid. If
   so, return True; Else ends with an error message"""
   # pylint: disable-msg=too-many-nested-blocks
   with open( infoFile ) as stream:
      try:
         info = yaml.safe_load( stream )
         schema.checkInfoSchema( info )
         if "version" in info:
            versionInfo = info[ "version" ]
            versionsToRpms = collections.defaultdict( set )
            versionsList = []
            for versionGroup in versionInfo:
               version = next( iter( versionGroup ) )
               matchedVersions = {}
               if not noReleaseDb:
                  matchedVersions = \
                     EosVersionValidator.getMatchingVersions( version )
                  # The matchedVersions object is a dictionary with version pointing
                  # at the expanded list
                  if version in matchedVersions:
                     versionsList.extend( matchedVersions[ version ] )
               # We should always check the version syntax
               else:
                  EosVersionValidator.parse( version )


               for matched in functools.reduce( operator.iconcat,
                                                matchedVersions.values(), [] ):
                  for filename in versionGroup:
                     if isinstance( filename, dict ):
                        # This is a file with instructions; get the name.
                        filename = next( iter( filename ) )
                     versionsToRpms[ matched ].update( versionGroup[ filename ] )

               for rpm in versionGroup[ version ]:
                  if isinstance( rpm, dict ):
                     rpm = next( iter( rpm ) )
                  if rpm != "all" and not rpm in allRpms:
                     sys.exit( "%s required for version %s doesn't exist" %
                               ( rpm, version ) )
            if versionsToRpms:
               renderVersionTable( versionsToRpms )
            return True, versionsList
         return True, None
      except Exception as e: # pylint: disable=broad-except
         sys.exit( f"{infoFile}: {e}\n" )
   # We have this final return as tests stub out sys.exit
   return False, None

def getRpmCompression( rpm ):
   compressCheck = subprocess.Popen( [
      'rpm', '-q', '--qf', '%{PAYLOADCOMPRESSOR} %{PAYLOADFLAGS}', '-p', rpm ]
      , stdout=subprocess.PIPE )
   stdout, _ = compressCheck.communicate()
   return stdout

def recompressRpm( rpm ):
   if not rpm.endswith( '.rpm' ):
      return
   stdout = getRpmCompression( rpm )
   if stdout == b'xz 2':
      # Already in xz compression, no need to do anything. This is likely to
      # happen for something like TerminAttr
      return

   # Since swix goes in the swi, if we required rpmrebuild for this process,
   # we would need rpmrebuild (and thus rpm-macros, dwz,... ) in the swi too.
   # This isn't particularly useful, so instead we will check if the binary
   # is present, and complain as needed
   if not os.path.exists( '/usr/bin/rpmrebuild' ):
      print(
         'rpmrebuild is not installed, but recompression is required to produce '
         'this swix. Please run "a ws dnf install rpmrebuild" to install it, and '
         'try again',
         file=sys.stderr )
      sys.exit( 1 )

   nameInfo = subprocess.Popen( [
      'rpm', '-q', '--qf', '%{name}.%{arch}.rpm', '-p', rpm ]
      , stdout=subprocess.PIPE )
   name, _ = nameInfo.communicate()
   workingDir = tempfile.mkdtemp()
   # -d . means to do the rebuild in place, but the rebuilt rpm drops the version
   # It should get name.arch.rpm, as we print in nameInfo, so rename to that

   rrcmd = [
      'rpmrebuild', '--batch', '--define', '%_binary_payload w2.xzdio',
      '--define', '%_topdir ' + workingDir,
      '-d', '.',
      '-p', rpm
   ]

   # Default RPMREBUILD_TMPDIR is ~/.tmp, which will break locally ran tests
   # due to "exceeding disk quota" on CHD. Abuilds seem to work correctly
   rrenv = os.environ.copy()
   if "AUTOTEST" not in os.environ:
      rrenv[ "RPMREBUILD_TMPDIR" ] = tempfile.mkdtemp()

   rr = subprocess.Popen( rrcmd, env=rrenv )
   rr.communicate()
   assert rr.returncode == 0

   if six.ensure_str( name ) != rpm:
      # This is destructive of the original RPM file
      move = subprocess.Popen( [ 'mv', name, rpm ] )
      move.communicate()
      assert move.returncode == 0

   stdout = getRpmCompression( rpm )
   assert stdout == b"xz 2"
   shutil.rmtree( workingDir )

def recompressRpms( rpms ):
   for rpm in rpms:
      recompressRpm( rpm )

def create( filename, primaryRpm, rpms, args=None, sign=False ):
   # Check if "--force" flag has been used
   force = ( args and args.force )
   if os.path.exists( filename ):
      if force:
         os.remove( filename )
      else:
         msg = "File %s exists: use --force to overwrite\n" % filename
         sys.exit( msg )
   # Try to create the SWIX file
   try:
      # If a lowmem SWIX config file has been provided, generate the payload
      # RPM from the config file.
      lowmemCfgPath = getattr( args, "lowmem_config", None )
      if lowmemCfgPath:
         # Warn when providing arguments that are overriden
         if args.no_recompress or args.recompress:
            print( "--recompress/no-recompress is ignored when using "
                   "--lowmem-config. Resulting RPM/SWIX will not be "
                   "directly compressed." )
         if primaryRpm or rpms:
            print( "rpms are ignored when using --lowmem-config. Payload RPM "
                   "is generated based on the config file." )

         args.no_recompress = True
         args.recompress = False
         rpms = []

         cfg = Lowmem.config.loadConfig( os.path.abspath( lowmemCfgPath ) )
         isSwix = cfg[ "isSwix" ]
         if isSwix:
            assert filename.endswith( ".swix" )
            primaryRpm = filename.replace( ".swix", ".rpm" )
         else:
            primaryRpm = filename

         Lowmem.createExtension( primaryRpm, cfg )

      outfile = open( filename, "w" )
      dir = tempfile.mkdtemp( suffix=".dir", # pylint: disable=redefined-builtin
                              prefix=os.path.basename( filename ), dir="." )
      manifest = os.path.join( dir, "manifest.txt" )
      createManifest( manifest, primaryRpm, rpms )
      filesToZip = [ manifest, primaryRpm ] + rpms
      if args and args.info:
         allRpms = [ os.path.basename( x ) for x in [ primaryRpm ] + rpms ]
         valid, versionList = checkInfo( args.info, allRpms, args.no_release_db )
         if valid:
            filesToZip.append( args.info )
         # Tokyo is the first release to support zstd compression of RPMS in rpm
         # If no version is specified, we will instead assume that we are targetting
         # the current release, and thus won't need to recompress
         if ( not args.no_recompress and
              ( args.recompress or
                not versionList or
                ( versionList and min( versionList ) < '4.27.1' ) ) ):
            # Recompress the rpms as needed
            recompressRpms( [ primaryRpm ] + rpms )
      elif args and args.recompress:
         recompressRpms( [ primaryRpm ] + rpms )
      # The -j arg causes zip to strip the directory path from filenames so
      # the output zip archive contains no directories
      p = subprocess.Popen( [ "zip", "-", "-0", "-j" ] + list( filesToZip ),
                            stdout=outfile, universal_newlines=True )
      p.communicate()
      assert p.returncode == 0

      if sign:
         try:
            SwixSign = importlib.import_module( 'Swix.sign' )
            retCode = SwixSign.sign( filename, forceSign=True )
            if retCode != SwixSign.SWIX_SIGN_RESULT.SUCCESS:
               sys.exit( "Error occured during SWIX signing: %s\n" % retCode )
            else:
               print( "SWIX %s successfully signed!" % filename )
         except ImportError:
            # Swix signing only available in devel environments
            print( "Skipping SWIX signing because the service is unavilable." )

   except Exception as e: # pylint: disable=broad-except
      sys.exit( "Error occurred during generation of SWIX file: %s\n" % e )
   finally:
      shutil.rmtree( dir, ignore_errors=True )


def parseCommandArgs( args ):
   # Define parser for "swix create"
   parser = argparse.ArgumentParser( prog="swix create",
         description="""Build a swix from a set of rpms.  The version information
for the extension is taken from the "primary rpm", which is the first rpm specified
in the list or the one whose name is specified by the --primary-rpm argument.""" )
   parser.add_argument( 'outputSwix', metavar="OUTFILE.swix",
                        help="Name of output file" )
   parser.add_argument( 'rpms', metavar="PACKAGE.rpm", type=str, nargs='+',
                        help='An RPM to add to the swix' )
   parser.add_argument( '-f', '--force', action="store_true",
                        help='Overwrite OUTFILE.swix if it already exists' )
   parser.add_argument( '-i', '--info', metavar="manifest.yaml", action='store',
                     type=str,
                     help='Location of manifest.yaml file to add metadata to swix' )
   parser.add_argument( "-l", "--lowmem-config", type=str,
                        help="Generate a lowmem SWIX/standalone RPM using "
                             "provided config yaml. This makes the command "
                             "ignore -r, -c, rpms." )
   parser.add_argument( "-n", "--no-release-db", action="store_true",
                        help="Do not check the release DB for version information" )
   parser.add_argument( "--primary-rpm", action="store",
                        help="regexp that matches exactly one rpm which will be the "
                             "primary rpm" )
   compress_parser = parser.add_mutually_exclusive_group( required=False )
   compress_parser.add_argument( "-r", "--recompress", action="store_true",
                        help="Force recompression of RPMs for compatibility with " +
                        "EOS 4.27.0 and earlier." )
   compress_parser.add_argument( "-c", "--no-recompress", action="store_true",
                                 help="Force no recompression of RPMS." )
   sign_parser = parser.add_mutually_exclusive_group( required=False )
   sign_parser.add_argument( "--sign", dest='sign', action='store_true',
                             help="Sign the SWIX after creation" )
   sign_parser.add_argument( "--no-sign", dest='sign', action='store_false',
                             help="Do not sign the SWIX after creation (default)" )
   parser.set_defaults( no_recompress=False )
   parser.set_defaults( sign=False )
   return parser.parse_args( args )

def createHandler( args=sys.argv[1:] ): # pylint: disable=dangerous-default-value
   args = parseCommandArgs( args )
   if args.primary_rpm:
      pattern = re.compile( args.primary_rpm )
      primaryRpmCandidates = [ rpm for rpm in args.rpms if pattern.search( rpm ) ]
      if len( primaryRpmCandidates ) != 1:
         print( f"Failed to find primary rpm {args.primary_rpm}s from candidates "
                f"{args.rpms}" )
         return 1
      primaryRpm = primaryRpmCandidates[ 0 ]
      rpmList = [ rpm for rpm in args.rpms if rpm != primaryRpm ]
   else:
      primaryRpm = args.rpms[ 0 ]
      rpmList = args.rpms[ 1 : ]
   create( args.outputSwix, primaryRpm, rpmList, args, args.sign )
   return 0

if __name__ == "__main__":
   createHandler()
