# Copyright (c) 2007-2010, 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import errno
import mmap
import os
import re
import sys

import _MemVolatileAccess

class Address:
   """Represents a PCI device address as a (domain, bus, slot, function) tuple."""

   def __init__( self, domain=0, bus=0, slot=0, function=0 ):
      """Construct an Address from separate domain, bus, slot, function values.
      Also accepts a single string with format [[DDDD:]BB:]SS[.F], an
      Inventory::PciAddress value, or an Address.

      The Inventory::PciAddress only supports 16-bit domains, but with string
      and numeric arguments 32-bit domains are supported.  The value() method
      will throw a TypeError if you try to get the PciAddress when the domain
      is larger than 16 bits."""
      try:
         domain, bus, slot, function = re.match(
            r"(?:(?:(\w+):)?(\w+):)?(\w+)(?:\.(\w+))?$", domain ).groups()
      except TypeError:
         pass
      # handle first arg of type Inventory::PciAddress without dragging in Tac
      self.domain, self.bus, self.slot, self.function = (
         getattr( domain, "domain", _hexnum( domain ) ),
         getattr( domain, "bus", _hexnum( bus ) ),
         getattr( domain, "slot", _hexnum( slot ) ),
         getattr( domain, "function", _hexnum( function ) ) )
      assert self.domain >= 0 and self.domain <= 0xffffffff
      assert self.bus >= 0 and self.bus <= 0x100
      assert self.slot >= 0 and self.slot <= 0x1f
      assert self.function >= 0 and self.function <= 7

   def value( self ):
      """Convert the address to an Inventory::PciAddress value."""
      import Tac # pylint: disable=import-outside-toplevel
      # If you get a TypeError here, like:
      # Tac::TypeException("Error processing constructor parameter 0 ('domain'):
      #     expected U16, got int: 65536")
      # the domain is a 32-bit value, which is valid in some systems,
      # but the Inventory types do not support 32-bit values.  Consider
      # catching the TypeError and ignoring this device, or filtering the
      # items before you pass them to Address().
      return Tac.Value(
         "Inventory::PciAddress",
         domain=self.domain, bus=self.bus, slot=self.slot, function=self.function )

   def devfn( self ):
      """Return the slot and function packed into a single byte."""
      return ( self.slot << 3 ) | self.function

   def __str__( self ):
      return "{:04x}:{:02x}:{:02x}.{:01x}".format( self.domain, self.bus,
                                                   self.slot, self.function )

   def __repr__( self ):
      return "Pci.Address( %s )" % self

   # __hash__ and __cmp__ are needed for set element identity
   def __hash__( self ):
      return hash( repr( self ) )

   def __eq__( self, other ):
      return repr( self ) == repr( other )

   def __ne__( self, other ):
      return repr( self ) != repr( other )

   def __lt__( self, other ):
      return repr( self ) < repr( other )

   def __le__( self, other ):
      return repr( self ) <= repr( other )

   def __gt__( self, other ):
      return repr( self ) > repr( other )

   def __ge__( self, other ):
      return repr( self ) >= repr( other )

class Id:
   """Represents a PCI device or subsystem ID as a (vendor, device) tuple."""

   def __init__( self, vendor=0, device=0 ):
      """Construct an Id from separate vendor, device values. Also accepts a
      single string with format VVVV:DDDD, an Inventory::PciId value, or an Id."""
      try:
         vendor, device = re.match( r"(\w+):(\w+)$", vendor ).groups()
      except TypeError:
         pass
      # handle first arg of type Inventory::PciId without dragging in Tac
      self.vendor, self.device = (
         getattr( vendor, "vendor", _hexnum( vendor ) ),
         getattr( vendor, "device", _hexnum( device ) ) )
      assert self.vendor >= 0 and self.vendor <= 0xffff
      assert self.device >= 0 and self.device <= 0xffff

   def value( self ):
      """Convert the ID to an Inventory::PciId value."""
      import Tac # pylint: disable=import-outside-toplevel
      return Tac.Value( "Inventory::PciId", vendor=self.vendor, device=self.device )

   def __str__( self ):
      return f"{self.vendor:04x}:{self.device:04x}"

   def __repr__( self ):
      return "Pci.Id( %s )" % self

   # __hash__ and __cmp__ are needed for set element identity
   def __hash__( self ):
      return hash( repr( self ) )

   def __eq__( self, other ):
      return repr( self ) == repr( other )

   def __ne__( self, other ):
      return repr( self ) != repr( other )

   def __lt__( self, other ):
      return repr( self ) < repr( other )

   def __le__( self, other ):
      return repr( self ) <= repr( other )

   def __gt__( self, other ):
      return repr( self ) > repr( other )

   def __ge__( self, other ):
      return repr( self ) >= repr( other )

class Device:
   """Represents a PCI device and provides convenient access to sysfs properties."""

   def __init__( self, address ):
      """Construct a PCI device for the given address.  Attempts to convert
      address to an Address object if necessary."""
      self.address_ = Address( address )
      # pylint: disable-next=redefined-outer-name
      sys = os.environ.get( 'SIMULATION_SYS', '/sys' )
      self.sysfsBase_ = os.path.join( sys,"bus/pci/devices", str(self) )

   def address( self ):
      """Return the device's address as an Address object."""
      return self.address_

   def devfn( self ):
      """Return the device's slot and function packed into a single byte."""
      return self.address_.devfn()

   def sysfsPath( self, p ):
      """Return the full path to a sysfs property for this device."""
      return os.path.join( self.sysfsBase_, p )

   def id( self ):
      """Return the device's vendor and device ID as an Id object."""
      try:
         return Id( _hexnum( open( self.sysfsPath( "vendor" ) ).read() ),
                    _hexnum( open( self.sysfsPath( "device" ) ).read() ) )
      except OSError:
         # With hotplug this can fail if the device is in the process of being
         # detected/removed when we scan for Pci devices.
         return None

   def subsystemId( self ):
      """Return the device's subsystem vendor and device ID as an Id object."""
      try:
         return Id( _hexnum( open( self.sysfsPath( "subsystem_vendor" ) ).read() ),
                    _hexnum( open( self.sysfsPath( "subsystem_device" ) ).read() ) )
      except OSError:
         # With hotplug this can fail if the device is in the process of being
         # detected/removed when we scan for Pci devices.
         return None

   def classCode( self ):
      """Return the device's class code as an int."""
      try:
         return _hexnum( open( self.sysfsPath( "class" ) ).read() )
      except OSError:
         # With hotplug this can fail if the device is in the process of being
         # detected/removed when we scan for Pci devices.
         return None

   # pylint: disable-next=inconsistent-return-statements
   def resource( self, index, readOnly=False, filename=None,
         startOffset=None, endOffset=None):
      """Return a Resource object for one of the device's I/O or memory resources."""
      if filename == None: # pylint: disable=singleton-comparison
         filename = "resource%d" % index
      p = self.sysfsPath( filename )
      if os.path.exists( p ):
         return MmapResource( p, readOnly, startOffset, endOffset )

   # pylint: disable-next=inconsistent-return-statements
   def config( self, readOnly=False ):
      """Return a Resource object for the device's configuration space."""
      if os.getuid():
         sys.stderr.write( "WARNING: You are not running as root. You may only be "
                           "able to access the first 64 bytes of PCI configuration "
                           "space.\n" )
      p = self.sysfsPath( "config" )
      if os.path.exists( p ):
         return PseudoMmapResource( p, readOnly )

   def __str__( self ):
      return str( self.address_ )

   def __repr__( self ):
      return "Pci.Device( %s )" % self

def allDevices():
   """Return a list of all the PCI devices in the system as reported by sysfs."""
   # pylint: disable-next=redefined-outer-name
   sys = os.environ.get( "SIMULATION_SYS", "/sys" )
   pciDeviceDir = os.path.join( sys, "bus/pci/devices" )
   return [ Device( Address( x ) ) for x in os.listdir( pciDeviceDir ) ]

def deviceById( id ): # pylint: disable=redefined-builtin
   """Return a Device object for a system device with the specified ID."""
   id = Id( id )
   deviceList = ( [ d for d in allDevices() if d.id() == id ] or [None] )
   # cannot assert here, NorCalInit requires this to work and on modular there
   # are multiple scds if the system rebooted with cards powered on, we are
   # lucky that the supe scd is always first in the list
   #assert len( deviceList ) == 1, "More than one device found with the given id"
   return deviceList[ 0 ]

def allDevicesById( id ): # pylint: disable=redefined-builtin
   """Return all the device objects for a system device with the specified ID."""
   id = Id( id )
   deviceList = ( [ d for d in allDevices() if d.id() == id ] or [None] )
   return deviceList

class Resource:
   """Abstract class representing a PCI memory space (config or resource).  Provides
   functions for reading and writing values."""

   # Common attribute which needs to be set by the derived class:
   # This needs to be set because there is existing code which will directly access
   # the mmap_ object instead of calling mmap()...
   mmap_ = None

   # Common attributes which may be set by the derived class:
   # Causes write8/16/32 to raise ValueError
   readOnly = False
   # The minimum and maximum acceptable address range.
   minAddr = 0
   maxAddr = None

   # Abstract methods which need to be implemented by the derived class:
   def _readCommonImpl( self, accBytes, addr ):
      """Perform a read of the specified number of bytes - internal use only"""
      raise NotImplementedError( "Please use a derived class" )

   def _writeCommonImpl( self, accBytes, addr, value ):
      """Perform a read of the specified number of bytes - internal use only"""
      raise NotImplementedError( "Please use a derived class" )

   # Common methods shared between subclasses:
   def mmap( self ):
      """Get a mmap object which represents the memory and can be sliced. Accesses
      may access the hardware multiple times due to Python's use of memcpy to obtain
      the data."""
      return self.mmap_

   def unmap( self ):
      """Safely closes the underlying resource."""
      if self.mmap_:
         self.mmap_.close()
         self.mmap_ = None

   def _checkAddr( self, accBytes, addr ):
      """Validates an address used for accessing the underlying resource."""
      if addr % accBytes != 0:
         raise ValueError( f'Address {addr:#010x} is not a multiple of {accBytes}' )
      if ( ( addr < self.minAddr ) or
           ( self.maxAddr and addr + accBytes > self.maxAddr ) ):
         maxAddrStr = 'inf' if self.maxAddr is None else f'{self.maxAddr:#010x}'
         raise ValueError( f'Address {addr:#010x} out of range'
                           f' ({self.minAddr:#010x} - {maxAddrStr})' )

   def _readCommon( self, accBytes, addr, check ):
      """Perform a read of the specified number of bytes - internal use only"""
      if check:
         self._checkAddr( accBytes, addr )
      return self._readCommonImpl( accBytes, addr )

   def _writeCommon( self, accBytes, addr, value, check ):
      """Perform a write of the specified number of bytes - internal use only"""
      if self.readOnly:
         raise ValueError( "Attempting to write to a read-only resource" )
      if check:
         self._checkAddr( accBytes, addr )
         if not 0 <= value < ( 2 ** ( accBytes * 8 ) ):
            raise ValueError( 'Value %s out of range' % value )
      self._writeCommonImpl( accBytes, addr, value )

   def read8( self, addr, check=True ):
      """Reads an 8-bit value from the specified address in a single access."""
      return self._readCommon( 1, addr, check )

   def read16( self, addr, check=True ):
      """Reads a 16-bit value from the specified address in a single access."""
      return self._readCommon( 2, addr, check )

   def read32( self, addr, check=True ):
      """Reads a 32-bit value from the specified address in a single access."""
      return self._readCommon( 4, addr, check )

   def write8( self, addr, value, check=True ):
      """Writes an 8-bit value to the specified address in a single access."""
      return self._writeCommon( 1, addr, value, check )

   def write16( self, addr, value, check=True ):
      """Writes a 16-bit value to the specified address in a single access."""
      return self._writeCommon( 2, addr, value, check )

   def write32( self, addr, value, check=True ):
      """Writes a 32-bit value to the specified address in a single access."""
      return self._writeCommon( 4, addr, value, check )

class MmapResource( Resource ):
   """Resource implementation for a directly-mapped memory region."""

   def __init__( self, path, readOnly=False, startOffset=None, endOffset=None ):
      self.readOnly = readOnly
      prot = mmap.PROT_READ
      if not readOnly:
         prot |= mmap.PROT_WRITE
      fd = os.open( path, os.O_RDWR )
      try:
         size = os.fstat( fd ).st_size

         if startOffset is None:
            startOffset = 0

         if endOffset is None:
            endOffset = size

         self.minAddr = startOffset
         self.maxAddr = endOffset
         length = endOffset - startOffset

         # pylint: disable-next=superfluous-parens
         if ( startOffset % mmap.PAGESIZE != 0 ):
            raise ValueError( 'Start Offset 0x%x, must be aligned to 0x%x'
                     %(startOffset, mmap.PAGESIZE) )

         self._mmapOffset = startOffset

         try:
            self.mmap_ = mmap.mmap( fd, length, mmap.MAP_SHARED, prot,
                           offset=startOffset )
         except OSError as e:
            if e.errno == errno.EINVAL:
               # pylint: disable-next=raise-missing-from
               raise ValueError( 'Cannot memory-map resource file (is it an'
                                 ' I/O region rather than a memory region?)' )
            raise

      finally:
         try:
            # Note that closing the file descriptor has no effect on the memory map
            os.close( fd )
         except OSError:
            pass

   def _readCommonImpl( self, accBytes, addr ):
      addr = addr - self._mmapOffset
      return _MemVolatileAccess.read( accBytes, self.mmap_, addr )

   def _writeCommonImpl( self, accBytes, addr, value ):
      addr = addr - self._mmapOffset
      _MemVolatileAccess.write( accBytes, self.mmap_, addr, value )


class _PseudoMmap:
   """A class that emulates an mmap object for files that are not mmap-able, by
   implementing __getitem__ and __setitem__ in terms of seek(), read() and
   write()."""

   def __init__( self, f, ignoreLength=False ):
      self.f_ = f
      self.ignoreLength_ = ignoreLength

   def _translateIndex( self, index ):
      if isinstance( index, slice ):
         if self.ignoreLength_:
            indices = ( index.start, index.stop, index.step or 1 )
         else:
            indices = index.indices( len( self ) )
         if indices[ 2 ] != 1:
            raise TypeError( "_PseudoMmap doesn't support extended slices" )
         return ( indices[ 0 ], indices[ 1 ] - indices[ 0 ] )
      else:
         return ( index, 1 )

   def __getitem__( self, index ):
      ( start, num ) = self._translateIndex( index )
      self.f_.seek( start )
      value = self.f_.read( num )
      if num == 1:
         # Emulate py3 byte indexing with an int
         return value[ 0 ]

      return value

   def __setitem__( self, index, value ):
      ( start, num ) = self._translateIndex( index )
      if num == 1 and isinstance( value, int ):
         # Emulate py3 byte indexing/setting with an int
         value = value.to_bytes( 1, 'little' )

      assert len( value ) == num
      self.f_.seek( start )
      self.f_.write( value )
      self.f_.flush()

   def __len__( self ):
      return os.fstat( self.f_.fileno() ).st_size or sys.maxsize

   def close( self ):
      # We don't need to get rid of our reference to f_ - if we leave it, then if
      # the user tries to do any accesses after closing, it will give a nice
      # "trying to operate on a closed file" error.
      self.f_.close()

class PseudoMmapResource( Resource ):
   """Resource implementation for a file descriptor-backed device."""

   def __init__( self, path, readOnly=False, ignoreLength=False ):
      self.readOnly = readOnly
      # File is opened as unbuffered as this class is mainly used
      # for the PCI configuration and I/O spaces and we do not want to buffer them
      # since they contain some status registers
      # pylint: disable-next=consider-using-with
      self.mmap_ = _PseudoMmap( open( path, readOnly and "rb" or "rb+", 0 ),
                                ignoreLength=ignoreLength )
      # ignoreLength is only used as a workaround for accessing files larger than
      # 2GiB (/dev/mem), in which case all functions have to be called with
      # check=False. Thus maxAddr should not matter in that case, so we set it
      # unconditionally here.
      self.maxAddr = len( self.mmap_ )

   def _readCommonImpl( self, accBytes, addr ):
      value = self.mmap_[ addr : addr + accBytes ]
      if isinstance( value, int ):
         return value
      return int.from_bytes( value, 'little' )

   def _writeCommonImpl( self, accBytes, addr, value ):
      self.mmap_[ addr : addr + accBytes ] = value.to_bytes( accBytes, 'little' )

def _hexnum( num ):
   if type( num ) is int: # pylint: disable=unidiomatic-typecheck
      return num
   elif type( num ) is str: # pylint: disable=unidiomatic-typecheck
      return int( num, 16 )
   else:
      return 0
