#!/usr/bin/env python3
# Copyright (c) 2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
#

# pylint: disable=consider-using-f-string
from __future__ import absolute_import, division, print_function
from ctypes import CDLL, c_bool, c_uint32, c_uint64, c_char_p, \
   c_void_p, byref, create_string_buffer, memset, c_size_t, POINTER, c_ssize_t
import six
import os, sys, re, inspect
import TableOutput
import Tracing

t0 = Tracing.Handle( 'libdmamem' ).trace0       # pylint: disable-msg=E1101

class dmamem( object ): # pylint: disable=useless-object-inheritance
   lib = CDLL( 'libdmamem.so' )
   lib.dmamem_alloc_func.restype = c_uint64
   lib.dmamem_alloc_func.argtypes = [ c_char_p, c_size_t, c_char_p ]
   lib.dmamem_map.restype = c_void_p
   lib.dmamem_map.argtypes = [ c_char_p, c_bool, c_size_t, c_size_t ]
   lib.dmamem_map_all.restype = c_void_p
   lib.dmamem_find.restype = c_char_p
   lib.dmamem_find.argtypes = [ c_uint64, POINTER( c_uint64 ) ]
   lib.dmamem_dump_begin.restype = c_bool
   lib.dmamem_dump_next.restype = c_bool
   lib.dmamem_sizeof.restype = c_ssize_t
   lib.dmamem_dma_addr_none.restype = c_uint64

   DUMP_MAX_LINE_SIZE = 1024

   def __init__( self, name, size, offset=0, readOnly=False, highIova=False,
                 dmaNoCap=False ):
      """ Allocates a dmamem chunk and mmap it at a given offset. """

      agent = os.getenv( 'AGENT_PROCESS_NAME' )
      if not agent:
         # Path to calling file, <stdin> if such doesn't exists
         agent = '@' + inspect.stack()[ 1 ][ 1 ]
      self.name = name
      self.size = size
      self.offset = offset
      t0( 'alloc: %s, size: %d' % ( name, size ) )
      flags = 0
      if readOnly:
         flags |= dmamem.lib.dmamem_flag_readonly()
      if highIova:
         flags |= dmamem.lib.dmamem_flag_64bit()
      if dmaNoCap:
         flags |= dmamem.lib.dmamem_flag_no_cap()
      # 'dmamem_alloc_func' will assert if 'dmaNoCap' is combined with other flags;
      pa = dmamem.lib.dmamem_alloc_func( six.ensure_binary( name ), size,
                                         six.ensure_binary( agent ), flags )
      if pa == 0:
         t0( 'nomem: %s, size: %d' % ( name, size ) )
         raise MemoryError
      va = dmamem.map( name, True, offset=offset )
      if va is None:
         t0( 'map failed: %s, pid: %d, dma: %x'  % ( name, os.getpid(), pa ) )
         raise MemoryError
      self.pa = pa + offset
      self.va = va
      size = size - offset
      self.uint32Arr = ( c_uint32 * ( size // 4 )).from_address( int( self.va ) )
      t0( 'mem: %s, size: %d, offset: 0x%x, dma: %x, va: %x' %
          ( name, offset, size, pa, va ) )

   def zeroChunk( self ):
      memset( self.uint32Arr, 0, self.size - self.offset )

   @staticmethod
   def iommu_map( domain, bus, devfn, iommu_prot ):
      return dmamem.lib.dmamem_iommu_map( domain, bus, devfn, iommu_prot )

   @staticmethod
   def register_device( domain, bus, devfn ):
      return dmamem.lib.dmamem_register_device( domain, bus, devfn )

   @staticmethod
   def iommu_unmap( domain, bus, devfn ):
      return dmamem.lib.dmamem_iommu_unmap( domain, bus, devfn )

   @staticmethod
   def unregister_device( domain, bus, devfn ):
      return dmamem.lib.dmamem_unregister_device( domain, bus, devfn )

   @classmethod
   def map( cls, name, writable, offset=0, size=0 ):
      # mmap needs to be page aligned
      assert( offset % 0x1000 == 0 ) # pylint: disable=superfluous-parens
      va = dmamem.lib.dmamem_map( six.ensure_binary( name ), writable, offset, size )
      if va is None:
         t0( 'map failed: %s, pid: %d, writable: %d, off %x size %x' % \
            ( name, os.getpid(), writable, offset, size ) )
         raise MemoryError
      t0( 'map: %s, writable: %d, va: %x off:%x size: %x' % (
            name, writable, va, offset, size ) )
      return va

   @classmethod
   def map_all( cls ):
      va = dmamem.lib.dmamem_map_all()
      if va is None:
         t0( 'map_all failed: pid: %d' % os.getpid() )
         raise MemoryError
      t0( 'map_all: va: %s' % va )
      return va

   @classmethod
   def unmap_all( cls ):
      dmamem.lib.dmamem_unmap_all()

   @classmethod
   def unmap( cls, name, va, size = 0, strict = False ):
      if strict:
         rc = dmamem.lib.dmamem_unmap_strict( six.ensure_binary( name ),
                                              c_void_p( va ),
                                              size, 0 ) # pylint: disable-msg=W0212
      else:
         rc = dmamem.lib.dmamem_unmap( six.ensure_binary( name ), c_void_p( va ),
                                       size, 0 ) # pylint: disable-msg=W0212
      if rc:
         t0( 'unmap failed: %s, pid: %d, va: %x'  % ( name, os.getpid(), va ) )
         raise MemoryError
      t0( 'unmap: %s, va: %x'  % ( name, va ) )

   @classmethod
   def sizeof( cls, name ):
      return dmamem.lib.dmamem_sizeof( six.ensure_binary( name ) )

   @classmethod
   def dma_addr_none( cls ):
      return dmamem.lib.dmamem_dma_addr_none()

   @classmethod
   def free( cls, name, ignoreError=False ):
      t0( 'free: %s' % name )
      rc = dmamem.lib.dmamem_free( six.ensure_binary( name ) )
      if ( rc != 0 ) and not ignoreError:
         raise NameError

   @classmethod
   def free_all( cls, nameRegEx=None ):
      t0( 'free_all' )
      if not nameRegEx:
         nameRegEx = '^.*$'
      for line in dmamem.dump():
         name = line.split()[ 1 ].decode()
         match = re.match( nameRegEx, name )
         if match and match.end() == len( name ):
            dmamem.free( name )

   # Valid 32-bit DMA address limits: (start, end)
   @classmethod
   def limits( cls ):
      startDmaAddr, endDmaAddr = c_uint32(), c_uint32()
      if dmamem.lib.dmamem_get_limits( byref(startDmaAddr),
              byref(endDmaAddr) ) != 0:
         return ( 0, 0 )
      return ( startDmaAddr.value, endDmaAddr.value )

   # Valid 64-bit DMA address limits: (start, end)
   @classmethod
   def limits64( cls ):
      startDmaAddr, endDmaAddr = c_uint64(), c_uint64()
      if dmamem.lib.dmamem_get_limits64( byref(startDmaAddr),
              byref(endDmaAddr) ) != 0:
         return ( 0, 0 )
      return ( startDmaAddr.value, endDmaAddr.value )

   @classmethod
   def _dump( cls, callback ):
      buf = create_string_buffer( dmamem.DUMP_MAX_LINE_SIZE )

      result = []
      if dmamem.lib.dmamem_dump_begin():
         try:
            while callback( byref(buf), dmamem.DUMP_MAX_LINE_SIZE ):
               result.append( buf.value )
         finally:
            dmamem.lib.dmamem_dump_end()
      return result

   @classmethod
   def dump( cls ):
      return dmamem._dump( dmamem.lib.dmamem_dump_next )

   @classmethod
   def dump_iommu( cls ):
      return os.listdir("/sys/kernel/dmamem/devices")

   @classmethod
   def find( cls, dmaAddr ):
      baseDmaAddr = c_uint64()
      name = dmamem.lib.dmamem_find( dmaAddr, byref(baseDmaAddr) )
      return ( name.decode(), baseDmaAddr.value )

def bytesToStr( size ):
   size_ = size
   power = 2**10 # 2**10 = 1024
   n = 0
   power_labels = { 0 : 'B', 1: 'KiB', 2: 'MiB', 3: 'GiB', 4: 'TiB' }
   while size > power:
      size /= power
      n += 1
   return "0x{:x}({:.2f}{})".format( size_, size, power_labels[ n ] )

def main( ):
   argv = sys.argv[1:]

   ( dmaStart32, dmaEnd32 ) = dmamem.limits()
   if dmaEnd32 == 0 and dmaStart32 == 0:
      sys.exit( -1 )
   ( dmaStart64, dmaEnd64 ) = dmamem.limits64()

   chunks = dmamem.dump()
   if len( argv ) == 0:
      usedDma = 0
      usedDma32 = 0
      usedDma64 = 0
      usedSma64 = 0 # SMA chunks have no DMA capabilities and are all 64bits;
      chunkTable = TableOutput.createTable(
            ( "Name", "Creator", "DMA address", "Size" ) )
      creatorAccumulator = {}

      for l in chunks:
         ( _, _name, _creator, startAddr, size ) = l.split()
         name = _name.decode()
         creator = _creator.decode()
         size = int( size )
         startAddr = int( startAddr, 16 )
         # SMA chunk has its DMA address set to 'U64_MAX'.
         isDmaCapable = ( startAddr != dmamem.dma_addr_none() )

         if isDmaCapable:
            usedDma += size

         t = creatorAccumulator.get( creator, (0, 0, 0) )

         if dmaStart32 <= startAddr <= dmaEnd32:
            usedDma32 += size
            creatorAccumulator[ creator ] = ( t[ 0 ] + size, t[ 1 ], t[ 2 ] )
         else:
            if isDmaCapable:
               usedDma64 += size
               creatorAccumulator[ creator ] = ( t[ 0 ], t[ 1 ] + size, t[ 2 ] )
            else:
               usedSma64 += size
               creatorAccumulator[ creator ] = ( t[ 0 ], t[ 1 ], t[ 2 ] + size )

         chunkTable.newRow( name, creator,
               bytesToStr( startAddr ) if isDmaCapable else "N/A (SMA chunk)",
               bytesToStr( size ) )

      print( "Chunks:" )
      print( chunkTable.output() )

      CreatorTable = TableOutput.createTable(
         ( "Creator", "32-bit DMA size",
           "Creator", "64-bit DMA size",
           "Creator", "SMA size" )
      )

      accDma32 = [ ( elem[ 0 ], elem[ 1 ][ 0 ] ) for elem in
                   creatorAccumulator.items() ]
      accDma64 = [ ( elem[ 0 ], elem[ 1 ][ 1 ] ) for elem in
                   creatorAccumulator.items() ]
      accSma64 = [ ( elem[ 0 ], elem[ 1 ][ 2 ] ) for elem in
                   creatorAccumulator.items() ]
      creatorAccSortedDma32 = sorted( accDma32, key=lambda i: i[ 1 ], reverse=True )
      creatorAccSortedDma64 = sorted( accDma64, key=lambda i: i[ 1 ], reverse=True )
      creatorAccSortedSma64 = sorted( accSma64, key=lambda i: i[ 1 ], reverse=True )

      for dma32, dma64, sma64 in zip( creatorAccSortedDma32, creatorAccSortedDma64,
                                      creatorAccSortedSma64 ):
         CreatorTable.newRow(
            dma32[ 0 ], bytesToStr( dma32[ 1 ] ),
            dma64[ 0 ], bytesToStr( dma64[ 1 ] ),
            sma64[ 0 ], bytesToStr( sma64[ 1 ] )
         )
      print( "Creator totals:" )
      print( CreatorTable.output() )

      reservationSize = 0
      if "SIMULATION_VMID" in os.environ:
         # dmamem region in simulation starts at 1 page so we don't give users 0 dma
         # address
         size = dmaEnd32 - dmaStart32 - 0x1000
         reservationSize = size
         reservationType = 'simulation'
      else:
         with open("/proc/dmamem") as f:
            physLimits = f.read().split("\n")[0]
            # 4 groups, phys start, phys end, size, and allocator identifier
            lims = re.match(r"([0-9a-f]+)\t([0-9a-f]+)\t([0-9a-f]+)\t([A-Z]+)",
                    physLimits)
            size = int( lims.group( 3 ), 16 )
            allocType = lims.group( 4 )
         if allocType == "CMA":
            reservationSize = size
            reservationType = 'CMA'
         else:
            reservationType = ''

      print( "Address ranges:" )
      table = TableOutput.createTable(
            ( "Type", "DMA address range start", "DMA address range end",
              "Allocated bytes" )
         )
      table.newRow( "32-bit DMA", bytesToStr( dmaStart32 ), bytesToStr( dmaEnd32 ),
                    bytesToStr( usedDma32 ) )
      table.newRow( "64-bit DMA", bytesToStr( dmaStart64 ), bytesToStr( dmaEnd64 ),
                    bytesToStr( usedDma64 ) )
      table.newRow( "SMA", "N/A", "N/A",
                    bytesToStr( usedSma64 ) )
      print( table.output() )

      print( "Total allocated bytes:" )
      if reservationSize:
         table = TableOutput.createTable( ( "Type", "Allocated bytes",
                                            f"Reserved {reservationType} bytes" ) )
         table.newRow( "DMA", bytesToStr( usedDma ), "N/A" )
         table.newRow( "SMA", bytesToStr( usedSma64 ), "N/A" )
         table.newRow( "DMA + SMA", bytesToStr( usedDma + usedSma64 ),
                       bytesToStr( reservationSize ) )
      else:
         table = TableOutput.createTable( ( "Type", "Allocated bytes" ) )
         table.newRow( "DMA", bytesToStr( usedDma ) )
         table.newRow( "SMA", bytesToStr( usedSma64 ) )
         table.newRow( "DMA + SMA", bytesToStr( usedDma + usedSma64 ) )
      print( table.output() )

   elif argv[0] == '-f':
      for l in chunks:
         chunkname = l.split()[1]
         for arg in argv[ 1: ]:
            m = re.match( six.ensure_binary( arg ), chunkname )
            if m:
               dmamem.free( chunkname )
               break
   elif argv[0] == '-i':
      print( 'Registered pci devices:' )
      for dev in dmamem.dump_iommu():
         print( dev )

   else:
      print( "usage: dmamem [ -f chunknames .. ]" )
      sys.exit( -1 )

if __name__ == '__main__':
   os.environ[ "DMAMEM_NO_CLEANUP" ] = "1"
   os.environ[ "DMAMEM_STD_ERR" ] = "1"
   main()

