#!/usr/bin/env python3
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from ArPyUtils import AsanHelper
import Tac, Tracing
import os
from pprint import pformat
from math import log, floor

if 'STEST' in os.environ:
   os.environ[ 'SIMULATION_VMID' ] = '1'

tracingHandle = 'CheckVmemLeak' 
t0 = Tracing.Handle( tracingHandle ).trace0      # pylint: disable-msg=E1101

def getAllocatedChunks( dmamem, simulationMode ):
   allocatedChunks = {}
   dmamemDump = dmamem.dump()
   for line in dmamemDump:
      chunkDump = line.split()
      chunkName = chunkDump[ 1 ].decode( "utf-8" )
      chunkDmaAddr = int( chunkDump[ 3 ].decode( "utf-8" ), 16 )
      chunkSize = int( chunkDump[ 4 ].decode( "utf-8" ) )
      chunkData = { "name" : chunkName, "address" : chunkDmaAddr,
                    "size" : chunkSize }
      if simulationMode:
         del chunkData[ "address" ]
         allocatedChunks[ chunkDmaAddr ] = chunkData
      else:
         del chunkData[ "name" ]
         allocatedChunks[ chunkName ] = chunkData
   return allocatedChunks

def getDmamemChunkMappings( procMaps, simulationMode ):
   chunkMappings = {}
   ignore = { "dmamem_lib", "libdmamem.so" }
   for mapping in procMaps:
      if len( mapping ) < 6:
         # Anonymous mapping, unrelated to dmamem
         continue
      pathname = mapping[5]
      # proc map entry example
      # /proc/11063/maps:7f83d1a7b000-7f83d255c000 rw-s 00000000 00:12 31935
      #                      /sys/kernel/dmamem/chunks/SandMcast
      chunkName = pathname.split( '/' )[-1].replace( "\\\\s", "/" )
      if chunkName not in ignore:
         offset = int( mapping[ 2 ], 16 )
         pid = int( mapping[ 0 ].split( '/' )[ 2 ] )
         addresses = mapping[ 0 ].split( ":" )[ 1 ].split( "-" )
         startAddr = int( addresses[ 0 ], 16 )
         endAddr = int( addresses[ 1 ], 16 )
         mappingSize = endAddr - startAddr
         chunkData = { "size" : mappingSize, "vAddr" : startAddr }
         chunkData[ "offset" ] = offset
         chunkData[ "chunkName" ] = chunkName

         if pid not in chunkMappings:
            chunkMappings[ pid ] = []
         chunkMappings[ pid ].append( chunkData )
   return chunkMappings

def isMappingLeakingUsingOffset( mappingData, allocatedChunks ):
   # dmamem in simulation mode offsets the dmaAddr of each chunk by 0x1000
   offset = 0x1000
   mappingOffset = mappingData[ "offset" ]
   return mappingOffset + offset not in allocatedChunks

def unsanitizeMappingName( mappingName ):
   # "/" gets replaced by "\s" and "\" gets replaced by "\\" when creating the chunk
   # filename for the map (see sanitize_filename() in dmamem.cpp).
   # The original chunkName could have contained a "\s", but thats probably rare.
   # There is a way to diferentiate between "\s" that were already in the chunkName
   # and those that came from "/"'s replacements, but we'd need a fancy regex that
   # makes sure only "\s" followed by an even amount of "\"s are replace by "/"
   return mappingName.replace( "\\s", "/" ).replace( "\\\\", "\\" )

def isMappingLeakingUsingChunkName( mappingData, allocatedChunks ):
   mappingName = unsanitizeMappingName( mappingData[ "chunkName" ] )
   return mappingName not in allocatedChunks

def compareMappedAndAllocatedChunks( isMappingLeaking, chunkMappings,
                                     allocatedChunks ):
   # dmamem in simulation mode offsets the dmaAddr of each chunk by 0x1000
   simulationOffset = 0x1000
   leaked = {}
   for pid, processMappings in chunkMappings.items():
      for mappingData in processMappings:
         if isMappingLeaking( mappingData, allocatedChunks ):
            chunkName = unsanitizeMappingName( mappingData[ "chunkName" ] )
            reportData = { "chunkName" : chunkName }
            reportData[ "vAddr" ] = hex( mappingData[ "vAddr" ] )
            reportData[ "size" ] = mappingData[ "size" ]
            if chunkName == "dmamem":
               # In simulation mode all chunksNames are dmamem, so we don't put them
               # in to reportData, but we need the offset
               del reportData[ "chunkName" ]
               reportData[ "dmaAddr" ] = mappingData[ "offset" ] + simulationOffset
   
            if pid not in leaked:
               leaked[ pid ] = { "mappings" : set(), "totalLeaked" : 0 }
            leaked[ pid ][ "mappings" ].add( tuple( sorted( reportData.items() ) ) )
            leaked[ pid ][ "totalLeaked" ] += mappingData[ "size" ]
   return leaked

# Separated into a function so we can execute just one Tac.run, since that was
# timing out tests
def addProcessNamesToLeaked( leaked ):
   allPids = [ "-p" ] * 2 * len( leaked.keys() )
   # interleave "-p" with pids, so we can use them as params in command
   allPids[ 1::2 ] = [ str( pid ) for pid in leaked.keys() ]

   with AsanHelper.disableSubprocessAsanPreload():
      output = Tac.run( [ 'ps' ] + allPids + [ '-o', 'pid=', '-o', 'command=' ],
                  stdout=Tac.CAPTURE, ignoreReturnCode=False ).strip()
   for line in output.splitlines():
      pid, processName = line.strip().split(' ', 1)
      pid = int( pid )
      processName = f"{processName}-{pid}"
      leaked[ processName ] = leaked[ pid ]
      del leaked[ pid ]
   return leaked

def getProcMaps( getFromAllProcs, isRoot ):
   pid = str( os.getpid() )
   if getFromAllProcs and isRoot:
      pid = '*'
   with AsanHelper.disableSubprocessAsanPreload():
      procMaps = Tac.run( [ f'grep -H dmamem /proc/{pid}/maps' ], stdout=Tac.CAPTURE,
                     ignoreReturnCode=True, shell=True )
   return [ line.split() for line in procMaps.splitlines() ]

def detectLeakedMappings( dmamem ):
   simulationMode = 'SIMULATION_VMID' in os.environ

   # In real mode all chunk mappings contain the chunk name, so we can use that to
   # match mappings and allocated chunks
   isMappingLeakingFunc = isMappingLeakingUsingChunkName
   if simulationMode:
      # In simulation mode dmamem creates all chunks in a file, so the filename of
      # the mappings all are the same.
      # But in simulation mode we can compare the mapping offset to the DMA address
      # of each chunk to match them.
      isMappingLeakingFunc = isMappingLeakingUsingOffset

   allocatedChunks = getAllocatedChunks( dmamem, simulationMode )
   t0( "allocatedChunks=", allocatedChunks )

   isRoot = os.getuid() == 0
   procMaps = getProcMaps( not simulationMode, isRoot )

   chunkMappings = getDmamemChunkMappings( procMaps, simulationMode )
   t0( "chunkMappings=", chunkMappings )
   leaked = compareMappedAndAllocatedChunks( isMappingLeakingFunc, chunkMappings,
         allocatedChunks )
   if leaked:
      leaked = addProcessNamesToLeaked( leaked )
   t0( "leaked=", leaked )
   return leaked

def humanReadableBytes( size ):
   units = [ "", "Ki", "Mi", "Gi", "Ti" ]
   sizeExponent = log( size, 1024 )
   unitIdx = floor( sizeExponent )
   significand = 1024 ** ( sizeExponent - unitIdx )
   return f"{significand:3.1f} {units[unitIdx]}B"

def checkVmemLeak( dmamem ):
   t0( "checkVmemLeak()" )
   leaked = detectLeakedMappings( dmamem )
   # Anything written to /tmp/debug_ files gets appended to test logs so it can be
   # detected by logres in bugs
   output_filename = "/tmp/debug_dmamem_vmem_leak.log"
   if leaked:
      with open( output_filename, "a" ) as output_file:
         amountLeaked = sum( len( i[ "mappings" ] ) for i in leaked.values() )
         leakingProcesses = sorted( leaked.keys(),
               key=lambda proc : leaked[ proc ][ "totalLeaked" ] )
         print( f"BUG956800: Virtual memory leak detected. {amountLeaked} chunks "
                "were freed but not unmapped by these processes: "
                f"{leakingProcesses}", file=output_file )
         for process in leakingProcesses:
            leakingData = leaked[ process ]
            sizeStr = humanReadableBytes( leakingData[ "totalLeaked" ] )
            leakedChunks = leakingData[ "mappings" ]
            print( f"{sizeStr} of virtual memory leaked by {process}. Leaked chunk "
                   f"mappings :\n{pformat(leakedChunks)}", file=output_file )
