#!/usr/bin/env python3
# Copyright (c) 2010, 2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

# pylint: disable-next=deprecated-module
import optparse, sys, time, os, struct, threading, re
#import Tac
# pylint: disable-msg=W0401

def pciAddrToPath( _addr ):
   """convert 'X:Y.Z' from table to '0X/0Y.Z' """
   ( x, yz ) = _addr.split( ':' )
   ( y, z ) = yz.split( '.' )
   ( x, y, z ) = ( int(s, 16) for s in (x, y, z) )
   return f'/proc/bus/pci/{x:02x}/{y:02x}.{z:x}'

###################################################
# Advanced Error Reporting (AER) registers from C headers

pciAerRegs = {   # (statusOffset, maskOffset, severityOffset, bitDescr)
   'UNCOR' : ( 4, 8, 12,
               { 0x00000001 : 'Training',
                 0x00000010 : 'Data Link Protocol',
                 0x00001000 : 'Poisoned TLP',
                 0x00002000 : 'Flow Control Protocol',
                 0x00004000 : 'Completion Timeout',
                 0x00008000 : 'Completer Abort',
                 0x00010000 : 'Unexpected Completion',
                 0x00020000 : 'Receiver Overflow',
                 0x00040000 : 'Malformed TLP',
                 0x00080000 : 'ECRC Error Status',
                 0x00100000 : 'Unsupported Request',
               }
             ),
   'CORR'  : ( 16, 20, -1,     # the real register offsets
   #'CORR'  : ( 0x18, 20, -1,  # TESTING ONLY: Point to another register that 
                               # typically has 1's in some error bit positions,
                               # to exercise the 'setpci' code path
               {
                 0x00000001 : 'Receiver Error Status',
                 0x00000040 : 'Bad TLP Status',
                 0x00000080 : 'Bad DLLP Status',
                 0x00000100 : 'REPLAY_NUM Rollover',   # the real bit
                 #0x00000020 : 'REPLAY_NUM Rollover',  # TESTING ONLY
                 0x00001000 : 'Replay Timer Timeout',
               }
             )
}

pciDevStaReg = {   # error status bits in CAP_EXP Device Status reg
   0x0001 : 'Correctable Error',
   0x0002 : 'Non-fatal Error',
   0x0004 : 'Fatal Error',
   0x0008 : 'Unsupported Request',
}

# Top-level Advanced Error Reporting types defined by PCIe
aerErrorTypes = [ 'UNCOR', 'CORR' ]

# Create a bit mask of "Correctable" errors that we consider bad.
badCorrBits = [ 'REPLAY_NUM Rollover' ]
badCorrMask = 0x0
( __, __, __, uncorBitDescr ) = pciAerRegs[ 'CORR' ]
for i in range( 32 ):
   flag = 1 << i
   if flag in uncorBitDescr and uncorBitDescr[ flag ] in badCorrBits:
      badCorrMask |= flag

class PciDevice:
   def __init__( self, _name, _addr, _path, _vendId, _devId ):
      self.name = _name
      self.addr = _addr
      self.path = _path
      self.fd = open( _path ) # pylint: disable=consider-using-with
      self.errors = []
      self.pid = None
      self.reads = 0
      self.cfgSpace = [ _vendId & 0xff, (_vendId >> 8) & 0xff,
                        _devId & 0xff,  (_devId >> 8) & 0xff ]
      self.hasAer = None

   def closeFd( self ):
      self.fd.close()

   def addError( self, err ):
      self.errors.append( err )

   def checkIfPcie( self ):
      if not self.isPcie():
         print( 'ERROR: Device %s at %s does not appear to be PCIe'
                % ( self.name, self.addr ) )
         exit( 1 ) # pylint: disable=consider-using-sys-exit

   # Write a 1 to clear the error bit
   def clearErrorBits( self ):
      self.checkIfPcie()
      # Clear DevSta
      os.popen( 'setpci -s %s CAP_EXP+0x0A.W=0xFFFF' % self.addr )
      if self.hasPcieAer():
         for _errorType in aerErrorTypes:
            (_statOffset, _, _, _) = pciAerRegs[_errorType]
            os.popen( 'setpci -s %s ECAP_AER+0x%X.L=0xFFFFFFFF' \
                      % ( self.addr, _statOffset) )

   def readCfgSpace( self, asString=False ):
      self.fd.seek( 0 )
      buf = self.fd.read( 64 )
      if asString:
         return buf
      bites = struct.unpack( 'BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB' \
                             'BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB', buf )
      return list( bites )

   def parseCfgSpace( self, byteString ):
      '''Parse out vendor, device, rev from basic config space, and gen (speed) and
         width from EXPress capability. These fields happen to be in the same
         places for Type 0 and Type 1 config spaces.'''
      cfg = {}
      ( cfg['vend'], cfg['dev'], _, _, cfg['rev'] ) \
         = struct.unpack( 'HHHHB', byteString[0:9] )
      # Read EXPress CAPab reg. setpci searches the linked list for us
      bfr = os.popen( 'setpci -s %s CAP_EXP+0x10.L' % self.addr ).read().strip()
      statReg = (int(bfr, 16) >> 16) & 0xffff
      cfg['gen'] = statReg & 0xf
      cfg['width'] = (statReg >> 4) & 0x3f
      return cfg

   def isPcie( self, verbose=False ):
      '''Verify we can read EXPress capability ID 0x10.'''
      bfr = os.popen( 'setpci -s %s CAP_EXP+0x0.B' % (self.addr) ).read().strip()
      if bfr != '10':
         return False
      return True

   def hasPcieAer( self, verbose=False ):
      '''Check for AER capability ID 0x0001'''
      if self.hasAer == None: # pylint: disable=singleton-comparison
         if not self.isPcie():
            return False
         bfr = os.popen( 'setpci -s %s ECAP_AER+0x0.W' % (self.addr) ).read().strip()
         self.hasAer = ( bfr == '0001' )
      return self.hasAer

   # pylint: disable-next=inconsistent-return-statements
   def checkLspci( self, msg='', verbose=False ):
      ''' fetch the following & parse
      bash-4.1# lspci -vv -s 4:0.0
      04:00.0 System peripheral: Arastra Inc. Device 0001 (rev 01)
       Status: Cap+ 66MHz- UDF- FastB2B- ParErr- DEVSEL=fast >TAbort- <TAbort-
               <MAbort- >SERR- <PERR- INTx-
       :::
       Capabilities: [1a0 v1] Advanced Error Reporting
       :::
       UESta:  DLP- SDES- TLP- FCP- CmpltTO- CmpltAbrt- UnxCmplt- RxOF- MalfTLP-
               ECRC- UnsupReq- ACSViol-
       CESta:  RxErr- BadTLP- BadDLLP- Rollover- Timeout- NonFatalErr-
      '''
      bfr = os.popen( 'lspci -vv -s %s' % self.addr ).read()
      lines = bfr.splitlines()
      #keys = [ 'Status', 'UESta', 'CESta' ]
      what = { 'Status' : ['ParErr', 'TAbort', 'MAbort', 'SERR', 'PERR'],
               'UESta'  : ['DLP', 'SDES', 'TLP', 'FCP', 'CmpltTO', 'CmpltAbrt',
                            'UnxCmplt', 'RxOF', 'MalfTLP', 'ECRC', 'UnsupReq',
                            'ACSViol' ],
               'CESta'  : ['RxErr', 'BadTLP', 'BadDLLP', 'Rollover', 'Timeout',
                           'NonFatalErr']
             }
      #err = 0
      for line in lines:
         for ( reg, bits ) in what.items():
            if re.match( r'^\s*%s:' % reg, line ):
               for bit in bits:
                  if re.match( r'%s\+' % bit, line ):
                     print( f'### Found Error {reg}/{bit} ###\n"' )
                     print( bfr )
                     return False

   def checkErrorStatus( self, verbose=False ):
      """Check PCIe status registers & parse status bits. Return True if
         no errors.
      """
      self.checkIfPcie()
      isOk = True
      # Check standard PCIe error status
      bfr = os.popen( 'setpci -s %s CAP_EXP+0x0A.W' % self.addr ).read().strip()
      stat = int( bfr, 16 )
      print( 'PCIe Device %s (%s) DevSta=0x%X (%s)'
             % ( self.addr, self.name, stat, 'ERROR' if ( stat & 0xf ) else 'OK' ) )
      if ( stat & 0xf ) != 0x0:
         isOk = False
         if verbose:
            print( f'PCIe Device {self.addr} ({self.name}) Errors:' )
            # pylint: disable-next=consider-using-dict-items
            # pylint: disable-next=consider-iterating-dictionary
            for bit in pciDevStaReg.keys():
               if stat & bit:
                  print( f'   0x{bit:04X}: {pciDevStaReg[ bit ]}' )
      # Check PCIe Advanced Error Reporting registers if present
      if self.hasPcieAer():
         for errorType in aerErrorTypes:
            (statOffset, _, _, bitDescr) = pciAerRegs[errorType]
            setpciCmd = f'setpci -s {self.addr} ECAP_AER+0x{statOffset:X}.L'
            bfr = os.popen( setpciCmd ).read().strip()
            stat = int(bfr, 16)
   
            # Error type that we want to log may be different from "raw" error type,
            # in the case of "correctable" errors that we consider bad.
            logErrorType = errorType
            if stat != 0x0:
               if ( errorType == 'CORR' ) and ( stat & badCorrMask ):
                  logErrorType = 'BAD_CORR'   # don't forgive
               self.addError( f'{logErrorType}:0x{stat:X}' )
               isOk = False
   
            if verbose:
               print( 'checkErrorStatus (AER) cmd="%s"' % ( setpciCmd ) )
               print( 'PCIe Device %s (%s): AER ErrorType=%s Status=0x%X'
                      % ( self.addr, self.name, logErrorType, stat ) )
               # Break down error by bits
               # pylint: disable-next=consider-using-dict-items
               # pylint: disable-next=consider-iterating-dictionary
               for bit in bitDescr.keys():
                  if stat & bit:
                     print( f'   0x{bit:08X}: {bitDescr[ bit ]}' )
         if verbose:
            # FWIW Dump Header Log & Error source regs
            bfrs = []
            for offset in range( 0x1c, 0x34 + 1, 4 ):
               bfr = os.popen( 'setpci -s %s ECAP_AER+0x%X.L' \
                               % (self.addr, offset) ).read().strip()
               bfrs.append( bfr )
            print( f'AER regs({self.addr}) 0x1C-0x34: {bfrs}' )
      self.clearErrorBits()
      return isOk

class PciStressThread( threading.Thread ):

   def __init__( self, _dev, runTime, iters=1 ):
      threading.Thread.__init__( self )
      self.dev = _dev
      self.dev.reads = 0
      self.runTime = opts.duration
      self.iters = iters
      self.running = False

   def run( self ):
      self.running = True
      for _itr in range( 0, self.iters ):
         print( '\nStressing PCIe dev %s for %s seconds; iter=%d...'
                % ( self.dev.addr, self.runTime, _itr ) )
         tStart = time.time()
         while time.time() - tStart < self.runTime:
            bfr = self.dev.readCfgSpace( )
            self.dev.reads += 1
            if len(bfr) < 64 or bfr[0:4] != self.dev.cfgSpace[0:4]:
               print( 'ERROR: addr=%s itr=%d reads=%d'
                      % ( self.dev.addr, _itr,  self.dev.reads ) )
               print( '  Expected {}; Read {}'.format( self.dev.cfgSpace[ 0:4 ],
                  bfr[ 0:4 ] ) )
               self.dev.addError( 'Read' )
               break
         print( '...done (dev=%s iter=%d reads=%d errs=%d)\n'
                % ( self.dev.addr, _itr, self.dev.reads, len( self.dev.errors ) ) )
         if opts.lspci:
            self.dev.checkLspci( msg='iter=%d' % _itr )
         else:
            self.dev.checkErrorStatus( verbose=opts.verbose )
      self.running = False

###################################################
# main

# Command Parsing

cmdParser = optparse.OptionParser()
cmdParser.usage = '%prog -t <seconds> [-v|--verbose]'
cmdParser.add_option(
   '-i', '--iters', action="store", type='int',
   help='Number of iterations (after which we check errors), for stress test',
   default=1 )
cmdParser.add_option(
   '-t', '--duration', action="store", type='float',
   help='duration, for stress test',
   default=0.0 )
cmdParser.add_option(
   '-v', '--verbose', action="store_true", help='verbosity', default=False )
cmdParser.add_option(
   '-b', '--buses', type='string', default='',
   help='List of buses to check, separated by "+" (NO SPACES)\n' \
        '  a=addr n=name v=vend d=dev r=rev g=gen(1|2) w=width(lanes).\n' \
        '  unspecified=dontcare. "a" is mandatory.\n' \
        '  eg. Trident+ (BCM56846) at 2:0.0; scd at 4:0.0\n' \
        '    -b a=2:0.0,n=switch0,v=0x14E4,d=0xb846,r=1,g=2,w=2' \
              '+a=0:4.0,n=nb-switch0,v=0x1022,d=0x9604,g=2,w=2' \
              '+a=4:0.0,n=scd,v=0x3475,d=0x0001,g=1,w=1' \
              '+a=0:9.0,n=nb-scd,v=0x1022,d=0x9608,g=1,w=1' )
cmdParser.add_option(
   '-c', '--pre-clear', action="store_true",
   help='Clear error status registers before test',
   dest='preClear', default=False )
cmdParser.add_option(
   '-l', '--lspci', action="store_true", help='Show AER status from lspci -vv',
   default=False )
cmdHelp = cmdParser.format_help()
(opts, args) = cmdParser.parse_args(sys.argv[1:])


# Parse --buses (misnomer; these are devices) arg and create a dict whose
# entries are some values we expect to find in the PCIe config space for each
# device under test; eg.
# { 'switch0': {'addr' : '2:0.0', 'name' : 'switch0', 'vend' : 0x14e4, ...},
#   'scd':     {'addr' : '4:0.0', 'name' : 'scd', 'vend' : 0x3475, ...}
# }
# Arg format is in help for --buses (really device) option
#
devs = opts.buses.split('+')   # '+' one of few bash-friendly separators left
attrNames = { 'a':'addr', 'n':'name', 'v':'vend', 'd':'dev', 'g':'gen', 'w':'width' }
expectedParams = {}
for dev in devs:
   attrs = dev.split(',')
   devInfo = {}
   for attr in attrs:
      abbrev = attr.split('=')[0]
      if abbrev in attrNames:
         attrName = attrNames[abbrev]
         val = attr.split('=')[1]   # TODO Handle multiple possible values '|'
         if attrName not in [ 'addr', 'name' ]:
            # remaining attrs are binary values in cfg space registers
            val = int( val, 0 )
         devInfo[ attrName ] = val
   expectedParams[ devInfo['name'] ] = devInfo    # using name field as key

# If skipPcieErrorMonitoring file not present, or it's newer than PciBus agent,
# create file and kill agent. Agent can restart, but won't interfere with us.
restartAgent = False
pid = os.popen('pgrep PciBus').read().strip()
if pid == '':
   restartAgent = True
else:
   # ps output: "pid [[d-]hh:]mm:ss"
   tmp = os.popen( 'ps -eo%%p%%t | grep %s' % pid ).read().strip()
   #pt = time.strptime(tmp.split()[1], "%d-%H:%M:%S")
   dhhmmss = tmp.split()[ 1 ].split( '-' )
   days = 0
   if len( dhhmmss ) >= 2:
      days = int( dhhmmss[-2], 10 )
   hhmmss = dhhmmss[-1].split( ':' )
   hrs = 0
   if len( hhmmss ) >= 3:
      hrs = int( hhmmss[-3] )
   mins = int( hhmmss[-2], 10 )
   secs = int( hhmmss[-1], 10 )
   procT = ( ( (24 * days + hrs) * 60 + mins ) * 60 + secs ) / ( 24 * 60 * 60.0 )
   try:
      st = os.stat('/mnt/flash/skipPcieErrorMonitoring')
      fileT = ( time.time() - st.st_mtime ) / ( 24 * 60 * 60.0 )
      print( f'fileT={fileT:f} procT={procT:f}' )
      if fileT < procT:
         restartAgent = True
   except OSError:  # assume file not there
      restartAgent = True

if restartAgent:
   print( 'Creating skipPcieErrorMonitoring' )
   oldPid = os.popen('pgrep PciBus').read().strip()
   os.system( "echo '' > /mnt/flash/skipPcieErrorMonitoring" )
   os.system( "sync" )
   os.system( "killall PciBus" )
   print( 'Waiting for PciBus agent to die...' )
   newPid = oldPid
   for t in range(0, 100):
      newPid = os.popen('pgrep PciBus').read().strip()
      if newPid != oldPid:
         print( '...done' )
         break
      time.sleep(0.1)
   if newPid == oldPid:
      print( 'ERROR: Failed to kill PciBus agent; pid=%s' % newPid )
      exit( 1 ) # pylint: disable=consider-using-sys-exit

devs = []

# Cycle through the selected PCIe devs
# pylint: disable-next=consider-using-dict-items,consider-iterating-dictionary
for name in expectedParams.keys():
   addr = expectedParams[name]['addr']
   path = pciAddrToPath( addr )
   expectedPrm = expectedParams[name]
   dev = PciDevice( name, addr, path, expectedPrm['vend'],  expectedPrm['dev'] )
   devs.append( dev )

   if opts.preClear and dev.isPcie() and dev.hasPcieAer():
      dev.clearErrorBits( )

   # This is the read that we check explicitly, so save=True
   cfgByteStr = dev.readCfgSpace( asString=True )
   rd = dev.parseCfgSpace( cfgByteStr )
   if opts.verbose:
      print( 'PCIe %s(%s): vend/dev/rev = %04X/%04X/%02X; gen/width = %s/%s'
            % ( dev.addr, dev.name, rd[ 'vend' ], rd[ 'dev' ], rd[ 'rev' ],
                rd[ 'gen' ], rd[ 'width' ] ) )

   # Check selected config space items against what we expect.
   for param in [ 'vend', 'dev', 'rev', 'gen', 'width' ]:
      if param in expectedPrm:
         if rd[param] != expectedPrm[param]:
            print( 'ERROR: dev.addr %s (%s): Expected %s=0x%04X; read 0x%04X'
                   % ( dev.addr, dev.name, param, expectedPrm[ param ],
                     rd[ param ] ) )
            dev.addError( 'Bad %s' % param )

if opts.duration > 0.0:
   for dev in devs:
      if dev.errors != []:
         continue   # skip the ones that already have errors
      dev.thrd = PciStressThread( dev, opts.duration, opts.iters )
      dev.thrd.start()

# Wait up to duration + extra for all threads to finish
time.sleep( (opts.duration + 1) * opts.iters )
for t in range( 10 ):
   allDone = True
   for dev in devs:
      if hasattr( dev, 'thrd' ) and dev.thrd.running:
         print( 'ERROR: Stress Test timed out for dev {}({})'.format( dev.addr,
            dev.name ) )
         allDone = False
   if allDone: # pylint: disable=no-else-break
      break
   else:
      time.sleep( 1 )

# Normally shouldn't happen, but double check
for dev in devs:
   if hasattr( dev, 'thrd' ) and dev.thrd.running:
      dev.addError( 'Stress Test Incomplete' )

for dev in devs:
   # FIXME Use DMA if supported
   # FIXME This is really a passive check for end of test run.
   if dev.isPcie():
      print( '\nPCIe dev %s final check for errors...' % dev.addr )
      dev.checkErrorStatus( verbose=opts.verbose )
      print( '...done' )
   dev.closeFd()

allOk = True
for dev in devs:
   if dev.errors != []:
      allOk = False
   if opts.verbose:
      print( f'Bus {dev.addr}({dev.name}) Errors: {dev.errors}' )

if allOk:
   print( 'PASS' )
   sys.exit( 0 )
else:
   print( 'FAIL' )
   sys.exit( 1 )

