#!/usr/bin/env python3
# Copyright (c) 2012 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
#
# S-channel Accelerator show/debug command

# pylint: disable=consider-using-f-string
from optparse import OptionParser # pylint: disable=deprecated-module
from ctypes import POINTER, cast, sizeof
import schan_accel, SchanAccelLib
import time, sys, os
import dmamem
import mmap

usage = '''
%prog [ options ] list
%prog [ options ] dump [ accel|cmdring|stats [ <name> ] ]
%prog [ options ] resetreq <copName>
%prog [ options ] injecterror <copName> <accelindex> <errorcode> <errorarg>'''

def error( s ):
   print( s )
   exit( -1 ) # pylint: disable=consider-using-sys-exit

def pcidevname( hamImpl ):
   if hamImpl.kind == 'hamTypeDmamem':
      a = hamImpl.name
      return a
   else:
      a = hamImpl.address
      return "%02x:%02x.%x" % ( a.bus, a.slot, a.function )

class Cmd:
   def __init__( self ):
      parser = OptionParser( usage = usage )
      self.parser = parser
      parser.add_option( "--sysname", default="ar",
         help="system name (default: %default)" )
      parser.add_option( "--verbose", "-v", action="store_true",
         help="enable more verbose msgs" )
      parser.add_option( "-t", type="int", dest="sec",
         help = "delay sec seconds between each",
         default=0 )

      ( self.opts, self.args ) = parser.parse_args()

      self.pciDeviceStatusDir = None
      self.coprocessor = None
      self.cops = None
      self.hams = None
      self.accelcops = None

      self.cmdlist = {}
      self.cmdlist['list'] = self.list
      self.cmdlist['dump'] = self.dump
      self.cmdlist['resetreq'] = self.resetreq
      self.cmdlist['injecterror'] = self.injecterror
      # remember to update the usage string when adding a command

   def init( self ):
      ( self.cops, self.hams ) = SchanAccelLib.getCops( sysname=self.opts.sysname )
      self.accelcops = self.getaccelcops()
      # prune coprocessors not currently hardwarePresent from cops and hams dicts
      names = list( self.cops )
      for name in names:
         if name in self.accelcops:
            pass
         else:
            print( "%s hardware not present" % ( name ) )
            del self.cops[ name ]
            del self.hams[ name ]

   def usage( self ):
      self.parser.print_help()
      exit( -1 ) # pylint: disable=consider-using-sys-exit

   def cmd( self ):
      if os.geteuid() != 0:
         print( "Error: Not a privileged user" )
         exit( -1 ) # pylint: disable=consider-using-sys-exit

      if len( self.args ) == 0:
         self.usage()

      f = self.cmdlist.get( self.args[0] )

      if f is None:
         self.usage()

      self.init()

      if self.opts.sec:
         while True:
            f()
            sys.stdout.flush()
            try:
               time.sleep( self.opts.sec )
            except KeyboardInterrupt:
               break
            print( "========================================================" )
      else:
         f()

   def list( self ):
      for name, cop in self.cops.items():
         print( "%s %s id %u" %
                ( name, pcidevname( self.hams[ name ] ), cop.coprocessorId ) )

   def dump( self ):
      args = self.args
      if len( args ) == 1:
         self.dumpCmdring( None )
         print()
         self.dumpAccel( None )
      elif args[1] == 'accel':
         copName = args[2] if len( args ) == 3 else None
         self.dumpAccel( copName )
      elif args[1] == 'cmdring':
         cmdringName = args[2] if len( args ) == 3 else None
         self.dumpCmdring( cmdringName )
      elif args[1] == 'stats':
         cmdringName = args[2] if len( args ) == 3 else None
         self.dumpStats( cmdringName )
      else:
         self.usage()

   def dumpAccel( self, copName ):
      if copName and copName not in self.cops:
         error( "%s not found" % copName )

      if len( self.cops ) == 0:
         return

      for name, cop in self.cops.items():
         if copName and ( name != copName ):
            continue

         print( "%s: %s id %u fapmask %s genId %u:" %
            ( name, pcidevname( self.hams[ name ] ), cop.coprocessorId, cop.fapmask,
            cop.genId ) )

         c = self.accelcops[ name ]

         for i in range( schan_accel.NUM_SCHAN_ACCEL_PER_COPROCESSOR ):
            a = c.accel[i]
            if a.state or a.statereq or a.error or a.bufaddr or self.opts.verbose:
               self.printAccel( "accel%u" % i, a )

         print()

   def printAccel( self, name, a ):
      print( "  %s: state %u statereq %u error %u errorarg 0x%x "
         "error_sbusaddr 0x%x errorline %u magic 0x%x" %
         ( name, a.state, a.statereq, a.error, a.errorarg, a.error_sbusaddr,
         a.errorline, a.magic ) )
      print( "    bufaddr 0x%x bufsz %u curaddr 0x%x head %u tail %u seqno %u" %
             ( a.bufaddr, a.bufsz, a.curaddr, a.head, a.tail, a.seqno ) )
      print( "    processed %u skipped %u nak %u" %
             ( a.processed, a.skipped, a.nak ) )

   def dumpStats( self, copName ):
      if copName and copName not in self.cops:
         error( "%s not found" % copName )

      if len( self.cops ) == 0:
         return

      for name, cop in self.cops.items():
         if copName and ( name != copName ):
            continue

         print( "-----accel-----  total(us)  max(us)  "
                "processed    skipped maxcmd" )

         c = self.accelcops[ name ]

         for i in range( schan_accel.NUM_SCHAN_ACCEL_PER_COPROCESSOR ):
            a = c.accel[i]
            if a.state or a.statereq or a.error or a.bufaddr or self.opts.verbose:
               self.printStats( name, i, cop.coprocessorId, a )

         print()

   def printStats( self, name, accel, fapId, a ):
      print( "%13s:%x %10u %8u %10u %10u %s" %
         ( name, accel, a.total_us, a.max_us, a.processed, a.skipped,
            schan_accel.schanOpcodeNames.get( a.max_opcode,
                                              f'unknown ({a.max_opcode:x})' ) ) )

   def dumpCmdring( self, cmdringName ):

      if len( self.cops ) == 0:
         return

      # construct the set of SchanAccelCmdRings referenced by all accelerators
      cmdrings = set()
      for name in self.cops:
         c = self.accelcops[ name ]
         for i in range( schan_accel.NUM_SCHAN_ACCEL_PER_COPROCESSOR ):
            a = c.accel[i]
            if a.state or a.statereq or a.error:
               bufaddr = a.bufaddr - schan_accel.CMDRING_HDRSZ
               bufsz = a.bufsz + schan_accel.CMDRING_HDRSZ
               cmdrings.add( ( bufaddr, bufsz ) )

      # get user dmamem range
      try:
         ( udma_start, udma_end ) = dmamem.dmamem.limits()
         ( udma64_start, udma64_end ) = dmamem.dmamem.limits64()
      except RuntimeError:
         return

      for ( bufaddr, bufsz ) in cmdrings:
         if ( not udma_start <= bufaddr < udma_end ) and \
            ( not udma64_start <= bufaddr < udma64_end ):
            print( "accel bufaddr invalid udma address (0x%x)" % bufaddr )
            continue

         ( chunkname, basepa ) = dmamem.dmamem.find( bufaddr )
         if not chunkname:
            print( "dmamem chunk for udma address 0x%x not found" % bufaddr )
            continue

         # map command ring header
         va = dmamem.dmamem.map( chunkname, False )

         offset = bufaddr - basepa
         addr = va + offset

         ptr = cast( addr, POINTER( schan_accel.cmdring ) )
         cmdring = ptr[0]

         if cmdringName and ( cmdring.name != cmdringName ):
            continue

         print( "%s:" % cmdring.name )
         print( "  magic 0x%x head %u tail %u postedtail %u seqno %u cmds %u" %
            ( cmdring.magic, cmdring.head, cmdring.tail, cmdring.postedtail,
            cmdring.seqno, cmdring.cmds ) )
         print( "  size %u dmaAddress 0x%x" % ( bufsz, bufaddr ) )

         for name in self.cops:
            c = self.accelcops[ name ]
            for i in range( schan_accel.NUM_SCHAN_ACCEL_PER_COPROCESSOR ):
               a = c.accel[i]
               if a.state or a.statereq or a.error:
                  if ( a.bufaddr - schan_accel.CMDRING_HDRSZ ) == bufaddr:
                     hamImpl = self.hams[name]
                     self.printAccel( "%s %s accel%u" % \
                        ( name, pcidevname(hamImpl), i ), a )

         dmamem.dmamem.unmap( chunkname, va )

         print()

   def resetreq( self ):
      args = self.args
      copName = self.args[1] if len( args ) > 1 else None
      if copName not in self.cops:
         error( "%s not found" % copName )
      c = self.accelcops.get( copName )
      c.resetreq = 1

   def injecterror( self ):
      args = self.args
      if len( args ) != 5:
         self.usage()
      copName = self.args[1]
      index = int( self.args[2], 0 )
      errorcode = int( self.args[3], 0 )
      errorarg = int( self.args[4], 0 )

      c = self.accelcops.get( copName )
      if c is None:
         error( "%s not found" % copName )
      if ( index < 0 ) or ( index > schan_accel.NUM_SCHAN_ACCEL_PER_COPROCESSOR ):
         error( "accelerator index %d out of range" % index )
      a = c.accel[ index ]
      # blam
      a.errorarg = errorarg
      a.error = errorcode
      a.statereq = schan_accel.SCHAN_ACCEL_DISABLE

   # return dict of name to hardware schan_coprocessor ctype
   def getaccelcops( self ):
      accelcops = {}
      for name, h in self.hams.items():
         if not h.hardwarePresent:
            continue
         if h.kind == 'hamTypeDmamem':
            addr = dmamem.dmamem.map( h.name, False )
            schancop = schan_accel.schan_coprocessor.from_address( addr )
         else:
            path = "/sys/bus/pci/devices/%s/%s" % \
                ( h.address.stringValue(), h.filename )

            with open( path, 'r+b' ) as f:
               maplen = h.offset + sizeof( schan_accel.schan_coprocessor )
               buf = mmap.mmap( f.fileno(), maplen )
               schancop = schan_accel.schan_coprocessor.from_buffer( buf, h.offset )

         accelcops[ name ] = schancop

      return accelcops

if __name__ == "__main__":
   Cmd().cmd()
