#!/usr/bin/env python3
# Copyright (c) 2009-2012 Arastra, Inc.  All rights reserved.
# Arastra, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

""" Simple memory mfg verification test which checks the total amount of system
memory installed is equivalent to the expected memory size.
if there are 2 memory DIMM modules installed which need to have the same
characteristics (verified with the same vendor id and part number) """

import seeprom, array                   
import optparse, sys, Tac, re # pylint: disable=deprecated-module

parser = optparse.OptionParser()
parser.add_option( "-d", "--dimms", action="store",
                   help="number of dimms (2 as default)" )
parser.add_option( "-t", "--type", action="store",
                   help="DDR type (DDR2 or DDR3, DDR2 as default)",
                   default = "DDR2")
parser.add_option( "-i", "--ignorecompare", action="store_true",
                   help="only do size checking, but ignore vendor detail comparison")
parser.add_option( "-v", "--verbose", action="store_true",
                   help="verbose mode to dump memory SPD information")
parser.add_option( "-a", dest="spdAddrs", type="int", nargs=2,  
                         help="spd addresses if it is not 0x50 and 0x51")

parser.usage = """%prog [options] size-in-MB
'size' is an integer representing the number of MB of memory we expect, where
MB is defined as 1024*1024 bytes. e.g. %prog 4096 (for 4GB)
%prog checks the total amount of system memory installed and verifies if all DIMM
modules have the same characteristics.
"""

( options, args ) = parser.parse_args()
if len( args ) < 1:
   parser.error( "at least one argument expected, the size of memory in MB" )
try:
   expectedSizeMB = int( args[ 0 ] )
except ValueError:
   parser.error( "Size argument must be an integer" )

if options.dimms:
   dimms = int( options.dimms )
   if dimms not in (1, 2):
      parser.error("dimms needs to be an integer: 1 or 2")
else:
   dimms = 2 #default


expectedDdrType = options.type
ignoreCompare = options.ignorecompare

MemoryTypeStr = ["None", "FPM DRAM", "EDO", "Pipeline Nibble", "SDRAM", "ROM",
                 "DDR SGRAM", "DDR SDRAM", 
                 "DDR2 SDRAM", "DDR2 SDRAM FB", "DDR2 SDRAM FB PROBE",
                 "DDR3 SDRAM"]

busId = 0  #SDRAM is at i2c bus 0
offset = 0
spdSeepromAddr = [0x50, 0x51] #i2c addr for DIMM0 & DIMM1

for dimm in range(dimms):
   if options.spdAddrs:
      spdSeepromAddr[dimm] = options.spdAddrs[dimm]


def memSize( capacityMb, busWidth, width, ranks ):
   """ calculate & return memSize in MB from SPD fields"""
   if options.verbose:
      print( "capacity=%d, busWidth=%d, width=%d, ranks=%d" % (
         capacityMb, busWidth, width, ranks ) )
      
   # pylint: disable-next=consider-using-in
   if capacityMb == 0xff or capacityMb == 0: #not valid, not installed
      return 0
   assert ( width ) # pylint: disable=superfluous-parens
   assert ( busWidth ) # pylint: disable=superfluous-parens
   assert ( ranks ) # pylint: disable=superfluous-parens
   #totalModuleSize = sdram capacity (Mb) * busWidth / 8 bit / width * ranks
   memSizeMB = capacityMb * busWidth // 8 // width * ranks
   return memSizeMB

class DimmSpd:
   def __init__( self, bus, deviceId ):
      result = seeprom.doRead( bus, deviceId, 0, 256 )
      self.spd = array.array('B', result)

   def type ( self ):
      return self.spd[ 2 ]

   def typeStr( self ):
      return MemoryTypeStr[self.type()]

   def timingInfo( self ):
      print( "Error Not Implemented" )
      raise NotImplementedError

   def mfgJedecIdCode( self ):
      print( "Error Not Implemented" )
      raise NotImplementedError

   def mfgPartNo( self ):
      print( "Error Not Implemented" )
      raise NotImplementedError   
   
   def memSizeMB( self ):
      print( "Error Not Implemented" )
      raise NotImplementedError

class DimmSpdDDR2( DimmSpd ):
   def __eq__( self, other ):
      """ Return whether the 2 spds are equivalent. This means
      checking all registers through byte 90, but skip byte72 (mfg location) 
      i.e. 0:72, 73:91
      ( See spd definition
      from Smart Modular Technologies )."""
      return ( self.spd[ 0 : 72 ] + self.spd[ 73 : 91 ]
         == other.spd[ 0 : 72 ] + other.spd[ 73 : 91 ] )

   def __hash__( self ):
      return hash( self.spd[ 0 : 72 ] + self.spd[ 73 : 91 ] )

   def __ne__( self, other ):
      return ( self.spd[ 0 : 72 ] + self.spd[ 73 : 91 ]
         != other.spd[ 0 : 72 ] + other.spd[ 73 : 91 ] )

   def __lt__( self, other ):
      return ( self.spd[ 0 : 72 ] + self.spd[ 73 : 91 ]
         < other.spd[ 0 : 72 ] + other.spd[ 73 : 91 ] )

   def __le__( self, other ):
      return ( self.spd[ 0 : 72 ] + self.spd[ 73 : 91 ]
         <= other.spd[ 0 : 72 ] + other.spd[ 73 : 91 ] )

   def __gt__( self, other ):
      return ( self.spd[ 0 : 72 ] + self.spd[ 73 : 91 ]
         > other.spd[ 0 : 72 ] + other.spd[ 73 : 91 ] )

   def __ge__( self, other ):
      return ( self.spd[ 0 : 72 ] + self.spd[ 73 : 91 ]
         >= other.spd[ 0 : 72 ] + other.spd[ 73 : 91 ] )

   def timingInfo( self ):
      return self.spd[ 0 : 47 ] # Not actually sure what is important here

   def mfgJedecIdCode( self ):
      return self.spd[ 64:72 ]

   def mfgPartNo( self ):
      return self.spd[ 73:91 ]

   def memSizeMB( self ):
      ranks = (self.spd[5] & 0x07) + 1
      totalBusWidth = self.spd[6]
      width = self.spd[13]
      errBusWidth = self.spd[14]
      busWidth = totalBusWidth - errBusWidth
      density = self.spd[31]
      if density & 0x1f: # in 1GB
         capacityMb = density*1024 #in GB->MB
      elif density & 0xe0:  # < 1GB
         capacityMb =  density * 4  # 128MB unit
      return memSize( capacityMb, busWidth, width, ranks )

class DimmSpdDDR3( DimmSpd ):   
   def __eq__( self, other ):
      """ Return whether the 2 spds are equivalent. This means
      checking all registers through byte 118 ( See spd definition
      from Smart Modular Technologies )."""
      return self.spd[ 0 : 119 ] == other.spd[ 0 : 119 ]

   def __hash__( self ):
      return hash( self.spd[ 0: 119 ] )

   def __ne__( self, other ):
      return self.spd[ 0 : 119 ] != other.spd[ 0 : 119 ]

   def __lt__( self, other ):
      return self.spd[ 0 : 119 ] < other.spd[ 0 : 119 ]

   def __le__( self, other ):
      return self.spd[ 0 : 119 ] <= other.spd[ 0 : 119 ]

   def __gt__( self, other ):
      return self.spd[ 0 : 119 ] > other.spd[ 0 : 119 ]

   def __ge__( self, other ):
      return self.spd[ 0 : 119 ] >= other.spd[ 0 : 119 ]

   def timingInfo( self ):
      return self.spd[ 0 : 39 ] # Not actually sure what is important here

   def mfgJedecIdCode( self ):
      return self.spd[118:119] + self.spd[117:118]

   def mfgPartNo( self ):
      return self.spd[ 128:146 ]

   def memSizeMB( self ):
      capacityMb = 2**((self.spd[4] & 0xf) + 8)
      ranks = ((self.spd[7] & 0x38)>>3) + 1
      busWidth = 2**((self.spd[8] & 0x07)+3)
      width =  2**((self.spd[7] & 0x07) + 2)
      return memSize( capacityMb, busWidth, width, ranks )
 
#if options.verbose: #dump spdseeprom if -v
#   Tac.run(['spdseeprom', 'read'])
   
err = 0
totalSizeMB = 0
spds = list() # pylint: disable=use-list-literal

for dimm in range(dimms):
   try:
      ddrTypeField = seeprom.doRead( busId, spdSeepromAddr[dimm], 2, 1 )
      ddrType = ord(ddrTypeField)
      if ddrType == 11: #DDR3
         spd = DimmSpdDDR3( busId, spdSeepromAddr[dimm] )
      elif ddrType in [ 8, 9, 10 ]: #DDR2
         spd = DimmSpdDDR2( busId, spdSeepromAddr[dimm] )
   # pylint: disable-msg=W0703
   except Exception as e:
      sys.stdout.write( "Read Error: %s  (Invalid device or not installed?)\n" % e )
      continue

   if not re.search(expectedDdrType, spd.typeStr()):
      sys.stderr.write( "Error: SDRAM %d type is %s, not as expected %s\n" % (
         dimm, spd.typeStr(), expectedDdrType))
      err += 1
   else:
      print( "SDRAM %d type is %s as expected %s" % (
         dimm, spd.typeStr(), expectedDdrType ) )
      
   if options.verbose: #dump spdseeprom if -v
      Tac.run(['spdseeprom', 'read', '-d', '%d.%d'% (busId, spdSeepromAddr[dimm])])

   totalSizeMB +=  spd.memSizeMB()
   spds.append( spd )

if ( totalSizeMB != expectedSizeMB ): # pylint: disable=superfluous-parens
   sys.stderr.write( "Error: total installed memory size is %dMB, expected %dMB\n"
                     % ( totalSizeMB, expectedSizeMB ) )
   err += 1

dimmsInstalled = len(spds)
if dimmsInstalled != dimms: #error if dimms not as expected
   sys.stderr.write( "Error: %d memory module(s) installed, expected %d modules\n"
                     % ( dimmsInstalled, dimms ) )
   err += 1

 
## make sure all DIMMs have the same characteristics, from the same vendor with
##   the same part no.
if ignoreCompare != True: # pylint: disable=singleton-comparison
   sys.stderr.write( "Checking mfg info in Spds...\n" )
   spdBase = None
   spdMismatch = False
   for dimm in range(dimmsInstalled):

      if spdBase is None:
         spdBase = spds[dimm] #save the 1st one
      elif spds[dimm] != spdBase:
         spdMismatch = True
         err += 1
         if spdBase.mfgJedecIdCode() != spds[dimm].mfgJedecIdCode():
            sys.stderr.write( "Error: JEDEC ID of the memory module %d is\
                               mismatched\n"\
                               % (dimm+1))
         if spdBase.mfgPartNo() != spds[dimm].mfgPartNo():
            sys.stderr.write( "Error: Part number of memory module %d is\
                               mismatched\n"\
                              % (dimm+1))
   if spdMismatch:
      sys.stderr.write( "%d Spds mismatch. Dumping spd contents\n" %dimms )
      for dimm in range( dimmsInstalled ):
         sys.stderr.write( "SPD Dump DIMM %d:\n%s\n" % ( (dimm+1), spds[dimm].spd ) )
   else:
      sys.stderr.write( "%d Spds match\n" %dimms )
   
if not err:
   print( "Passed: %d of %d memory modules installed and verified with %s %dMB"
          % ( dimmsInstalled, dimms, expectedDdrType, totalSizeMB ) )
else:
   print( "Failed: %d of %d memory modules installed with %dMB (expected %s %dMB)"
          % ( dimmsInstalled, dimms, totalSizeMB, expectedDdrType, expectedSizeMB ) )
   
sys.exit(err)
