#!/usr/bin/env python3
# Copyright (c) 2006-2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import Tac
import PyWrappers.Mkfs as mkfs
import os, re, shutil, stat, tempfile, time, pexpect
import six

bytesPerSector = 512
MB = 1024 * 1024
FDISK_COMMAND_LINE = 'Command \\(m for help\\): '

# Depend on the rpms that provide mkfs for vfat, ext3, ext4
# pkgdeps: rpm dosfstools
# pkgdeps: rpm e2fsprogs

# Filesystem identifier, filesystem name(mkfs alias), options to pass to fsck,
# size of the crash partition, extra commands after creating the filesystem
filesystems = {
   'vfat' : [ 'c', 'vfat', [ '-F', '32' ], 0, [] ],
   'ext3' : [ '83', 'ext3', [], 0, [] ],
   'ext4' : [ '83', 'ext4', [ '-m', '0' ], 64 * MB,
              [ [ 'tune2fs', '-c', '1' ] ] ],
}

mkfsLabelOption = {
    'vfat': '-n',
    'ext4': '-L',
    'ext3': '-L',
}

# NOTE: fdisk specifies a preferred partition alignment/ granularity of 1 MiB.
# This granularity and therefore minimum partition size is valid unless either
# of two conditions occur:
#       1. The disk block size is over 1 MiB.
#       2. The disk is extremely small (see update_sector_offset() in the
#          util-linux version of fdisk.c)
# These cases are not handled for now because most modern disk block sizes are
# not over 1 MiB, and the disks in our systems are not so small that case 2
# applies.
minPartitionSize = 1024 * 1024

def diskSizeBytes( disk ):
   if stat.S_ISBLK( os.stat( disk ).st_mode ):
      fd = os.open( disk, os.O_RDONLY )
      try:
         res = os.lseek( fd, 0, os.SEEK_END )
      finally:
         os.close( fd )
      return res
   else:
      return os.stat( disk ).st_size

def trySetupLoop( cmd ):
   try:
      x = Tac.run( cmd, stdout=Tac.CAPTURE, stderr=Tac.CAPTURE,
                   asRoot=True )
      return x
   except Tac.SystemCommandError as e:
      if not 'Resource temporarily unavailable' in e.output:
         raise e
      return None

def setupLoop( disk, opts=None ):
   opts = [] if opts is None else opts
   cmd = ["/sbin/losetup"] + opts + ["--show", "-vf", disk]
   x = Tac.waitFor( lambda: trySetupLoop( cmd ),
                    description='losetup to not return EAGAIN',
                    timeout=5 )
   m = re.match( r"(/dev/\w+)", x.strip() )
   assert m, x
   return m.group( 1 )

def deleteLoop( dev ):
   # There doesn't seem to be any reliable way to sync a loop device
   # before attempting to delete it (and on F12, even the trivial
   # 'losetup /dev/loop1 /dev/sdc; losetup -d /dev/loop1' fails
   # intermittently) so we try a bunch of times before giving up
   for _ in range( 10 ):
      x = Tac.run( ["/sbin/losetup", "-d", dev],
                   stdout=Tac.CAPTURE, stderr=Tac.CAPTURE,
                   asRoot=True, ignoreReturnCode=True )
      if "Device or resource busy" not in x:
         break
      time.sleep( 1 )

   # Ensure the loopback device buffers get flushed to the real block device
   os.system("sync")
   time.sleep(2)

def mountLoop( disk, loopOpts=None, mountOpts=None, _dir=None ):
   loopOpts = [] if loopOpts is None else loopOpts
   mountOpts = [] if mountOpts is None else mountOpts
   if not _dir:
      _dir = tempfile.mkdtemp()
   try:
      dev = setupLoop( disk, loopOpts )
      with open( "%s/loopdev" % _dir, "w" ) as f:
         f.write( dev )
      Tac.run( ["/bin/mount"] + mountOpts + [dev, _dir],
               stdout=Tac.CAPTURE, asRoot=True )
      return _dir
   except:
      unmountLoop( _dir, removeDir=False )
      raise

def unmountLoop( _dir, removeDir=True ):
   Tac.run( ["/bin/umount", _dir], asRoot=True, ignoreReturnCode=True )
   try:
      with open( "%s/loopdev" % _dir ) as f:
         dev = f.read()
   except OSError:
      pass
   else:
      deleteLoop( dev )
   if removeDir:
      shutil.rmtree( _dir, ignore_errors=True )

def fixupDirPerms( fsDir ):
   # extX filesystem supports POSIX permissions properly. This means that
   # we need to be carefull because in EOS we want everything on the flash to
   # be owned by root:eosadmin (gid=88).
   # We use 88 here instead of eosadmin because Arora doesn't know
   # about the eosadmin group.
   # These commands won't have any effects on vfat, which garantees correct
   # permissions via mount options.
   Tac.run( [ "chown", "-R", "root:88", fsDir ], stdout=Tac.DISCARD,
            stderr=Tac.DISCARD, asRoot=True, ignoreReturnCode=True )
   Tac.run( [ "chmod", "-R", "ug=rwX,o=", fsDir ], stdout=Tac.DISCARD,
            stderr=Tac.DISCARD, asRoot=True, ignoreReturnCode=True )

# Offset defaults to sector 2048, which is the usual start of partition 1
def mountHdFs( disk, offset=2048*bytesPerSector, writable=True, _dir=None ):
   return mountLoop( disk, [ "-o", str( offset ) ], [ writable and "-w" or "-r" ],
                     _dir )

def unmountHdFs( _dir, removeDir=True ):
   unmountLoop( _dir, removeDir )

def emptyHd( disk ):
   with open( disk ) as f:
      return f.read( 2048 ) == chr( 0 ) * 2048

def mountedHd( disk ):
   with open( "/proc/mounts" ) as mounts:
      return re.findall( "^%s[0-9]* " % re.escape( disk ),
         mounts.read(), re.MULTILINE )

def createFileSystem( disk, fsType, fsPartType, fsOpts, tuneCmds, parts,
                      label=None ):
   loopDev = setupLoop( disk, [ "-o", str( parts[ 0 ] )] )

   if label != None: # pylint: disable=singleton-comparison
      fsOpts += [ mkfsLabelOption[ fsType ], label ]

   try:
      Tac.run( [ mkfs.name() + '.' + fsType ] + fsOpts
               + [ loopDev, str( parts[ 1 ] // 1024 ) ],
               stdout=Tac.CAPTURE, stderr=Tac.CAPTURE, asRoot=True )
      if tuneCmds is None:
         tuneCmds = []
      for tuneCmd in tuneCmds:
         Tac.run( tuneCmd + [ loopDev ],
                  stdout=Tac.CAPTURE, stderr=Tac.CAPTURE, asRoot=True )

      if fsType.startswith( 'ext' ):
         fsDir = mountHdFs( disk, offset=parts[0] )
         try:
            fixupDirPerms( fsDir )
         finally:
            unmountHdFs( fsDir )
   finally:
      deleteLoop( loopDev )

def createNewPartition( fdisk, fsPartType, sectorSize=0, sectorLeft=0 ):
   fdisk.expect( FDISK_COMMAND_LINE )
   fdisk.sendline( "n" )
   fdisk.expect( "Select \\(default p\\):" )
   fdisk.sendline( "p" )
   fdisk.expect( "Partition number \\(\\d+[-,]\\d+, default (\\d+)\\):" )
   partitionNumber = int( fdisk.match.group( 1 ) )
   fdisk.sendline( "" )
   fdisk.expect( "First sector \\((\\d+)-(\\d+), default.*:" )

   # We get the first & last sectors from fdisk
   firstSector, lastSector = fdisk.match.groups()
   firstSector, lastSector = int( firstSector ), int( lastSector )

   # We start on the first sector
   fdisk.sendline( str( firstSector ) )
   fdisk.expect( "Last.*:" )

   # And we end with enough room for sectorLeft sectors
   fsEndSector = lastSector - sectorLeft
   if sectorSize > 0:
      # - 1 because we can use the last sector
      fsEndSector = firstSector + sectorSize - 1
      if fsEndSector > lastSector: # pylint: disable=consider-using-min-builtin
         fsEndSector = lastSector

   fdisk.sendline( str( fsEndSector ) )
   i = fdisk.expect( [ FDISK_COMMAND_LINE,
                       r'Do you want to remove the signature\?.*:' ] )
   if i == 1:
      # This is new in CentOS 8 util-linux
      fdisk.sendline( "Y" )
      fdisk.expect( FDISK_COMMAND_LINE )
   fdisk.sendline( "t" )
   fdisk.expect( 'Selected partition \\d+|Partition number .*:' )
   if str( fdisk.match.group( 0 ) ).startswith( 'Partition number' ):
      fdisk.sendline( str( partitionNumber ) )

   fdisk.expect( "Hex code (or alias )?\\(type L to list( all)?( codes)?\\):" )
   fdisk.sendline( fsPartType )

   return firstSector, fsEndSector

def bytesToSector( size, align=0 ):
   if size == 0:
      return 0

   if align:
      size = ( size + ( align - 1 ) ) & ~( align - 1 )
   return ( size - 1 ) // bytesPerSector + 1

def computeStartAndLength( startInSectors, lastInSectors ):
   """Compute in bytes the offset at which the partition starts and its length
   from the start/end sectors.
   """
   start = startInSectors * bytesPerSector
   # We add 1 here because the last sector is actually usable
   length = ( lastInSectors + 1 ) * bytesPerSector - start
   return start, length

def formatHd( disk, fsPartType, fsType, fsOpts=None, crashPartSize=0, tuneCmds=None,
              swi=None, bootConfig=None, startupConfig=None, zeroTouchConfig=None,
              kickStartConfig=None, recoverySwi=None, addToRecovery=None,
              flashSize=0 ):
   """ Set up the partition table.
   Makes two partitions, for the FS and the recovery image.
   Returns [[offset, len], [offset, len]] for the two partitions, in bytes,
   which you can use to loopback mount them.
   """
   fsOpts = [] if fsOpts is None else fsOpts
   with open( disk, "r+" ) as f:
      f.write( chr( 0 ) * 2048 )

   recovSize = makeRecovFile( "/dev/null", swi, bootConfig, startupConfig,
                              zeroTouchConfig, kickStartConfig, recoverySwi,
                              addToRecovery )

   flashSizeSectors = bytesToSector( flashSize * MB, MB )
   crashSizeSectors = bytesToSector( crashPartSize, MB )
   recovSizeSectors = bytesToSector( recovSize, MB )

   fdisk = pexpect.spawn(
      "/sbin/fdisk",
      [ "-c=dos", disk ],
      encoding=( 'utf-8' if six.PY3 else None )
   )
   fdisk.setwinsize( 200, 200 )

   # Get the size of the disk to be sure that the requested size for the flash
   # will fit. If it doesn't fit, we set flashSize to 0 to use as much space
   # as possible.
   if flashSize:
      fdisk.expect( FDISK_COMMAND_LINE )
      fdisk.sendline( "p" )
      fdisk.expect( "Disk [^\\s]+: \\d+(\\.\\d+)? (Mi?B|Gi?B), (\\d+) bytes" )
      # pylint: disable-msg=maybe-no-member
      diskSizeSectors = bytesToSector( int( fdisk.match.group( 3 ) ) )
      totalEstimatedSizeSectors = flashSizeSectors + crashSizeSectors + \
                                  recovSizeSectors
      if totalEstimatedSizeSectors > diskSizeSectors:
         print( 'WARNING: Requested flash size is bigger than the available size' )
         print( 'Going to use all the space available' )
         flashSizeSectors = 0

   fdisk.expect( FDISK_COMMAND_LINE )
   fdisk.sendline( "o" )

   # If we want the flash to have a fixed size then do it. Otherwise, compute
   # the size to be left for the crash and recovery partition, if they need
   # to be created.
   if flashSizeSectors > 0:
      flashFirstSector, flashLastSector = \
         createNewPartition( fdisk, fsPartType, sectorSize=flashSizeSectors )
   else:
      flashFirstSector, flashLastSector = \
         createNewPartition( fdisk, fsPartType,
                             sectorLeft=crashSizeSectors + recovSizeSectors )

   fdisk.expect( FDISK_COMMAND_LINE )
   fdisk.sendline( "a" )
   # CentOS 7 fdisk automatically selects the only partition, FC18 version asks
   prompt = fdisk.expect( [ "Partition number \\(1-4\\):",
      "Selected partition 1" ] )
   if prompt == 0:
      fdisk.sendline( "1" )

   # Create the crash partition if needed
   if crashPartSize:
      # If we are using a fixed size flash, then create the crash partition
      # with the size given in parameter. Otherwise, only leave space for
      # the recovery, because enough room was reserved when we created the flash
      if flashSizeSectors > 0:
         crashFirstSector, crashLastSector = \
            createNewPartition( fdisk, fsPartType, sectorSize=crashSizeSectors )
      else:
         crashFirstSector, crashLastSector = \
            createNewPartition( fdisk, fsPartType, sectorLeft=recovSizeSectors )

   # Create the recovery partition
   recovFirstSector, recovLastSector = \
      createNewPartition( fdisk, '12', sectorSize=recovSizeSectors )

   fdisk.expect( FDISK_COMMAND_LINE )
   fdisk.sendline( "w" )
   fdisk.expect( "The partition table has been altered[!.]" )

   fsStart, fsLength = computeStartAndLength( flashFirstSector, flashLastSector )
   recovStart, recovLength = computeStartAndLength( recovFirstSector,
                                                    recovLastSector )
   if crashPartSize:
      crashStart, crashLength = computeStartAndLength( crashFirstSector,
                                                       crashLastSector )

   parts = [ [ fsStart, fsLength ], [ recovStart, recovLength ] ]

   # Create our fileystem for the flash
   createFileSystem( disk, fsType, fsPartType, fsOpts, tuneCmds, parts[ 0 ],
                     label='eos_flash' )

   # Create our fileystem for the crash partition if it exists
   if crashPartSize:
      crashOpts = fsOpts + [ '-J', 'size=1' ] if fsType == 'ext4' else []
      createFileSystem( disk, fsType, fsPartType, crashOpts, tuneCmds,
                        [ crashStart, crashLength ], label='eos_crash' )

   return parts

def writeRecovFile( disk, parts, swi, bootConfig=None, startupConfig=None,
                    copyTreeFrom=None, zeroTouchConfig=None, kickStartConfig=None,
                    recoverySwi=None, addToRecovery=None ):
   """Write recovery file, populate filesystem"""
   recovSize = makeRecovFile( "/dev/null", swi, bootConfig, startupConfig,
                              zeroTouchConfig, kickStartConfig, recoverySwi,
                              addToRecovery )
   assert recovSize <= parts[1][1]
   loopDev = setupLoop( disk, ["-o", str(parts[1][0])] )
   try:
      makeRecovFile( loopDev, swi, bootConfig, startupConfig, zeroTouchConfig,
                     kickStartConfig, recoverySwi, addToRecovery )
      fsDir = mountHdFs( disk, offset=parts[0][0] )
      try:
         Tac.run( [ "sh", "-c",
                   f"(cd {fsDir}; cpio -i 2>/dev/null) <{loopDev}" ],
                  stdout=Tac.CAPTURE,
                  asRoot=True )
         
         if recoverySwi:
            # If recoverySwi is provided, the cpio cmd above copies it to flash
            # from recovery partition, remove that swi from flash
            Tac.run( [ "rm", "-rf", ( os.path.join( fsDir,
                     os.path.basename( recoverySwi ) ) ) ],
                     stdout=Tac.CAPTURE, asRoot=True )

         if swi:
            # Copy the swi to flash and update boot-config if required
            Tac.run( [ "cp", swi, fsDir ], stdout=Tac.CAPTURE, asRoot=True )
            if recoverySwi and not bootConfig:
               with open( os.path.join( fsDir, "boot-config" ), "w" ) as f:
                  f.write( "SWI=flash:/%s\n" % os.path.basename( swi ) )

         if copyTreeFrom:
            Tac.run( [ "sh", "-c",
                       f"cp -r {copyTreeFrom}/. {fsDir}" ],
                     stdout=Tac.CAPTURE,
                     asRoot=True )
         fixupDirPerms( fsDir )
      finally:
         unmountHdFs( fsDir )
   finally:
      deleteLoop( loopDev )

   # Make sure all data makes it to hardware before returning. This prevents
   # the unfortunate situation where a user decides to power cycle the box
   # immediately after formatting the hd, and ends up with a completely
   # corrupted hd, without even a recovery partition to use.
   fd = os.open( disk, os.O_RDONLY )
   os.fsync( fd )
   os.close( fd )
   Tac.run( [ "/bin/sync" ] )

def makeRecovFile( recovFile, swi, bootConfig=None, startupConfig=None,
                   zeroTouchConfig=None, kickStartConfig=None,
                   recoverySwi=None, addToRecovery=None ):
   """Create a cpio archive comprising these files:
         - a software image (swi), if specified
         - a "startup-config" file.
         - a "boot-config" file.
         - a "zerotouch-config" see http://aid/723
         - a "kickstart-config" file see: http://cl/2455454
      If bootConfig is not specified, then the "boot-config" file points
   to the software image.
      Note that startupConfig must contain the complete contents of the
   startup-config file, including the trailing newline."""

   fsDir = tempfile.mkdtemp()
   try:
      files = []
      if recoverySwi:
         os.symlink( os.path.abspath( recoverySwi ),
                     os.path.join( fsDir, os.path.basename( recoverySwi ) ) )
         files.append( os.path.basename( recoverySwi ) )
      if recoverySwi or swi or bootConfig:
         with open( os.path.join( fsDir, "boot-config" ), "w" ) as f:
            f.write( bootConfig or "SWI=flash:/%s\n" % os.path.basename(
               recoverySwi or swi ) )
         files.append( "boot-config" )
      with open( os.path.join( fsDir, "startup-config" ), "w" ) as f:
         pass
      if startupConfig:
         with open( os.path.join( fsDir, "startup-config" ), "w" ) as f:
            f.write( startupConfig )
      files.append( "startup-config" )
      if kickStartConfig:
         with open( os.path.join( fsDir, "kickstart-config" ), "w" ) as f:
            f.write( kickStartConfig )
         files.append( "kickstart-config" )
      if zeroTouchConfig:
         with open( os.path.join( fsDir, "zerotouch-config" ), "w" ) as f:
            f.write( zeroTouchConfig )
         files.append( "zerotouch-config" )

      if addToRecovery:
         # add file recursively to recovFile
         for ( dirPath, _, fileNames ) in os.walk( addToRecovery ):
            for fileName in fileNames:
               fullFilePath = os.path.join( dirPath, fileName )
               os.symlink( os.path.abspath( fullFilePath ),
                           os.path.join( fsDir, fileName ) )
               files.append( fileName )

      fixupDirPerms( fsDir )
      recovSize = Tac.run( [ "sh", "-c",
                             "(cd %s; cpio -ocL 2>/dev/null) | tee %s | wc -c" %
                             ( fsDir, recovFile ) ],
                           input="\n".join( files ), stdout=Tac.CAPTURE,
                           asRoot=True )

      return max( minPartitionSize, int( recovSize.strip() ) )

   finally:
      shutil.rmtree( fsDir, ignore_errors=True )

def blockDeviceNode( name ): # pylint: disable=inconsistent-return-statements
   with open( "/etc/blockdev" ) as blockdev:
      pat = dict( [ reversed( x.split() ) for x in blockdev ] )[ name ]
   for dev in os.listdir( "/sys/block" ):
      if re.match( "mmcblk.*(boot.*|rpmb)", dev ):
         # eMMC device creates a mmcblk*boot0/1 partitions by default.
         # and Linux allows access to the RPMB partition since 3.8 (090d25f),
         # skip these, as we really want to manage the actual
         # device instead of these partitions
         continue
      devid = os.path.realpath( os.path.join( "/sys/block", dev, "device" ) )
      if devid.startswith( "/sys/devices/" ) and re.match( pat, devid[13:] ):
         return "/dev/%s" % dev
