# Copyright (c) 2006-2010, 2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function
import os
import Plugins
import re
import StorageInvLib
import Tac
from FirmwareUtils import FirmwareVersion, Inventory, Device, Tee, Marker
import six

fw_path = '/usr/share/firmware'
regex = r'NOTE "ARISTA_FW_VERSION" (.*);\nNOTE "ARISTA_FW_CKSUM" (.*);\n'

class ParserError( Exception ):
   pass

class BinHeader ( object ):
   def __init__( self, model, version, md5sum, abin ):
      self.model = model
      self.version = version
      self.md5sum = md5sum
      self.abin = abin

class StorageInventory ( Inventory ):
   def __init__( self ):
      self.devices = []
      self.models = {}
      super( StorageInventory, self ).__init__()

   def findDevices( self ):
      self.devices = []
      mounts = StorageInvLib.Mounts()
      devices = StorageInvLib.DeviceFiles( mounts )
      possibleMounts = [ 'flash', 'drive', 'drive2' ]
      for mount in possibleMounts:
         device = devices.deviceFactory( mount )
         if device and device.mibDeviceType() == 'ssd':
            model = device.paramDict[ 'Model' ].replace( ' ', '_' )
            device.paramDict[ 'Model' ] = model
            self.devices.append( device )

   def populate( self ):
      tee = self.tee
      self.findDevices()
      for device in self.devices:
         name = device.paramDict[ 'Model' ]
         dev = device.paramDict[ 'Device' ]
         fwRev = device.paramDict[ 'FwRev' ]
         aBinf = getABinFile( name, fw_path )
         if os.path.exists( aBinf ):
            self.inventory.append( self.models[ name ]( dev, name, fwRev ) )
         else:
            msg = ( "Automatic firmware upgrade for storage model %s is not "
                  "implemented." % name )
            tee.out( msg )
            continue

   def reset( self ):
      Tac.runActivities( 0.5 )
      Tac.run( [ "reboot", ] )
      Tac.runActivities( 1 )

def _getHeader( abin ):
   _header = b''
   with open( abin, 'rb' ) as f:
      _header = f.read( 100 )
   return _header.split( b'\n' )[ : 2 ]

def getHeader( abin ):
   return b"\n".join( _getHeader( abin ) + [ b'' ] )

def getHeaderLen( abin ):
   header = _getHeader( abin )
   return len( header[ 0 ] ) + len( header[ 1 ] ) + 1

def parseHeader( header, model ):
   m = re.match( regex, six.ensure_str( header ) )
   if not m:
      raise ParserError
   version = m.group( 1 )
   md5sum = m.group( 2 )
   return BinHeader( model, version, md5sum, '' )

def getABinFile( name, path ):
   aBinFile = '%s/%s.abin' % ( path, name.replace( ' ', '_' ) )
   return aBinFile

def readHeader( name, path ):
   abin = getABinFile( name, path )
   header = getHeader( abin )
   headerObj = parseHeader( header, name )
   return headerObj

def readBin( name, path ):
   abin = getABinFile( name, path )
   hlen = getHeaderLen( abin )
   header = readHeader( name, path )

   binf = ( "/tmp/%s_%s.bin" % ( header.version, header.model ) )
   if os.path.exists( binf ):
      os.remove( binf )
   with open( abin, 'rb' ) as f:
      f.seek( hlen + 1 )
      while True:
         abinData = f.read( 10240 )
         if not abinData:
            break
         with open( binf, mode='ab' ) as fw:
            fw.write( abinData )
   return binf

class StorageFirmwareVersion( FirmwareVersion ):
   def __init__( self, version ):
      self.version = version.rstrip()

   def __int__( self ):
      raise NotImplementedError

   def __eq__( self, other ):
      return int( self ) == int( other )

   def __ne__( self, other ):
      return not self.__eq__( other )

   def __lt__( self, other ):
      return int( self ) < int( other )

   def __key( self ):
      return self.version

   def __hash__( self ):
      return hash( self.__key() )

   def __str__( self ):
      return self.version


class StorageDevice ( Device ):
   def init( self, dev, name, fwRev ):
      self.name = name
      self.dev = dev
      self.fwRev = fwRev
      self.tee = Tee()

   def __init__( self, dev, name, fwRev ):
      self.init( dev, name, fwRev )
      super( StorageDevice, self ).__init__( name )

   def needsUpgrade( self ):
      hwVersion = self.getHwVersion()
      localVersion = self.getLocalVersion()
      return hwVersion < localVersion

   def getUpgradeCommand( self ):
      # implement it for each vendor
      raise NotImplementedError

   def getLocalVersion( self ):
      raise NotImplementedError

   def getHwVersion( self ):
      raise NotImplementedError

   def getDescription( self ):
      return "Storage Device %s" % self.name

   def _upgrade( self, upgradeCmd ):
      tee = self.tee
      marker = Marker( "/mnt/flash/%s_skip_upgrade" % self.name )
      if marker.exists():
         tee.out( "StorageDevices FirmwareUtils: Upgrade command is %s"
               % upgradeCmd )
         return
      Tac.run( upgradeCmd, asRoot=True )

   def upgrade( self ):
      upgradeCmd = self.getUpgradeCommand()
      self._upgrade( upgradeCmd )

   def isRebootReq( self ):
      return True

def Plugin( context ):
   siObj = StorageInventory()
   Plugins.loadPlugins( "StoragePlugin", context=siObj )
   context.register( siObj )
