# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import ctypes
from fcntl import ioctl
import os
import posix
import struct
import time

# default EEPROM size used when specifying
# --skipHeader without using the --eepromSize
# option
DEFAULT_EEPROM_SIZE = 256

# smbus_access read or write markers
I2C_SMBUS_READ = 1
I2C_SMBUS_WRITE = 0

I2C_SLAVE = 0x0703
I2C_SLAVE_FORCE = 0x0706
I2C_SMBUS = 0x0720 # SMBus-level access

I2C_SMBUS_BYTE = 1
I2C_SMBUS_BYTE_DATA = 2
I2C_SMBUS_I2C_BLOCK_DATA = 8

class i2c_smbus_data( ctypes.Union ):
   _fields_ = [ ( 'byte', ctypes.c_uint8 ),
                ( 'word', ctypes.c_uint16 ),
                ( 'block', ctypes.c_uint8 * 34 ) ]

class i2c_smbus_ioctl_data( ctypes.Structure ):
   _fields_ = [ ( 'read_write', ctypes.c_uint8 ),
                ( 'command', ctypes.c_uint8 ),
                ( 'size', ctypes.c_int ),
                ( 'data', ctypes.POINTER( i2c_smbus_data ) ) ]

class TwoByteIdpromHam:
   def __init__( self, busNumber, addr ):
      self.addr = addr
      self.busNumber = busNumber
      self.fd = -1
      self.filename = ''
      self.addrSize = 1

      #open fd
      self.open()

      # setSlaveAddr
      force = False 
      ioctl( self.fd, I2C_SLAVE_FORCE if force else I2C_SLAVE, self.addr )

   def open( self ):
      def tryOpenFile( filename ):
         try:
            self.filename = filename
            return posix.open( filename, os.O_RDWR )
         except OSError:
            return -1
      # pylint: disable-next=consider-using-f-string
      self.fd = tryOpenFile( '/dev/i2c/%d' % self.busNumber )
      if self.fd < 0:
         # pylint: disable-next=consider-using-f-string
         self.fd = tryOpenFile( '/dev/i2c-%d' % self.busNumber )
      if self.fd < 0:
         raise OSError

   # pylint: disable-msg=W0201
   def smbusAccess( self, read_write, command, size, data=None ):
      ar = i2c_smbus_ioctl_data()
      ar.read_write = read_write
      ar.command = command
      ar.size = size
      if not data:
         data = i2c_smbus_data()
      ar.data = ctypes.pointer( data )

      try:
         ioctl( self.fd, I2C_SMBUS, ar, 1 )
      except OSError as e:
         # pylint: disable-next=raise-missing-from,consider-using-f-string
         raise Exception( 'Smbus access to %s failed. Linux reports %r' %
            ( self.filename, str( e ) ) )

      return ar.data

   def writeChunk( self, address, data=b'' ): # pylint: disable=useless-return
      time.sleep( 0.25 )
      addrHigh = ( address >> 8 ) & 0xff
      addrLow = address & 0xff
      smbusData = i2c_smbus_data()
      # Byte count itself not transmitted, used by
      # controller to track buffer size
      smbusData.block[ 0 ] = len( data ) + 1
      smbusData.block[ 1 ] = addrLow
      for offset, byte in enumerate( data ):
         smbusData.block[ 2 + offset ] = byte
      self.smbusAccess( I2C_SMBUS_WRITE, addrHigh, I2C_SMBUS_I2C_BLOCK_DATA,
                        smbusData )
      return

   def writeSequenceStr( self, address, data ):
      chunkSize = 16
      for x in range( 0, len( data ), chunkSize ):
         offset = address + x
         self.writeChunk( offset, data[ x : x + chunkSize ] )

   def setAddressPointer( self, addr ): # pylint: disable=useless-return
      self.writeChunk( addr )
      time.sleep( 0.01 )
      return

   def receiveByte( self ):
      data = self.smbusAccess( I2C_SMBUS_READ, 0, I2C_SMBUS_BYTE )
      return data.contents.byte & 0xff

   def readSequenceStr( self, address, count ):
      self.setAddressPointer( address )
      data = bytearray( count )
      for offset in range( count ):
         data[offset] = self.receiveByte()
      return bytes( data )

class TwoByteIdprom:
   def __init__( self, busNumber, addr ):
      self.ham = TwoByteIdpromHam( busNumber, addr )

   def prefdlLength( self ):
      prefdlHeader = self.ham.readSequenceStr( 0x0, 8 )
      ( fformat , length ) = struct.unpack( ">LL", prefdlHeader )
      if fformat != 3:
         return -1
      return length 

   def fdlLength( self ):
      offset = self.prefdlLength()
      fdlHeader = self.ham.readSequenceStr( offset, 8 )
      ( fformat, length ) = struct.unpack( ">LL", fdlHeader )
      if fformat not in [ 2, 3 ]:
         return -1
      return length 
      
   def readPrefdl( self, skipHeader, eepromSize ):
      length = 0
      offset = 8 # header length
      if not skipHeader:
         # read the prefdl header to determine the length of the prefdl
         length = self.prefdlLength()
      else:
         length = eepromSize
         offset = 0
      if length > 0:
         # read off the rest of the prefdl
         return self.ham.readSequenceStr( offset, length - offset )
      print( "Error: invalid prefdl Length. The prefdl may be corrupted." )
      return b""

   def readFdl( self ):
      prefdlLength = self.prefdlLength()
      fdlLength = self.fdlLength()
      if ( prefdlLength > 0 ) and ( fdlLength > 0 ):
         return self.ham.readSequenceStr( prefdlLength + 8, fdlLength - 8 )
      print( "Error: invalid fdl Length. The fdl may be corrupted." )
      return b""

   def readAll( self, length=0 ):
      if length > 0:
         return self.ham.readSequenceStr( 0, length )

      prefdlLength = self.prefdlLength()
      fdlLength = self.fdlLength()
      if ( prefdlLength > 0 ) and ( fdlLength > 0 ):
         totalLength = prefdlLength + fdlLength
         return self.ham.readSequenceStr( 0, totalLength )

      if prefdlLength < 0:
         print( "Error: invalid prefdl Length. The prefdl may be corrupted." )
      if fdlLength < 0:
         print( "Error: invalid fdl Length. The fdl may be corrupted." )
      return b""

   def unlock( self ):
      pass

   def lock( self ):
      pass

   def writePrefdl( self, prefdlStr, skipHeader, eepromSize ):
      # copy off fdl before writing new prefdl to prevent overwriting
      prefdlLength = self.prefdlLength()
      newPrefdl = prefdlStr.encode( 'utf-8' )
      if not skipHeader:
         # build new prefdl header
         newPrefdlLength = len( newPrefdl ) + 8 # length = prefdl + 8 byte header
         header = struct.pack( ">LL", 0x3, newPrefdlLength )
         newPrefdl = header + newPrefdl
      else:
         # if skipping header, pad out to eeprom bytes
         newPrefdl += bytes( eepromSize - len( newPrefdl ) )
      newEepromContents = newPrefdl
      fdlLength =  self.fdlLength()
      if fdlLength > 0:
         oldFdl = self.ham.readSequenceStr( prefdlLength, fdlLength + 8 )
         newEepromContents += oldFdl
      self.unlock() 
      self.ham.writeSequenceStr( 0x0, newEepromContents )
      self.lock() 

def doRead( busId, seepromId, offset, length ):
   idprom = TwoByteIdprom( busId, seepromId )
   return idprom.ham.readSequenceStr( offset, length )

def doReadPrefdl( busId, seepromId, skipHeader, eepromSize=DEFAULT_EEPROM_SIZE ):
   idprom = TwoByteIdprom( busId, seepromId )
   return idprom.readPrefdl( skipHeader, eepromSize )

def doWrite( busId, seepromId, offset, data, unlock=None ):
   idprom = TwoByteIdprom( busId, seepromId )
   idprom.ham.writeSequenceStr( offset, data )
   verify_data = idprom.ham.readSequenceStr( offset, len( data ) )
   if verify_data != data:
      raise Exception( "Data verification mismatch" )

def doWritePrefdl( busId, seepromId, data, skipHeader,
                   eepromSize=DEFAULT_EEPROM_SIZE ):
   idprom = TwoByteIdprom( busId, seepromId )
   idprom.writePrefdl( data, skipHeader, eepromSize )
   verify_data = idprom.readPrefdl( skipHeader, eepromSize )
   if verify_data.decode()[ : len( data ) ] != data:
      raise Exception( "Prefdl data verification mismatch" )
