#!/usr/bin/env python3
#
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
#
# This script mounts a squashfs payload at a specified mount point directly
# from an RPM or a SWIX that contains an RPM with an uncompressed CPIO payload.

import argparse
import subprocess
import os
import sys
import zipfile
import struct

PKZIP = b"\x50\x4b\x03\x04"
PKZIP_NAMELEN_O = 26
PKZIP_HDR_SZ = 30

RPM = b"\xed\xab\xee\xdb"
RPM_HDR = b"\x8e\xad\xe8"
RPM_HDR_O = 96
RPM_HDR_SZ = 8
RPM_ENT_SZ = 16

CPIO_NEWC = b"\x30\x37\x30\x37\x30\x31"
CPIO_FSIZE_O = 54
CPIO_NSIZE_O = 94
CPIO_HDR_SZ = 110

RET_INVALIDSWIX = -1
RET_INVALIDRPM = -2
RET_INVALIDHDR = -3
RET_INVALIDCPIO = -4
RET_MOUNTPOINT = -5
RET_MOUNTFAILURE = -6

def fatalException( reason, retcode ):
   print( reason, file=sys.stderr )
   sys.exit( retcode )

def assertMagic( value, size, file, errReason, errCode ):
   magic = file.read( size )
   if not magic or magic != value:
      fatalException( errReason, errCode )

def getZipHeaderParams( zipname, filename ):
   with zipfile.ZipFile( zipname, "r" ) as z:
      if filename not in z.namelist():
         fatalException( f"{filename} not in {zipname}", RET_INVALIDSWIX )

      zi = z.getinfo( filename )
      if zi.compress_type != zipfile.ZIP_STORED:
         fatalException( f"{filename} should be stored uncompressed",
                         RET_INVALIDSWIX )

      hdrOffset = z.getinfo( filename ).header_offset
      fileSize = zi.file_size

      return hdrOffset, fileSize

def readHdrTotalSize( file, offset ):
   file.seek( offset )
   assertMagic( RPM_HDR, 3, file, "File doesn't look like an RPM.",
               RET_INVALIDHDR )

   offset += RPM_HDR_SZ
   file.seek( offset )
   ents = int.from_bytes( file.read( 4 ), byteorder="big", signed=False )
   dataSz = int.from_bytes( file.read( 4 ), byteorder="big", signed=False )
   
   return ents * RPM_ENT_SZ + dataSz

def calculateCpioOffset( file, baseOffset ):
   offset = baseOffset + RPM_HDR_O
   totSize = readHdrTotalSize( file, offset )
   padSize = ( 8 - ( totSize % 8 ) ) % 8
   offset += RPM_HDR_SZ * 2 + totSize + padSize
   totSize = readHdrTotalSize( file, offset )
   return offset + RPM_HDR_SZ * 2 + totSize

def calculateSqfsOffset( file, cpioPosition, payload ):
   offset = cpioPosition

   while True:
      file.seek( offset )
      assertMagic( CPIO_NEWC, 6, file,
                  f"Invalid CPIO or {payload} not in CPIO.", RET_INVALIDCPIO )

      file.seek( offset + CPIO_FSIZE_O )
      filesize = int( file.read( 8 ).decode(), 16 )
      file.seek( offset + CPIO_NSIZE_O )
      namesize = int( file.read( 8 ).decode(), 16 )
      file.seek( offset + CPIO_HDR_SZ )
      fname = file.read( namesize )
      namepad = ( 4 - ( ( CPIO_HDR_SZ + namesize ) % 4 ) ) % 4
      file.seek( namepad, 1 )

      if fname.decode( "ascii" )[ : -1 ] == f".{payload}":
         return file.tell()

      entsize = CPIO_HDR_SZ + namesize + namepad + filesize
      entpad = ( 4 - ( entsize % 4 ) ) % 4
      offset += entsize + entpad

def mountSqfs( rpm, extension, mountpoint, swix=None ):
   """
   Calculates the extension squashfs offset in the rpm file or from
   a swix file if provided and mounts it directly from the file at 
   mountpoint.
   """
   inFile = swix if swix else rpm

   if swix:
      rpmHdrOffset, rpmSize = getZipHeaderParams( swix, rpm )

      print( f"RPM header offset: {rpmHdrOffset} RPM filesize: {rpmSize}" )

   with open( inFile, "rb" ) as f:
      if swix:
         f.seek( rpmHdrOffset )
         assertMagic( PKZIP, 4, f,
                      f"Invalid ZIP entry header at offset {rpmHdrOffset}.",
                      RET_INVALIDSWIX )

         f.seek( rpmHdrOffset + PKZIP_NAMELEN_O )

         filenameSize, extraSize = struct.unpack( "<HH", f.read( 4 ) )
         rpmOffset = rpmHdrOffset + PKZIP_HDR_SZ + filenameSize + extraSize

         f.seek( rpmOffset )

      baseOffset = f.tell()

      assertMagic( RPM, 4, f, "File doesn't look like an RPM.",
                  RET_INVALIDRPM )

      cpioOffset = calculateCpioOffset( f, baseOffset )
      print( f"CPIO offset: {cpioOffset}" )

      sqfsOffset = calculateSqfsOffset( f, cpioOffset, extension )
      print( f"Squashfs offset: {sqfsOffset}" )

   # Setup directories for squashfs and tmpfs
   tmpfsdir = os.path.join( mountpoint, "overlay-tmpfs" )
   lowerdir = os.path.join( mountpoint, "sqfs" )
   for d in [ mountpoint, tmpfsdir, lowerdir ]:
      os.makedirs( d, exist_ok=True )
   
   print( f"Created the directory structure at {mountpoint}" )

   # Setup tmpfs
   proc = subprocess.run( [ "mount", "-t", "tmpfs", "tmpfs", tmpfsdir ],
                          check=True )
   if proc.returncode != 0:
      fatalException( "Tmpfs mount unsuccessful.", RET_MOUNTFAILURE )
   print( f"Tmpfs mount successful at {tmpfsdir}" )

   # Create rest of overlay directory structure
   upperdir = os.path.join( tmpfsdir, "upper" )
   workdir = os.path.join( tmpfsdir, "work" )
   mergeddir = os.path.join( tmpfsdir, "merged" )
   for d in [ upperdir, workdir, mergeddir ]:
      os.makedirs( d, exist_ok=True )

   print( "Created the directory structure for overlayfs." )

   # Mount squashfs with tmpfs as overlay
   proc = subprocess.run( [ "mount", "-t", "squashfs", "-o",
                          f"ro,loop,offset={sqfsOffset}", inFile, lowerdir ],
                          check=True )
   if proc.returncode != 0:
      fatalException( "Squashfs mount unsuccessful.", RET_MOUNTFAILURE )

   print( f"Sqfs mount successful at {lowerdir}" )

   proc = subprocess.run( [ "mount", "-t", "overlay", "overlay", "-o",
                          f"lowerdir={lowerdir},upperdir={upperdir},"
                          f"workdir={workdir}", mergeddir ], check=True )
   if proc.returncode != 0:
      fatalException( "Squashfs overlay mount unsuccessful.",
                      RET_MOUNTFAILURE )

   print( f"Overlay mount successful at {mergeddir}" )

if __name__ == "__main__":
   ap = argparse.ArgumentParser()
   ap.add_argument( "-s", "--swix", help="SWIX filename", default=None )
   ap.add_argument( "rpm", help="RPM filename, if SWIX is provided RPM "
                                "filename within SWIX archive" )
   ap.add_argument( "extension",
                   help="Extension squash filename in the RPM CPIO payload" )
   ap.add_argument( "mountpoint", help="Mount point for the extension squash" )
   args = ap.parse_args()

   mountSqfs( args.rpm, args.extension, args.mountpoint, args.swix )
