# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

from collections import namedtuple
import os
import tempfile
import EosVersion
import Swi
import subprocess
from SwimHelperLib import (
      getSwimSqshMapAndOverlayDirs, mountPathOverLowers,
      parseSwimSqshMap, unmountAndCleanupWd )
from SwimTopology import installRpms, getInstalledRpms
import Tac

OptimRpmDetails = namedtuple( 'OptimRpmDetails', [ 'existingRpms', 'newRpms' ] )
MetadataTags = [ "provides", "requires", "obsoletes", "conflicts" ]

def applyCorrection( sqshdir, rpmsToRemove, debug=False ):
   assert isinstance( rpmsToRemove, list )
   step = 50
   cmd = [ "rpm", "--root=%s" % sqshdir,
            "--nodeps", "--erase" ]
   # Apply Rpm correction DB step RPMs at a time
   remaining = len( rpmsToRemove )
   for startIndex in range( 0, len( rpmsToRemove ), step ):
      numToProcess = min( step, remaining )
      rpmRemoveSubset = rpmsToRemove[ startIndex : ( startIndex + numToProcess ) ]
      remaining -= numToProcess
      redirect = Tac.INHERIT if debug else Tac.DISCARD
      Tac.run( cmd + rpmRemoveSubset, stdout=redirect, stderr=redirect, asRoot=True )

def getDeps( rpms, yumConfig, repoArgs,
             installRoot ):
   '''To get the list of dependencies that'll be pulled in,
   use yum install --assumeno and parse the output under
   the Installing: and Installing dependencies: sections.

   repoquery --requires --resolve doesn't work as expected.
   '''
   if yumConfig:
      cmd = [ "dnf", "-c", yumConfig ]
   else:
      cmd = [ "env", "A4_YUM_NATIVE_ARCH=1", "a4", "dnf" ]
   for repoArg in repoArgs:
      cmd += [ f"--{ repoArg.op.value }repo={ repoArg.repo }" ]
   cmd += [ "--assumeno", "install", "--installroot", installRoot ]
   # disable installation of recommended packages to match `swi installrootfs`
   cmd += [ "--setopt", "install_weak_deps=False" ]
   cmd += rpms

   def parseRpmLines( lines, startIndex ):
      rpmSet = set()
      for i in range( startIndex, len( lines ) ):
         line = outputLines[ i ].strip()
         # stop on an empty line or next section header
         if not line or line.endswith( ':' ):
            break
         rpmSet.add( line.split()[ 0 ] )
      return rpmSet

   newRpms = set()
   output = Tac.run( cmd, stdout=Tac.CAPTURE, stderr=Tac.CAPTURE,
                     ignoreReturnCode=True )

   outputLines = output.splitlines()
   installingHdr = "Installing:"
   assert installingHdr in outputLines
   installingIndex = outputLines.index( installingHdr ) + 1
   newRpms.update( parseRpmLines( outputLines, installingIndex ) )

   installingDepsHdr = "Installing dependencies:"
   if installingDepsHdr in outputLines:
      installingDepsIndex = outputLines.index( installingDepsHdr ) + 1
      newRpms.update( parseRpmLines( outputLines, installingDepsIndex ) )

   return newRpms

def mountAndFetchRpmDetails( extDir, lowerDirs, extraPkgs, rpmDbCorrection,
                             repoArgs ):
   '''
   Mount the lower dirs for a specific optimization. Then apply the RPMDB
   correction to make the RPMDB in the union dir reflect what's actually installed
   in the lower squash layers.
   Then attempt a dummy installation of extraPkgs to find out what dependencies
   are pulled in
   '''
   upperDir = "tmp-upperdir.dir"
   upperDirPath = os.path.join( extDir, upperDir )
   unionDirPath = None
   try:
      Tac.run( [ 'mkdir', '-p', upperDirPath ] )
      unionDirPath = mountPathOverLowers( extDir, upperDirPath, lowerDirs )
      applyCorrection( unionDirPath, rpmDbCorrection )
      existingRpms = getInstalledRpms( unionDirPath )
      newRpms = getDeps( rpms=extraPkgs, yumConfig=None,
                         repoArgs=repoArgs,
                         installRoot=unionDirPath )
   finally:
      if unionDirPath:
         unmountAndCleanupWd( extDir, unionDirPath )
      Tac.run( [ "rm", "-rf", upperDirPath ], asRoot=True )
   return OptimRpmDetails( existingRpms, newRpms )

def getRpmDbDirs( extDir ):
   filesInDir = os.listdir( extDir )
   rpmDbDirs = [ x for x in filesInDir if x.endswith( '.rpmdb.dir' ) ]
   return rpmDbDirs

def validateExtractedSwim( extDir, flavorName, supersetOptim, extraPkgs,
                           repoArgs ):
   ''' The new RPMs for the flavor will be installed in a new layer,
   which will be layered just below the existing RPMDB layer.
   Make sure the extracted SWI is a vaild SWIM, has the RPMDB layer as the layer
   right below the rootfs, and there's only one RPMDB layer.

   Also ensure that the same new squash layer can be re-used for all optimizations.
   This is done by making sure a fresh install of the new RPMS off each combination
   of lower layers pull in the same dependencies. '''

   filesInDir = os.listdir( extDir )
   assert EosVersion.swimHdrFile in filesInDir, (
         "Did not find swimSqshMap in extracted SWIM" )
   assert "flavor-%s.dir" not in filesInDir
   assert "flavor-%s.sqsh" not in filesInDir

   _, swimOverlayDirs = getSwimSqshMapAndOverlayDirs( extDir, None )
   assert supersetOptim in swimOverlayDirs, (
      "Superset optimization %s not found in squash map, available optims: %s" %
      ( ' '.join( swimOverlayDirs.keys() ) ) )

   supersetOverlayDirs = swimOverlayDirs[ supersetOptim ]
   lowerDirs = supersetOverlayDirs[ 'lower' ]
   lowerDirsList = lowerDirs.split( ':' )
   topmostLowerDir = lowerDirsList[ 0 ]
   assert topmostLowerDir.endswith( '.rpmdb.dir' )

   rpmDbDirs = getRpmDbDirs( extDir )
   assert len( rpmDbDirs ) == 1, (
         "We expect one and only one RPM DB layer" )
   rpmDbDir = rpmDbDirs[ 0 ]

   # For each optimization, try to do a dummy install of the RPMS
   # and fetch what dependencies are pulled in
   optimToRpmDetails = {}
   for optim, info in swimOverlayDirs.items():
      dbCorrectionRelPath = "etc/RpmDbCorrection"
      rpmDbCorrection = list() # pylint: disable=use-list-literal
      if optim != supersetOptim:
         dbCorrectionAbsPath = os.path.join(
               extDir, rpmDbDir, dbCorrectionRelPath, optim )
         assert os.path.exists( dbCorrectionAbsPath ), (
               "RpmDbCorrection file not found at %s" % dbCorrectionAbsPath )
         with open( dbCorrectionAbsPath ) as f:
            rpmDbCorrection = f.read().splitlines()
      optimToRpmDetails[ optim ] = mountAndFetchRpmDetails(
            extDir, info[ 'lower' ], extraPkgs, rpmDbCorrection,
            repoArgs )
   newRpmsInSupersetOptim = optimToRpmDetails[ supersetOptim ].newRpms

   # Assert that the set of the new RPMS pulled in is the same for all
   # optimizations
   for optim, details in optimToRpmDetails.items():
      assert details.newRpms == newRpmsInSupersetOptim, (
            "New RPMs to be installed for optim %s is %s, it doesn't match with "
            " new RPMS to be installed for superset optim %s which is:\n%s" %
            ( optim, ' '.join( details.newRpms ),
              supersetOptim, newRpmsInSupersetOptim ) )
   return optimToRpmDetails

def getRpmMetadata( pkgs, tag, root ):
   cmd = [ "rpm", "-q", "--root=%s" % root, "--" + tag ]
   cmd.extend( pkgs )
   output = subprocess.check_output( cmd, stderr=subprocess.STDOUT )
   lines = output.decode().splitlines()
   uniqueLines = list( set( line for line in lines if "warning:" not in line ) )
   return uniqueLines

def handleSimultaneousRequireAndProvideEntries( metadata ):
   provides = set( prov.split()[ 0 ] for prov in metadata[ 'provides' ] )
   metadata[ 'requires' ] = [ req for req in metadata[ 'requires' ]
                           if req.split()[ 0 ] not in provides ]

def writeSpecFormattedMetadata( metadata, filename ):
   with open( filename, "w" ) as f:
      for tag, deps in metadata.items():
         f.write( "%s: " % tag.capitalize() )
         for idx, dep in enumerate( deps ):
            f.write( dep )
            if idx != len( deps ) - 1:
               f.write( ", " )
            else:
               f.write( "\n" )

def flavorSwimWithExtraPkgs( image, supersetOptim,
                             flavorName, extraPkgs,
                             repoArgs, outputFile, outputMetadata=None ):
   '''Flavor the SWIM image with the specified superset optimization
   supersetOptim by installing the specified extraPkgs into a new squash layer,
   then update the RPMDB layer with the new RPM metadata, update the swimSqshMap
   with the new layer for all optimizations, and finally repackage and create
   the newly flavored image at outputFile'''

   extDir = tempfile.mkdtemp()
   try:
      Swi.extract.extract( image, Swi.SwiOptions( dirname=extDir,
                                                  use_existing=True ),
                           quiet=False, rpmOp=False, readOnly=False )

      optimToRpmDetails = validateExtractedSwim(
            extDir,
            flavorName, supersetOptim,
            extraPkgs, repoArgs )

      _, swimOverlayDirs = getSwimSqshMapAndOverlayDirs( extDir, None )
      lowerDirs = swimOverlayDirs[ supersetOptim ][ 'lower' ]
      upperDir = "flavor-%s.dir" % flavorName
      upperSqsh = "flavor-%s.sqsh" % flavorName
      upperDirFullPath = os.path.join( extDir, upperDir )
      rpmDbDirs = getRpmDbDirs( extDir )
      rpmDbDirFullPath = os.path.join( extDir, rpmDbDirs[ 0 ] )

      # Install the new RPMs.

      oldRpmSet = getInstalledRpms( rpmDbDirFullPath )
      # oldRpmSet was derived from mounting the RPMDb layer, while the details
      # were derived with the lower dirs also mounted.
      assert oldRpmSet == optimToRpmDetails[ supersetOptim ].existingRpms

      installRpms( rpms=extraPkgs,
                   yumConfig=None,
                   erasePattern=r'$',
                   repoArgs=repoArgs,
                   sqshdir=upperDirFullPath,
                   lowerdirs=lowerDirs, printOutput=True )
      # Use upperDirFullPath as installroot because rpmdb would have been copied up
      newRpmSet = getInstalledRpms( upperDirFullPath )

      # Nothing should have been removed
      assert oldRpmSet.issubset( newRpmSet ), (
            "RPMs %s were unexpectedly removed after new RPMs were added" %
            ( newRpmSet - oldRpmSet ) )

      # See if only expected RPMs were added
      newRpmsAdded = newRpmSet - oldRpmSet
      # Remove gpg-pubkey since it's a fake rpm which may be added
      # when installing rpms
      newRpmsAdded.discard( "gpg-pubkey" )
      expectedRpmsToBeAdded = optimToRpmDetails[ supersetOptim ].newRpms
      assert newRpmsAdded == expectedRpmsToBeAdded, (
            "We expected these RPMs to be added: %s, "
            "but found that these were added: %s" %
            ( expectedRpmsToBeAdded, newRpmsAdded ) )

      # If output metadata file has been provided, extract metadata from
      # installed RPMs
      if outputMetadata is not None:
         metadata = {}
         for tag in MetadataTags:
            metadata[ tag ] = getRpmMetadata( newRpmsAdded, tag,
                                            upperDirFullPath )
         handleSimultaneousRequireAndProvideEntries( metadata )
         writeSpecFormattedMetadata( metadata, outputMetadata )

      # Copy RPMDB down to RPMDB squash after installation
      # No need to wipe rpmdb from the new layer, swi create will do it for us
      Tac.run( [ "rsync", "-axHAX",
                 "%s/var/lib/rpm" % upperDirFullPath,
                 "%s/var/lib" % rpmDbDirFullPath ], asRoot=True )
      rpmDbSquashFullPath = os.path.join( extDir,
                                          rpmDbDirs[ 0 ].replace( '.dir', 'sqsh' ) )
      # Remove the RPMDB squash as it needs to be resquashed anyway after this update
      Tac.run( [ "rm", "-rf", rpmDbSquashFullPath ], asRoot=True )

      # Tweak SWIM squash map to add new layer to all optimizations
      swimSquashMapPath = os.path.join( extDir, EosVersion.swimHdrFile )
      swimSqshMap = parseSwimSqshMap( swimSquashMapPath )
      for optim, lowerDirs in swimSqshMap.items():
         lowerDirsList = lowerDirs.split( ':' )
         assert lowerDirsList[ 1 ].endswith( '.rpmdb.sqsh' )
         lowerDirsList.insert( 2, upperSqsh )
         swimSqshMap[ optim ] = ':'.join( lowerDirsList )
      with open( swimSquashMapPath, "w" ) as f:
         for optim, lowerDirs in swimSqshMap.items():
            f.write( f"{optim}={lowerDirs}\n" )

      # Update version files of all opts to new flavor
      versionFiles = [ 'version' ] + [ x + '.version' for x in swimSqshMap ]
      for verFile in versionFiles:
         verFilePath = os.path.join( extDir, verFile )
         assert os.path.exists( verFilePath )
         sedCmd = [ 'sed', '-i', 's/SWI_FLAVOR=.*/SWI_FLAVOR=%s/' % flavorName,
                    verFilePath ]
         Tac.run( sedCmd )

      # Repackages changes into new SWI image
      Tac.run( [ 'swi', 'create', '-d', extDir, outputFile ] )
   finally:
      Tac.run( [ "rm", "-rf", extDir ], asRoot=True )
