# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function
import json
import os
import re
import Logging
import MmcFlashLib
import EusbLib
import Tac
import Tracing

traceHandle = Tracing.Handle( 'StorageInvLib' )
t0 = Tracing.trace0

FRU_SSD_DETECTION_ERROR = Logging.LogHandle(
              "FRU_SSD_DETECTION_ERROR",
              severity=Logging.logError,
              fmt="Failed to detect SSD on bootup. Error msg: %s ",
              explanation="Failed to detect SSD during system bootup. This can "
               "happen if the device is not recognized by the driver or if the "
               "system encountered an error when booting up.",
              recommendedAction=Logging.CONTACT_SUPPORT )

FRU_SSD_PARTITION_NOT_MOUNTED = Logging.LogHandle(
              "FRU_SSD_PARTITION_NOT_MOUNTED",
              severity=Logging.logError,
              fmt="SSD partition %s not mounted as %s",
              explanation=( "The specified SSD partition is not currently mounted. "
                            "This can happen if the partition is not properly "
                            "formatted or if the SSD device was not detected by the "
                            "driver."),
              recommendedAction=Logging.CONTACT_SUPPORT )

FRU_SSD_UNEXPECTED_PARTITION_MOUNTED = Logging.LogHandle(
              "FRU_SSD_UNEXPECTED_PARTITION_MOUNTED",
              severity=Logging.logError,
              fmt=( "Unexpected partition %s mounted as %s. Expected "
                    "SSD partition %s to be mounted instead." ),
              explanation=( "Expected the specified SSD partition to be mounted at "
                            "the mount point but a different partition is mounted"
                            "instead. This is unexpected." ),
              recommendedAction=Logging.CONTACT_SUPPORT )

class Mounts( object ):
   def __init__( self ):
      # devices indexed by mount
      self.devices = {}
      self.errMsg = None
      self.parse()

   def getDevice( self, mount ):
      return self.devices.get( mount )

   def parse( self ):
      try:
         with open( '/proc/mounts' ) as f:
            output = f.readlines()
      except IOError:
         self.errMsg = "Failed to get list of mounted filesystems"
         return

      for line in output:
         m = re.match( r'(\S+)\s+(\S+)\s+', line )
         if m:
            self.devices[ m.group( 2 ) ] = m.group( 1 )

class PciEntries( object ):
   def __init__( self ):
      self.pciEntries = {}
      self.nonPciEntries = {}
      self.errMsg = None
      self.parse()

   def getPciEntry( self, mount ):
      return self.pciEntries.get( mount )

   def getNonPciEntry( self, mount ):
      return self.nonPciEntries.get( mount )

   def parse( self ):
      try:
         with open( '/etc/blockdev' ) as f:
            output = f.readlines()
      except IOError:
         self.errMsg = "Failed to parse /etc/blockdev"
         return

      for line in output:
         m = re.match( r'((pci|platform/AMDI)?\S+)\s+(\S+)', line )
         if m:
            if m.group( 2 ):
               self.pciEntries[ m.group( 3 ) ] = m.group( 1 )
            else:
               self.nonPciEntries[ m.group( 3 ) ] = m.group( 1 )

class DeviceFile( object ):
   EMMC_FILTER = r'e?mmcblk[0-9]+$'
   HDASDA_FILTER = r'[hs]d[a-z]'
   NVME_FILTER = r'nvme\d+n\d+'
   SUPPORTED_DEV_FILTER = EMMC_FILTER + '|' + HDASDA_FILTER + '|' + NVME_FILTER

   def __init__( self, name ):
      self.name = name
      self.isRotational = None
      self.sizeBytes = None
      self.parseBlockInfo()

   def parseBlockInfo( self ):
      blockDir = os.path.join( '/sys/block', self.name )
      blockSize = 512
      try:
         rotFile = os.path.join( blockDir, 'queue/rotational' )
         sizeFile = os.path.join( blockDir, 'size' )
         with open( rotFile ) as f:
            output = f.readline()
            self.isRotational = bool( int( output ) )
         with open( sizeFile ) as f:
            output = f.readline()
            self.sizeBytes = int( output ) * blockSize
      except ( IOError, ValueError ):
         pass

   def factory( self, mounts, mount ):
      if re.match( self.EMMC_FILTER, self.name ):
         return EmmcFlash( mounts, mount, self )
      if re.match( self.HDASDA_FILTER, self.name ):
         if self.isRotational:
            return Eusb( mounts, mount, self )
         else:
            return Ssd( mounts, mount, self )
      if re.match( self.NVME_FILTER, self.name ):
         return Nvme( mounts, mount, self )
      return None

def bashReMatch( regex, text ):
   try:
      Tac.run( [ "expr", "match", text, regex ] )
   except Tac.SystemCommandError:
      return False
   return True

class DeviceFiles( object ):
   """The DeviceFiles parse function uses the following method to determine
   the device file for a given mount:
      1) Get the SSD PCI address from /etc/blockdev. The PciEntries
         class is used for this.
      2) Search /sys/block for a device with the matching PCI address:
          ie. readlink -f /sys/block/<dev>/device
   This is the same method used by EOS-initscripts to mount SSDs"""
   def __init__( self, mounts, deviceFilter=DeviceFile.SUPPORTED_DEV_FILTER ):
      self.mounts = mounts
      self.deviceFilter = deviceFilter
      self.blockdirs = set()
      self.pciEntries = PciEntries()
      self.deviceFiles = {}
      self.errMsg = None
      if self.mounts.errMsg:
         self.errMsg = self.mounts.errMsg
      elif self.pciEntries.errMsg:
         self.errMsg = self.pciEntries.errMsg
      else:
         self.parse()

   def getDeviceFile( self, mount ):
      return self.deviceFiles.get( mount )

   def getDevicePath( self, mount ):
      return os.path.join( '/dev', self.getDeviceFile( mount ) )

   def deviceFactory( self, mount, mountExpected=False ):
      if self.errMsg:
         Logging.log( FRU_SSD_DETECTION_ERROR, self.errMsg )
         return None

      deviceFile = self.deviceFiles.get( mount )
      if not deviceFile:
         if mountExpected:
            if self.pciEntries.getNonPciEntry( mount ):
               errMsg = 'PCI address for drive not found in /etc/blockdev'
            elif not self.pciEntries.getPciEntry( mount ):
               errMsg = 'SSD drive not found in /etc/blockdev'
            else:
               errMsg = 'Device not found'
            Logging.log( FRU_SSD_DETECTION_ERROR, errMsg )
         return None

      mountDir = os.path.join( '/mnt', mount )
      if not self.mounts.getDevice( mountDir ):
         if mountExpected:
            expectedPartition = os.path.join( '/dev', deviceFile.name + '1' )
            Logging.log( FRU_SSD_PARTITION_NOT_MOUNTED, expectedPartition, mountDir )
         return None

      device = deviceFile.factory( self.mounts, mount )
      if not device:
         Logging.log( FRU_SSD_DETECTION_ERROR, self.errMsg )
         return None

      return device

   def parse( self ):
      if os.path.exists( "/etc/multidrive_src" ):
         # Devices that support multiple boot drives will have the
         # /etc/multidrive_src file with a regex that matches the current boot drive.
         with open( "/etc/multidrive_src", "r" ) as srcFile:
            bootDrive = srcFile.read().strip()
         with open( "/etc/multidrive", "r" ) as multiFile:
            bootDrives = [ line for line in map( str.strip, multiFile ) if line ]
      else:
         bootDrive = None

      for devName in os.listdir( '/sys/block' ):
         if not re.match( self.deviceFilter, devName ):
            continue
         deviceFile = os.path.join( '/sys/block', devName, 'device' )
         cmdArgList = [ '/usr/bin/readlink', '-f', deviceFile ]
         try:
            devPciAddress = Tac.run( cmdArgList, asRoot=True, stdout=Tac.CAPTURE )
         except Tac.SystemCommandError:
            continue

         if devPciAddress.startswith( "/sys/devices/" ):
            devPciAddress = devPciAddress[ len( "/sys/devices/" ) : ]

         if ( bootDrive is not None and
              any( bashReMatch( regex, devPciAddress ) for regex in bootDrives ) and
              not bashReMatch( bootDrive, devPciAddress ) ):
            # The system supports multiple boot drives, and this drive is not the one
            # that was booted.
            continue

         # Find the mount for this address.
         for mount in self.pciEntries.pciEntries:
            blockdevRe = self.pciEntries.pciEntries[ mount ]
            # The block_flash entry for whiteboxes has an additional string at
            # the end that doesn't match the result of readlink -f. Strip it
            # out if found.
            m = re.match( r'(.*)/\d*\.?\*?\$?$', blockdevRe )
            if m:
               blockdevRe = m.group( 1 )
            if bashReMatch( blockdevRe, devPciAddress ):
               self.deviceFiles[ mount ] = DeviceFile( devName )

class HdParmInfo( object ):
   def __init__( self, device, defaultResults ):
      self.device = device
      self.results = defaultResults
      self.errMsg = None
      self.parse()

   def getResults( self ):
      return self.results

   def parse( self ):
      modelSerialFirmwareRegex = ( r'\s*Model Number:\s+([\w -\.]*)\n'
                                   r'\s*Serial Number:\s+([\w -]*)\n'
                                   r'\s*Firmware Revision:\s+([\w -\/]*)\n' )
      sizeRegex = r'\((\d+) GB\)'
      try:
         output = Tac.run( [ '/sbin/hdparm',  '-I',  self.device ], asRoot=True,
                           stdout=Tac.CAPTURE )
      except Tac.SystemCommandError:
         self.errMsg = "hdparm failed for %s" % ( self.device )
         return
      except OSError:
         self.errMsg = "hdparm command not found"
         return

      m = re.search( modelSerialFirmwareRegex, output )
      sizeM = re.search( sizeRegex, output )
      if not m or not sizeM:
         self.errMsg = "Unexpected output from hdparm command"
         return

      self.results[ 'Model' ] = m.group( 1 ).strip()
      self.results[ 'SerialNo' ] = m.group( 2 ).strip()
      self.results[ 'FwRev' ] = m.group( 3 ).strip()
      self.results[ 'SizeGB' ] = int( sizeM.group( 1 ) )

      writeCacheRegExp = 'WriteCache=((en|dis)abled)'
      m = re.search( writeCacheRegExp, output )
      if m and m.group( 1 ) == 'disabled':
         self.results[ 'WriteCache' ] = 0

class Device( object ):
   BYTES_PER_GB = 1000 ** 3

   def __init__( self, mounts, mount, deviceFile ):
      self.mounts = mounts
      self.mount = mount
      self.device = deviceFile.name
      self.deviceFile = deviceFile
      self.devicePath = os.path.join( '/dev', self.device )
      self.paramDict = {}
      self.paramDict[ 'Model' ] = 'Unknown'
      self.paramDict[ 'SerialNo' ] = 'Unknown'
      self.paramDict[ 'FwRev' ] = '0.0'
      self.paramDict[ 'SizeGB' ] = 0
      if self.mount:
         self.readData()

   def mibDeviceType( self ):
      return 'unknown'

   def readData( self ):
      assert False, 'readData not implemented for class %s' % (
         self.__class__.__name__ )

   def checkMounts( self, expectedPartitionRe=None, expectedPartitionDesc=None ):
      expectedMountedPartition = expectedPartitionDesc
      if not expectedMountedPartition:
         expectedMountedPartition = self.devicePath + '*'
      if not expectedPartitionRe:
         expectedPartitionRe = self.devicePath + r'\d+'
      mountDir = os.path.join( '/mnt', self.mount )
      mountedPartition = self.mounts.getDevice( mountDir )
      if not mountedPartition:
         # This can happen if the FS partition on the device is accidentally
         # deleted or if the partition was never actually created.
         Logging.log( FRU_SSD_PARTITION_NOT_MOUNTED, expectedMountedPartition,
                      mountDir )
      elif not re.match( expectedPartitionRe, mountedPartition ):
         # On platforms using LVM, the partitions are not tied to physical block
         # devices, so don't print out the UNEXPECTED_PARTITION_MOUNTED log message.
         lvmPartitionRe = os.path.join( os.sep, 'dev', r'dm-\d+' )
         if not re.match( lvmPartitionRe, mountedPartition ):
            # This should never happen. An assert would do. But since all exceptions
            # in the FruPlugin are caught, its better to just check and syslog here.
            Logging.log( FRU_SSD_UNEXPECTED_PARTITION_MOUNTED, mountedPartition,
                         mountDir, expectedMountedPartition )

   def populateInventory( self, mibEnt, invSizeGB=0 ):
      mibEnt.modelName = self.paramDict[ 'Model' ]
      mibEnt.serialNum = self.paramDict[ 'SerialNo' ]
      mibEnt.firmwareRev = self.paramDict[ 'FwRev' ]
      # If sizeGB = 0, use the autodetected size instead of the hardcoded.
      # Do NOT use automatic size detecton with modular supervisors
      # There is no way to determine the real size of the peer sup's ssd
      mibEnt.sizeGB = (
         self.paramDict[ 'SizeGB' ] if invSizeGB is 0 else invSizeGB )
      mibEnt.storageType = self.mibDeviceType()
      mibEnt.mount = os.path.join( '/mnt', self.mount )

class Ssd( Device ):
   def __init__( self, mounts, mount, deviceFile ):
      super( Ssd, self ).__init__( mounts, mount, deviceFile )
      self.paramDict[ 'WriteCache' ] = 1

   def mibDeviceType( self ):
      return 'ssd'

   def readData( self ):
      self.checkMounts()

      self.paramDict[ 'Device' ] = self.devicePath
      hdParmInfo = HdParmInfo( self.devicePath, self.paramDict )
      self.paramDict.update( hdParmInfo.getResults() )
      if hdParmInfo.errMsg:
         Logging.log( FRU_SSD_DETECTION_ERROR, hdParmInfo.errMsg )

class Eusb( Device ):
   def mibDeviceType( self ):
      return 'eUsb'

   def readData( self ):
      assert self.deviceFile.sizeBytes

      cmd = [ 'udevadm', 'info', self.devicePath, '--query=property' ]
      udevOutput = Tac.run( cmd, asRoot=True, stdout=Tac.CAPTURE )
      udevInfo = EusbLib.parseUdevOutput( udevOutput )

      self.paramDict[ 'FwRev' ] = EusbLib.getRevision( udevInfo )
      self.paramDict[ 'SerialNo' ] = EusbLib.getSerial( udevInfo )
      self.paramDict[ 'Model' ] = EusbLib.getModel( udevInfo )
      # pylint: disable=round-builtin
      self.paramDict[ 'SizeGB' ] = int( round( float( self.deviceFile.sizeBytes )
                                        / self.BYTES_PER_GB ) )

class EmmcFlash( Device ):
   def mibDeviceType( self ):
      return 'eMmc'

   def readData( self ):
      self.paramDict[ 'Device' ] = self.devicePath
      cid = MmcFlashLib.readCidRegister( self.device )
      if not cid:
         return

      sizeGb = MmcFlashLib.readDeviceSize( self.device )
      if not sizeGb:
         return

      manufacturerName = MmcFlashLib.manufacturerName( cid )
      if manufacturerName == 'Unknown':
         manufacturerName = MmcFlashLib.manufacturerId( cid )

      self.paramDict[ 'SerialNo' ] = MmcFlashLib.serialNumber( cid )
      self.paramDict[ 'SizeGB' ] = sizeGb
      self.paramDict[ 'FwRev' ] = MmcFlashLib.firmwareRevision( cid )
      self.paramDict[ 'Model' ] = '%s %s' % ( manufacturerName,
         MmcFlashLib.productName( cid ) )

class Nvme( Device ):
   def mibDeviceType( self ):
      return 'nvme'

   def getDeviceInfo( self ):
      try:
         output = Tac.run( [ '/usr/sbin/nvme',  'list',  '-o', 'json' ],
                           asRoot=True, stdout=Tac.CAPTURE )
      except ( Tac.SystemCommandError, OSError ) as e:
         return None, "nvme list failed: %s" % e

      try:
         jsonData = json.loads( output )
      except ValueError as e:
         return None, "Failed to parse nvme data"

      if 'Devices' not in jsonData:
         return None, "Devices not found in nvme data"

      for deviceInfo in jsonData[ 'Devices' ]:
         if deviceInfo.get( 'DevicePath' ) == self.devicePath:
            return deviceInfo, None

      return None, "Device %s not found in nvme data" % self.devicePath

   def readData( self ):
      self.checkMounts( expectedPartitionRe=self.devicePath + r'[p\d+]?' )

      deviceInfo, errorMsg = self.getDeviceInfo()
      if not deviceInfo:
         Logging.log( FRU_SSD_DETECTION_ERROR, errorMsg )
         return

      self.paramDict[ 'Device' ] = self.devicePath
      self.paramDict[ 'Model' ] = str( deviceInfo.get( 'ModelNumber',
                                                       'Unknown' ) )
      self.paramDict[ 'SerialNo' ] = str( deviceInfo.get( 'SerialNumber',
                                                          'Unknown' ) ).strip()
      self.paramDict[ 'FwRev' ] = str( deviceInfo.get( 'Firmware', '0.0' ) )
      # pylint: disable=round-builtin
      self.paramDict[ 'SizeGB' ] = int(
         round( float( deviceInfo.get( 'PhysicalSize', 0 ) )
                / self.BYTES_PER_GB ) )
