# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

from collections import defaultdict

import Ark
import ArPyUtils
import CliModel
import TableOutput

KiB = 1024
MiB = 1024 * KiB
GiB = 1024 * MiB

DEVICE_ORDER = [ "flash", "drive", "usb" ]

def deviceNameKey( name ):
   name = name.rstrip( ':' ) # Remove ":" from the end
   # follow DEVICE_ORDER, then any other drives lexicographically
   return ( tuple( not name.startswith( dev ) for dev in DEVICE_ORDER ) +
            ( ArPyUtils.naturalOrderKey( name ), ) )

def deviceKey( device ):
   name, _ = device
   return deviceNameKey( name )

class EmmcLifetime( CliModel.Model ):
   slcWearLevel = CliModel.Float( help='SLC region fractional design wear remaining',
                                  optional=True )
   mlcWearLevel = CliModel.Float( help='MLC region fractional design wear remaining',
                                  optional=True )
   reservesLevel = CliModel.Float( help='Reserve block fractional level remaining',
                                   optional=True )

   def isEmpty( self ):
      return ( self.slcWearLevel is None and self.mlcWearLevel is None and
               self.reservesLevel is None )

class SmartHealth( CliModel.Model ):
   wearLevel = CliModel.Float( help='Percent design lifetime wear remaining',
                               optional=True )
   reallocationsRemaining = CliModel.Float(
                                       help='Fractional unused reallocations blocks',
                                       optional=True )
   health = CliModel.Str( help='Overall device health', optional=True )

   def isEmpty( self ):
      return ( self.wearLevel is None and self.reallocationsRemaining is None and
               self.health is None )

class StorageHealth( CliModel.Model ):
   emmc = CliModel.Submodel( help='eMMC lifetimes attributes', optional=True,
                             valueType=EmmcLifetime )
   smart = CliModel.Submodel( help='S.M.A.R.T. health attributes', optional=True,
                              valueType=SmartHealth )

class Devices( CliModel.Model ):
   devices = CliModel.Dict( help='Storage devices', valueType=StorageHealth )

   def render( self ):
      fl = TableOutput.Format( justify="left" )
      fr = TableOutput.Format( justify="right" )
      fl.noPadLeftIs( True )
      fl.padLimitIs( True )
      fr.padLimitIs( True )

      table = TableOutput.createTable( [ "Device", "Type", "Health Metric",
                                         "Value" ] )
      table.formatColumns( fl, fl, fl, fr )

      def newPercentRow( table, device, typ, metric, value ):
         table.newRow( device, typ, metric, "%.0f%%" % ( value * 100.0 ) )

      for device, health in sorted( self.devices.items(), key=deviceKey ):
         device += ":"

         typ = "eMMC"
         info = health.emmc
         if info is not None:
            if info.slcWearLevel is not None:
               newPercentRow( table, device, typ, "SLC remaining",
                              info.slcWearLevel )
            if info.mlcWearLevel is not None:
               newPercentRow( table, device, typ, "MLC remaining",
                              info.mlcWearLevel )
            if info.reservesLevel is not None:
               newPercentRow( table, device, typ, "Reserves remaining",
                              info.reservesLevel )

         typ = "SMART"
         info = health.smart
         if info is not None:
            if info.wearLevel is not None:
               newPercentRow( table, device, typ, "Lifetime remaining",
                              info.wearLevel )
            if info.reallocationsRemaining is not None:
               newPercentRow( table, device, typ, "Reallocations remaining",
                              info.reallocationsRemaining )
            if info.health is not None:
               table.newRow( device, typ, "Health status", info.health )

      print( table.output() )

class IOStat( CliModel.Model ):
   read = CliModel.Int( help='Quantity of bytes read' )
   write = CliModel.Int( help='Quantity of bytes written' )
   measurementTimestamp = CliModel.Float(
      help='Timestamp at which last measurement was done' )
   measurementType = CliModel.Enum(
      values=[ 'total', 'boot', 'dayAverage', 'lastDay' ],
      help='Type and length of measurement'
   )

class DevicesIOStat( CliModel.Model ):
   devices = CliModel.Dict( help='Storage devices', valueType=IOStat )

   def _headingsForMeasurementType( self, mType, unitName ):
      if mType == 'total':
         dataHeadingBase = 'Total %s (%s)'
      elif mType == 'boot':
         dataHeadingBase = 'Since Boot %s (%s)'
      elif mType == 'dayAverage':
         dataHeadingBase = 'Average %s (%s/day)'
      else:
         dataHeadingBase = 'Last Day %s (%s)'

      return [
         'Device',
         dataHeadingBase % ( 'Read', unitName ),
         dataHeadingBase % ( 'Write', unitName ),
         'Last Collection',
      ]

   def _computeDisplayUnit( self, devices ):
      '''
      Determine what the best unit is to display the measurements passed in argument
      '''

      biggestMeasurement = 0
      for deviceName in devices:
         device = self.devices[ deviceName ]
         for m in [ device.read, device.write ]:
            if m > biggestMeasurement:
               biggestMeasurement = m

      units = [
         ( KiB, 'KiB' ),
         ( MiB, 'MiB' ),
         ( GiB, 'GiB' ),
      ]
      for unit in units:
         # If the unit generates less than 5 digits numbers, use it
         if ( biggestMeasurement // unit[ 0 ] ) < 10000:
            return unit
      # If we didn't find a unit suitable for our number, just use the biggest we
      # have
      return units[ -1 ]

   def _renderTable( self, mType, devices ):
      fl = TableOutput.Format( justify="left" )
      fr = TableOutput.Format( justify="right" )
      fl.noPadLeftIs( True )
      fl.padLimitIs( True )
      fr.padLimitIs( True )

      unitDivider, unitName = self._computeDisplayUnit( devices )

      headings = self._headingsForMeasurementType( mType, unitName )
      table = TableOutput.createTable( headings )
      table.formatColumns( fl, fr, fr, fr )

      for deviceName in sorted( devices, key=deviceNameKey ):
         deviceIOStat = self.devices[ deviceName ]
         # The CLI displayes devices in the file system notation. i.e. with a colon
         # at the end
         table.newRow( '%s:' % deviceName,
                       '%.3f' % ( deviceIOStat.read / unitDivider ),
                       '%.3f' % ( deviceIOStat.write / unitDivider ),
                       Ark.timestampToStr( deviceIOStat.measurementTimestamp ) )

      return table.output()

   def render( self ):
      devicesByMeasurementType = defaultdict( list )

      for ( fsName, device ) in self.devices.items():
         devicesByMeasurementType[ device.measurementType ] += [ fsName ]

      for mType in devicesByMeasurementType:
         print( self._renderTable( mType, devicesByMeasurementType[ mType ] ) )

class SmartAttr( CliModel.Model ):
   name = CliModel.Str( help='SMART attribute name' )
   value = CliModel.Int( help='SMART attribute value', optional=True )
   unit = CliModel.Str( help='Units for value', optional=True )
   status = CliModel.Enum( values=[ "ok", "failed", "n/a" ],
         help='SMART attribute status' )

class DeviceSmartAttrs( CliModel.Model ):
   attributes = CliModel.List( help='SMART attribute list', valueType=SmartAttr )

class SmartAttrs( CliModel.Model ):
   devices = CliModel.Dict( help='SMART attributes, keyed by device name',
                            valueType=DeviceSmartAttrs )

   def render( self ):
      fl = TableOutput.Format( justify="left" )
      fr = TableOutput.Format( justify="right" )
      fl.noPadLeftIs( True )
      fl.padLimitIs( True )
      fr.padLimitIs( True )

      table = TableOutput.createTable( [ "Device", "SMART Attribute", "Value",
                                         "Unit", "Status" ] )
      table.formatColumns( fl, fl, fr, fl, fl )

      for device, attrs in sorted( self.devices.items(), key=deviceKey ):
         device += ":"
         for attr in attrs.attributes:
            value = attr.value if attr.value is not None else ""
            unit = attr.unit if attr.unit is not None else ""
            table.newRow( device, attr.name, value, unit, attr.status )

      print( table.output() )

class SmartAttrDetail( CliModel.Model ):
   __revision__ = 2
   id = CliModel.Int( help='SMART attribute ID', optional=True )
   name = CliModel.Str( help='SMART attribute name' )
   normalized = CliModel.Int( help='SMART normalized value', optional=True )
   worst = CliModel.Int( help='SMART worst normalized value', optional=True )
   threshold = CliModel.Int( help='SMART threshold (lowest acceptable)',
         optional=True )
   raw = CliModel.Int( help='SMART raw value' )

   def degrade( self, dictRepr, revision ):
      if revision == 1:
         # id, normalized, worst, threshold are non optional, so we maybe set
         # it to -1.
         dictRepr[ 'id' ] = dictRepr.get( 'id', -1 )
         dictRepr[ 'normalized' ] = dictRepr.get( 'normalized', -1 )
         dictRepr[ 'worst' ] = dictRepr.get( 'worst', -1 )
         dictRepr[ 'threshold' ] = dictRepr.get( 'threshold', -1 )

      return dictRepr

class DeviceSmartAttrDetails( CliModel.Model ):
   attributes = CliModel.List( help='SMART attribute list',
                               valueType=SmartAttrDetail )

class SmartAttrDetails( CliModel.Model ):
   __revision__ = 2
   devices = CliModel.Dict( help='SMART attributes, keyed by device',
                            valueType=DeviceSmartAttrDetails )

   def render( self ):
      fl = TableOutput.Format( justify="left" )
      fr = TableOutput.Format( justify="right" )
      fl.noPadLeftIs( True )
      fl.padLimitIs( True )
      fr.padLimitIs( True )

      table = TableOutput.createTable( [ "Device", "ID", "SMART Attribute",
                                         "Normalized", "Worst", "Threshold",
                                         "Raw" ] )
      table.formatColumns( fl, fr, fl, fr, fr, fr )

      for device, attrs in sorted( self.devices.items(), key=deviceKey ):
         device += ":"
         for attr in attrs.attributes:
            if not attr.id and not attr.normalized and not attr.worst and \
                  not attr.threshold:
               table.newRow( device, 'n/a', attr.name, 'n/a', 'n/a', 'n/a',
                     attr.raw )
            else:
               table.newRow( device, attr.id, attr.name, attr.normalized,
                     attr.worst, attr.threshold, attr.raw )

      print( table.output() )
