# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import glob
import os
import re
import subprocess
import sys
import zipfile
import EosVersion
import Tac
import platform

if platform.machine() == 'aarch64':  # A4NOCHECK
   farch = "aarch64"
else:
   farch = "i386"


# Returns rootfs sqsh names found in an extracted SWI Dir
def getRootSqshNamesFromSwiDir( swiDir ):
   rootSqshNames = [ os.path.basename( path ) for path in
                     glob.glob( f'{swiDir}/*.rootfs-{farch}.sqsh' ) ]
   rootSqshNames = [ d.replace( '.rootfs-%s.sqsh' % farch, '' )
                     for d in rootSqshNames ]
   return rootSqshNames

# Returns rootfs sqsh names found in an a zipped SWI
def getRootSqshNamesFromSwi( swiPath ):
   rootSqshNames = []
   if not zipfile.is_zipfile( swiPath ):
      raise Exception( "Error: input swi %s is not a SWI File" % swiPath )
   try:
      with zipfile.ZipFile( swiPath, 'r' ) as swi:
         zippedFiles = swi.namelist()
         rootSqshNames = [ f.replace( '.rootfs-%s.sqsh' % farch, '' )
                           for f in zippedFiles if
                           re.match( r'.*\.rootfs-%s\.sqsh' % farch, f ) ]
   except:
      # pylint: disable-next=raise-missing-from
      raise Exception( "Unable to read from SWI File" )
   return rootSqshNames

def getSupportedOptimizationsFromSwi( swiPath ):
   optimizations = []
   try:
      with zipfile.ZipFile( swiPath, 'r' ) as swi:
         zippedFiles = swi.namelist()

         if EosVersion.swimHdrFile not in zippedFiles:
            return optimizations

         swimSqshMapData = swi.read( EosVersion.swimHdrFile ).decode( 'utf-8' )
         for l in swimSqshMapData.splitlines():
            optimization, sqshList = l.split( '=', 1 )
            sqshes = sqshList.split( ':' )
            if set( sqshes ).issubset( set( zippedFiles ) ):
               optimizations.append( optimization )
   # pylint: disable=broad-except
   except Exception:
      print( "Warning - unable to read from SWI" )
   return optimizations

# Accepts a SWI File or extracted SWI Dir as swiPath
# and returns a list of superset rootfs found.
# A superset rootfs includes full platform support
# in the image. Multiple supersets should only be found
# if swiPath is a meta image.
# Rootfs sqshes are in format - <rootSqshName>.rootfs-i386.sqsh
# Pass fullSqshName=True to return the full squash filenames
def getSupersetRootfs( swiPath, failOnMissing=False, fullSqshName=False ):
   if os.path.isdir( swiPath ):
      rootSqshNames = getRootSqshNamesFromSwiDir( swiPath )
   else:
      rootSqshNames = getRootSqshNamesFromSwi( swiPath )

   # Filter out non-superset names
   rootSqshNames = sorted( list( set( rootSqshNames ) &
                                 set( EosVersion.fullSwimFlavors ) ) )
   if not rootSqshNames:
      if failOnMissing: # pylint: disable=no-else-raise
         raise Exception( "Error: no superset optimization found in image" )
      else:
         print( "No superset rootfs found in %s" % swiPath )
   elif len( rootSqshNames ) > 1:
      print( "Multiple superset rootfs found in image" )
   else:
      print( "Found superset rootfs - %s" % rootSqshNames[ 0 ] )

   if fullSqshName:
      baseName = ".rootfs-%s.sqsh" % farch
      rootSqshNames = [ opt + baseName for opt in rootSqshNames ]

   return rootSqshNames

# Parse the overlay map file to create a flavor to lowerdirs map
def parseSwimSqshMap( overlayFn ):
   flavorToLowerdirs = {}
   with open( overlayFn ) as overlayFile:
      for line in overlayFile:
         flavor, lowerdirs = line.split( "=", 1 )
         flavorToLowerdirs[ flavor ] = lowerdirs.replace( "\n", "" )

   return flavorToLowerdirs

def onSupportedKernel():
   # overlay mounting is not supported on 3.18 kernels.
   verInfo = Tac.run( [ "uname", "-r" ], stdout=Tac.CAPTURE )
   return not verInfo.startswith( '3.18' )

# union mount rootfs for all swim flavors based
def mountSwimRootfs( overlayDirs, flavor=None ):
   if not onSupportedKernel():
      raise Exception( "ERROR: overlay mounting is not supported on 3.18 kernel" )

   flavors = overlayDirs
   if flavor and ( flavor in flavors ):
      flavors = [ flavor ]

   for f in flavors:
      dirs = overlayDirs[ f ]
      Tac.run( [ "mkdir", "-p", dirs[ 'root' ], dirs[ 'work' ],
                 dirs[ 'union' ] ], asRoot=True )
      Tac.run( [ "mount", "-t", "overlay", "overlay-%s" % f, "-o",
                 "lowerdir={},upperdir={},workdir={}".format( dirs[ 'lower' ],
                                                          dirs[ 'root' ],
                                                          dirs[ 'work' ] ),
                 dirs[ 'union' ] ], asRoot=True )

def unmountSwimRootfs( overlayDirs, flavor=None ):
   flavors = overlayDirs
   if flavor and ( flavor in flavors ):
      flavors = [ flavor ]
   for f in flavors:
      if os.path.ismount( overlayDirs[ f ][ "union" ] ):
         Tac.run( [ "umount", overlayDirs[ f ][ "union" ] ], asRoot=True )
      # Cleanup dirs created for overlay mounting
      Tac.run( [ "rm", "-rf", overlayDirs[ f ][ "union" ] ], asRoot=True )
      Tac.run( [ "rm", "-rf", overlayDirs[ f ][ "work" ] ], asRoot=True )

# Mount lowerdirs as read only.
# unionDir will become the  union/merged directory of the
# mounted lowerdirs
def mountSwimfsReadOnly( unionDir, lowerdirs, flavor ):
   # SWIM rootfs should be prepended to lowerdirs before calling
   # this func since overlap of existing files in unionDir
   # and files in lowerdirs will hide the unionDir versions
   Tac.run( [ "mount", "-t", "overlay", "overlay-%s" % flavor,
            "-o", "lowerdir=%s" % lowerdirs, unionDir ], asRoot=True )

def umountSwimFs( rootfs, ignoreReturnCode=False ):
   if not os.path.ismount( rootfs ):
      print( "%s is not a mountpoint, skipping umount" % rootfs )
      return
   Tac.run( [ "umount", rootfs ], asRoot=True,
            ignoreReturnCode=ignoreReturnCode )

def isSwimImage( swi ):
   # pylint: disable-next=consider-using-with
   p = subprocess.Popen( [ "unzip", "-l", swi, EosVersion.swimHdrFile ],
                         stdout=subprocess.PIPE )
   p.communicate()
   return not p.returncode

def isSwimDir( swiDir ):
   return os.path.exists( os.path.join( swiDir, EosVersion.swimHdrFile ) )

def isPartitionType( dir_, fsType="ext4" ):
   # by using the `df` command, determines if the given directory is of type `fsType`
   # pylint: disable-next=consider-using-with
   p = subprocess.Popen( [ "df", "-a", "--output=fstype", dir_ ],
                         stdout=subprocess.PIPE,
                         stderr=subprocess.PIPE )
   stdOut, stdErr = p.communicate()
   assert p.returncode == 0, \
          "Error running 'df -a ---output=fstype':\n%s" % stdErr

   # output should be 'Type\n<actualFsType>'
   _, actualFsType = stdOut.splitlines()
   return actualFsType == fsType.encode()

# swimdir is the extracted SWIM dir
# of which we will create the overlayFS union/work dirs.
def getSwimSqshMapAndOverlayDirs( swimdir, swimFlavors ):
   swimSqshMap = os.path.join( swimdir, EosVersion.swimHdrFile )
   swimOverlayDirs = {}
   if os.path.exists( swimSqshMap ):
      flavors = []
      flavorSqshMap = parseSwimSqshMap( swimSqshMap )
      if swimFlavors:
         flavorsIn = swimFlavors.split( "," )
         for flavor in flavorsIn:
            if flavor not in flavorSqshMap:
               print( "Swim flavor %s is not supported by this image" % flavor )
            else:
               flavors += [ flavor ]
         if not flavors:
            raise Exception( "Error: No valid swim flavor arg specified" )
      else:
         flavors = flavorSqshMap
      for flavor in flavors:
         rootfsSqsh = glob.glob( os.path.join( swimdir,
                                               flavor + ".rootfs-*.sqsh" ) )
         if rootfsSqsh:
            lowerdirs = []
            for sqsh in flavorSqshMap[ flavor ].split( ":" ):
               if f'{flavor}.rootfs-' in sqsh:
                  continue
               lowerdirs += [ os.path.join( swimdir,
                                            sqsh.replace( ".sqsh", ".dir" ) ) ]


            swimOverlayDirs[ flavor ] = {
               "union": f"{swimdir}/union.{flavor}",
               "work": f"{swimdir}/work.{flavor}",
               "lower" : ":".join( lowerdirs ),
               "root" : rootfsSqsh[ 0 ].replace( ".sqsh", ".dir" )
            }
   else:
      swimSqshMap = None
   return swimSqshMap, swimOverlayDirs

def mountSwimWorkspaceRootfs( wsRoot ):
   flavor = None
   flavorsArg = ",".join( EosVersion.fullSwimFlavors )

   # Retrieve overlayDirs (flavor => union/work/lower/root dirs)
   # for all full SWIM flavors. Then see which of those flavor's rootfs
   # are installed in the SWI to determine the dirs to overlay mount
   _, swimOverlayDirs = getSwimSqshMapAndOverlayDirs( wsRoot, flavorsArg )

   for swimFlavor in swimOverlayDirs: # pylint: disable=consider-using-dict-items
      if os.path.exists( swimOverlayDirs[ swimFlavor ][ "root" ] ):
         flavor = swimFlavor
         break

   if flavor:
      mountSwimRootfs( swimOverlayDirs, flavor )
   else:
      sys.stderr.write( "Failed to retrieve overlay from swiDir %s" % wsRoot )
      sys.exit( 1 )
   return swimOverlayDirs[ flavor ][ "union" ]

def unmountAndCleanupWd( wd, unionDir ):
   if os.path.ismount( unionDir ):
      Tac.run( [ "umount", unionDir ], asRoot=True )
   Tac.run( [ "rm", "-rf", unionDir ], asRoot=True )
   Tac.run( [ "rm", "-rf", os.path.join( wd, "workdir" ) ], asRoot=True )

def mountPathOverLowers( wd, upperDir, lowers ):
   overlayName = os.path.basename( upperDir )
   unionDir = os.path.join( wd, "union.%s" % overlayName )
   workDir = os.path.join( wd, "workdir" )

   Tac.run( [ "mkdir", "-p", unionDir, workDir ] )

   Tac.run( [ "mount", "-t", "overlay", "overlay-%s" % overlayName, "-o",
              "lowerdir={},upperdir={},workdir={}".format( lowers,
                                                       upperDir,
                                                       workDir ),
              unionDir ], asRoot=True )
   return unionDir

def removeOverlayXattrsFromDir( _dir ):
   # Overlay xattrs need to be removed from SWI due to increasing complications
   # they're introducing when unsquashing an image to an incompatible filesystem
   if os.geteuid() != 0:
      print( "You need to have root privileges to use the xattr python module" )
      sys.exit( 1 )

   overlayXattrs = [
      'trusted.overlay.origin',
      'trusted.overlay.impure',
      'trusted.overlay.opaque',
   ]
   xattrsRemovedCount = {}
   # pylint: disable=c-extension-no-member
   for root, directories, files in os.walk( _dir ):
      for d in directories + files:
         fullPath = os.path.join( root, d )
         if os.path.islink( fullPath ):
            continue
         attrList = os.listxattr( fullPath )
         overlayXattrMatches = list( set( overlayXattrs  ) & set( attrList ) )
         for overlayXattr in overlayXattrMatches:
            os.removexattr( fullPath, overlayXattr )
            if overlayXattr not in xattrsRemovedCount:
               xattrsRemovedCount[ overlayXattr ] = 0
            xattrsRemovedCount[ overlayXattr ] += 1
   for _xattr, count in xattrsRemovedCount.items():
      print( "Removed %s %d times from %s" % ( _xattr, count, _dir ) )
