#!/usr/bin/env python3
#
# Copyright (c) 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
#
# Development/debug utility for Broadcom CMIC SBUSDMA Table DMA .
#
# In Arad/T2 each CMICM has 3 CMC's (CMC0, CMC1 and CMC2).
# Each CMC# has 3 SBUSDMA channels which support contiguous table dma.
# Each SBUSDMA channel has about 20 control registers.
#
# Examples:
# Read and print the first 4 Trident2 MAC table entries:
#     sbusdma -n 4 02:00 read 4 1 0x1c000000 16
#
# Write the first 100 Trident2 MAC table entries to 0x1 0x2 0x3 0x4:
#     sbusdma -n 100 02:00 write 4 1 0x1c000000 16 0x1 0x2 0x3 0x4

# pylint: disable=consider-using-f-string
# pylint: disable=consider-using-sys-exit

import Tac
import Pci
import time
import schan_accel
from optparse import OptionParser # pylint: disable=deprecated-module
from dmamem import dmamem
from ctypes import Structure, c_uint32, pointer
from SchanMsgLib import readMemory, writeMemory
from TacUtils import Timeout

# SBUSDMA Channel register set
class CmicMSbusDmaRegs( Structure ):
   _fields_ = [
      ( 'control', c_uint32 ),
      ( 'request', c_uint32 ),
      ( 'count', c_uint32 ),
      ( 'opcode', c_uint32 ),
      ( 'sbus_start_address', c_uint32 ),
      ( 'hostmem_start_addr', c_uint32 ),
      ( 'desc_start_address', c_uint32 ),
      ( 'status', c_uint32 ),
      ( 'cur_hostmem_address', c_uint32 ),
      ( 'cur_sbusaddr', c_uint32 ),
      ( 'cur_desc_address', c_uint32 ),
      ( 'cur_sbusdma_config_request', c_uint32 ),
      ( 'cur_sbusdma_config_count', c_uint32 ),
      ( 'cur_sbusdma_config_start_address', c_uint32 ),
      ( 'cur_sbusdma_config_hostmem_start_address', c_uint32 ),
      ( 'cur_sbusdma_config_opcode', c_uint32 ),
      ( 'sbusdma_debug', c_uint32 ),
      ( 'sbusdma_debug_clr', c_uint32 ),
      ( 'sbusdma_eccerr_addr', c_uint32 ),
      ( 'sbusdma_eccerr_control', c_uint32 ),
   ]

class CmicXSbusDmaRegs( Structure ):
   _fields_ = [
      ( 'control', c_uint32 ),
      ( 'request', c_uint32 ),
      ( 'count', c_uint32 ),
      ( 'opcode', c_uint32 ),
      ( 'sbus_start_address', c_uint32 ),
      ( 'hostmem_start_addr', c_uint32 ),
      ( 'hostmem_start_addr_hi', c_uint32 ),
      ( 'desc_start_address', c_uint32 ),
      ( 'desc_start_address_hi', c_uint32 ),
      ( 'status', c_uint32 ),
      ( 'cur_hostmem_address', c_uint32 ),
      ( 'cur_hostmem_address_hi', c_uint32 ),
      ( 'cur_sbusaddr', c_uint32 ),
      ( 'cur_desc_address', c_uint32 ),
      ( 'cur_desc_address_hi', c_uint32 ),
      ( 'cur_sbusdma_config_request', c_uint32 ),
      ( 'cur_sbusdma_config_count', c_uint32 ),
      ( 'cur_sbusdma_config_start_address', c_uint32 ),
      ( 'cur_sbusdma_config_hostmem_start_address', c_uint32 ),
      ( 'cur_sbusdma_config_hostmem_start_address_hi', c_uint32 ),
      ( 'cur_sbusdma_config_opcode', c_uint32 ),
      ( 'sbusdma_debug', c_uint32 ),
      ( 'sbusdma_debug_clr', c_uint32 ),
      ( 'sbusdma_eccerr_addr', c_uint32 ),
      ( 'sbusdma_eccerr_addr_hi', c_uint32 ),
      ( 'sbusdma_eccerr_control', c_uint32 ),
   ]

dmabufname = 'sbusdma'
INCRSHIFT = 24

# CMICM:
# CMC0: 0x31600 ch0: +0x0, ch1: +0x50, ch2: +0xa0
# CMC1: 0x32600 ..
# CMC2: 0x33600 ..
# CMICX:
# CMC0: 0x3000 ch0: +0x0, ch1: +0x100, ch2: +0x200, ch3: +0x300
# CMC1: 0x6000 ..
#
def sbusdmaOffset( cmcgen, cmc, ch ):
   if cmcgen == 'm':
      assert ( cmc >= 0 ) and ( cmc <= 2 ) # pylint: disable=chained-comparison
      assert ( ch >= 0 ) and ( ch <= 2 ) # pylint: disable=chained-comparison
      return 0x31600 + ( cmc * 0x1000 ) + ( ch * 0x50 )
   elif cmcgen == 'x':
      assert ( cmc >= 0 ) and ( cmc <= 1 ) # pylint: disable=chained-comparison
      assert ( ch >= 0 ) and ( ch <= 3 ) # pylint: disable=chained-comparison
      return ( ( cmc + 1 ) * 0x3000 ) + ( ch * 0x100 )
   else:
      assert False, "Unknown CMIC Generation %s" % ( cmcgen )

def error( msg ):
   print( msg )
   exit( 1 )

usage = \
'''
%prog [ options ] pci-device read acc_type dst_blk address data_byte_len
%prog [ options ] pci-device write acc_type dst_blk address data_byte_len \
0xdata [ 0xdata .. ]
%prog [ options ] pci-device abort
%prog [ options ] pci-device dump'''

class SbusDma:
   def __init__( self ):
      parser = OptionParser( usage=usage )
      parser.add_option( "-n", type="int",
         help="number of entries to read (default: %default)",
         default=1 )
      parser.add_option( "--bar", type="int",
         help="Device BAR number (0-2) (default: %default)",
         default=0 )
      parser.add_option( "--cmc", type="int",
         help="CMICM CMC# (0-2) or CMICX CMC# (0-1) (default: %default)",
         default=1 )
      parser.add_option( "--ch", type="int",
         help="SBUSDMA channel number (0-3) (default: %default)",
         default=1 )
      parser.add_option( "--hver", type="int",
         help="schan message header version (3-4) (default: %default)",
         default=3 )
      parser.add_option( "--cmcgen", type="choice", choices=[ 'm', 'x' ],
         help="CMIC Generation (m,x) (default: %default)",
         default='m' )
      parser.add_option( "--dmabuf", type="string",
         help="Name of preexisting dmamem allocated buffer to be used",
         default='' )
      parser.add_option( "--host-address", type="int", dest="host_address",
         help="pyhsical address in main memory that should be read/written from/to."\
               "If this is used, the address taken from --dmabuf is not used.",
         default=-1 )
      parser.add_option( "--dmaOffset", type=int,
         help="Offset within dmamem buffer", default=0 )
      parser.add_option( "--incrshift", type=int,
         help="Number of entries to be interleaved at destination", default=0 )
      parser.add_option( "--silent-timeouts", action="store_true",
         dest="silent_timeouts",
         help="Handle transaction timeouts silently, \
               return non-zero exit code on failure.",
         default=False )
      parser.add_option( "--stop-dma-prem", action="store_true",
         dest="stop_dma_prem",
         help="Stop the Dma operation prematurely. Do not use this option if you "
              " want the Dma operation to complete successfully.",
         default=False )

      self.parser = parser

      ( self.opts, self.args ) = parser.parse_args()

      if len( self.args ) < 2:
         print( "ERROR: insufficient number of args" )
         parser.print_help()
         exit( 1 )

      pcidev = self.args[0]

      self.cmd = self.args[1]

      if self.opts.host_address != -1:
         self.buf = None
      elif self.opts.dmabuf:
         # Find chunk size
         size = 0
         chunks = dmamem.dump()
         # Example output of dmamem.dump():
         # [ '  1 SandTcam-cmdring   SandTcam   0x000058b0c000   1048576' ]
         for chunk in chunks:
            chunk = chunk.split()
            if chunk[ 1 ] == self.opts.dmabuf:
               size = int( chunk[ 4 ] )
               break

         if size == 0:
            print( "ERROR: dmamem buffer name doesn't exist" )
            exit( 1 )

         # Map dmamem region
         self.buf = dmamem( self.opts.dmabuf, size, offset=self.opts.dmaOffset )
      else:
         # allocate a dmamem buffer
         dmamem.free( dmabufname, ignoreError=True )
         self.buf = dmamem( dmabufname, 512 * 1024, offset=self.opts.dmaOffset )

      # map device registers
      dev = Pci.Device( pcidev )
      mmap = dev.resource( self.opts.bar ).mmap_
      offset = sbusdmaOffset( self.opts.cmcgen, self.opts.cmc, self.opts.ch )
      # pylint: disable-msg=E1101
      if self.opts.cmcgen == 'm':
         ptr = pointer( CmicMSbusDmaRegs.from_buffer( mmap, offset ) )
      elif self.opts.cmcgen == 'x':
         ptr = pointer( CmicXSbusDmaRegs.from_buffer( mmap, offset ) )
      else:
         assert False, "Unknown CMIC Generation %s" % ( self.opts.cmcgen )
      self.regs = ptr[0]
      self.incrshift = self.opts.incrshift

   def stop( self ):
      regs = self.regs
      regs.control = 0
      time.sleep(0.001)
      assert regs.control == 0

   def start( self ):
      self.regs.control = 1

   def read( self ):
      args = self.args
      opts = self.opts

      # validate command line args
      if len( args ) != 6:
         print( "ERROR: insufficient number of args" )
         self.parser.print_help()
         exit( 1 )

      acc_type = int( args[2] )
      dst_blk = int( args[3] )
      address = int( args[4], 0 )
      data_byte_len = int( args[5], 0 )
      assert ( data_byte_len % 4 ) == 0, "data_byte_len must be a word-multiple"

      regs = self.regs

      # construct a READ_MEMORY_CMD_MSG
      m = readMemory( opts.hver, acc_type, dst_blk, address )

      schan_accel.schan_header_dmaIs( m.header, opts.hver, 1 )

      dwc = data_byte_len // 4
      n = opts.n
      mask = 0xFFFFFFFF

      if self.buf:
         dstAddr = self.buf.pa
      else:
         dstAddr = opts.host_address

      dstAddrLo = mask & dstAddr

      self.stop()

      regs.request = dwc
      regs.count = n
      regs.opcode = m.dwords[0]  # really header word
      regs.sbus_start_address = address
      regs.hostmem_start_addr = dstAddrLo
      if opts.cmcgen == "x":
         dstAddrHi = mask & ( dstAddr >> 32 )
         dstAddrHi |= 0x10000000
         regs.hostmem_start_addr_hi = dstAddrHi

      # go
      self.start()

      if opts.stop_dma_prem:
         regs.control = 0
         assert regs.control == 0
         self.dump()
         return

      try:
         Tac.waitFor( lambda : regs.status & 1, timeout=3.0, maxDelay=0.001,
                      sleep=True )
      except Timeout:
         if opts.silent_timeouts:
            exit( 1 )
         else:
            raise

      if regs.status != 1:
         print( "ERROR: status 0x%x" % regs.status )
         self.dump()
         exit( 1 )

      if self.buf:
         for i in range( n ):
            print( "%u:" % i, end=' ' )
            for offset in range( dwc ):
               print( "0x%x" % self.buf.uint32Arr[ ( i * dwc ) + offset ], end=' ' )
            print()

      self.stop()


   def write( self ):
      args = self.args
      opts = self.opts

      # validate command line args
      if len( args ) < 7:
         print( "ERROR: insufficient number of args" )
         self.parser.print_help()
         exit( 1 )

      acc_type = int( args[2] )
      dst_blk = int( args[3] )
      address = int( args[4], 0 )
      data_byte_len = int( args[5], 0 )
      assert ( data_byte_len % 4 ) == 0, "data_byte_len must be a word-multiple"

      dwc = data_byte_len // 4

      if len( args ) != ( 6 + dwc ):
         print( "ERROR: data_byte_len arg not consistent with number of data words" )
         self.parser.print_help()
         exit( 1 )

      regs = self.regs

      data = [ int( arg, 0 ) for arg in args[6:6+dwc] ]

      # construct a WRITE_MEMORY_CMD_MSG
      m = writeMemory( opts.hver, acc_type, dst_blk, address, data )
      schan_accel.schan_header_dmaIs( m.header, opts.hver, 1 )

      n = opts.n
      mask = 0xFFFFFFFF

      if self.buf:
         dstAddr = self.buf.pa
      else:
         dstAddr = opts.host_address

      dstAddrLo = mask & dstAddr
      dstAddrHi = mask & ( dstAddr >> 32 )

      if self.buf:
         # put data in dmabuf
         for row in range( opts.n ):
            for i in range( dwc ):
               self.buf.uint32Arr[ ( row * dwc ) + i ] = data[ i ]

      self.stop()

      regs.request = dwc << 5
      regs.request = regs.request | ( self.incrshift << INCRSHIFT )
      regs.count = n
      regs.opcode = m.dwords[0]  # header word, not opcode
      regs.sbus_start_address = address
      regs.hostmem_start_addr = dstAddrLo
      if opts.cmcgen == "x":
         dstAddrHi |= 0x10000000
         regs.hostmem_start_addr_hi = dstAddrHi

      # go
      self.start()

      if opts.stop_dma_prem:
         regs.control = 0
         assert regs.control == 0
         self.dump()
         return

      try:
         Tac.waitFor( lambda : regs.status & 1, timeout=3.0, maxDelay=0.001,
                      sleep=True )
      except Timeout:
         if opts.silent_timeouts:
            exit( 1 )
         else:
            raise

      if regs.status != 1:
         print( "ERROR: status 0x%x" % regs.status )
         self.dump()
         exit( 1 )

      self.stop()

   def abort( self ):
      regs = self.regs
      regs.control = 2
      time.sleep(0.01)
      regs.control = 0

   def tearDown( self ):
      if self.buf:
         self.buf.free( dmabufname, ignoreError=True )

   def dump( self ):
      regs = self.regs
      print( "control 0x%x" % regs.control )
      print( "request 0x%x" % regs.request )
      print( "count 0x%x" % regs.count )
      print( "opcode 0x%x" % regs.opcode )
      print( "sbus_start_address 0x%x" % regs.sbus_start_address )
      print( "hostmem_start_addr 0x%x" % regs.hostmem_start_addr )
      if self.opts.cmcgen == 'x':
         print( "hostmem_start_addr_hi 0x%x" % regs.hostmem_start_addr_hi )
      print( "desc_start_address 0x%x" % regs.desc_start_address )
      if self.opts.cmcgen == 'x':
         print( "desc_start_address_hi 0x%x" % regs.desc_start_address_hi )
      print( "status 0x%x" % regs.status )
      print( "cur_hostmem_address 0x%x" % regs.cur_hostmem_address )
      if self.opts.cmcgen == 'x':
         print( "cur_hostmem_address_hi 0x%x" % regs.cur_hostmem_address_hi )
      print( "cur_sbusaddr 0x%x" % regs.cur_sbusaddr )
      print( "cur_desc_address 0x%x" % regs.cur_desc_address )
      if self.opts.cmcgen == 'x':
         print( "cur_desc_address_hi 0x%x" % regs.cur_desc_address_hi )
      print( "cur_sbusdma_config_request 0x%x" % regs.cur_sbusdma_config_request )
      print( "cur_sbusdma_config_count 0x%x" % regs.cur_sbusdma_config_count )
      print( "cur_sbusdma_config_start_address 0x%x" %
             regs.cur_sbusdma_config_start_address )
      print( "cur_sbusdma_config_hostmem_start_address 0x%x" %
             regs.cur_sbusdma_config_hostmem_start_address )
      if self.opts.cmcgen == 'x':
         print( "cur_sbusdma_config_hostmem_start_address_hi 0x%x" %
                regs.cur_sbusdma_config_hostmem_start_address_hi )
      print( "cur_sbusdma_config_opcode 0x%x" %
             regs.cur_sbusdma_config_opcode )
      print( "sbusdma_debug 0x%x" % regs.sbusdma_debug )
      print( "sbusdma_debug_clr 0x%x" % regs.sbusdma_debug_clr )
      print( "sbusdma_eccerr_addr 0x%x" % regs.sbusdma_eccerr_addr )
      if self.opts.cmcgen == 'x':
         print( "sbusdma_eccerr_addr_hi 0x%x" % regs.sbusdma_eccerr_addr_hi )
      print( "sbusdma_eccerr_control 0x%x" % regs.sbusdma_eccerr_control )

if __name__ == "__main__":
   s = SbusDma()

   cmd = s.cmd

   if cmd == 'read':
      s.read()
   elif cmd == 'write':
      s.write()
   elif cmd == 'abort':
      s.abort()
   elif cmd == 'dump':
      s.dump()
   else:
      print( "unknown command", cmd )
      s.parser.print_help()
      exit( 1 )

   s.tearDown()
