# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

"""Support for running agents under Valgrind.

In order to run one or more agents under valgrind control, make sure the Valgrind
RPM is installed - if not, run:
   a4 yum install -y valgrind

Note that we build Valgrind (version 3.8.1 currently) ourselves from source,
including some required patches (including one to make Valgrind recognise the 'setns'
syscall).

Example usage:
   Runs a namespace DUT with Rib and Arp under Valgrind:
      AGENTS_UNDER_VALGRIND='Rib,Arp' \
         Art on

   Runs a namespace DUT with Rib under Valgrind:
      AGENTS_UNDER_VALGRIND='Rib' \
          RIB_VALGRIND_ARGS='--leak-check=full -v' \
          Art on

Note that the XXX_VALGRIND_ARGS variable is optional as a default set of arguments
will be used if it's not present. These default arguments includes a sensible
location for the Valgrind log file, along with including any relevant suppression
files (which we expect to find at /usr/share/Artools/valgrind-*.supp). If the
XXX_VALGRIND_EXTRA_ARGS variable is defined, any arguments specified will be appended
to the default arguments - for example:
   AGENTS_UNDER_VALGRIND='Rib' \
       RIB_VALGRIND_EXTRA_ARGS='--show-reachable=yes' \
       Art on

You can set agent specific environment variables when run under
valgrind with the XXX_VALGRIND_ENV variable, this is a comma separate
list of NAME=VAL environment variables that will be added to the
environment of the agent when started.
"""

import functools
import operator
import os
import re
import stat
import subprocess
import sys
import ArPyUtils
from pprint import pformat
from collections import defaultdict
from TableOutput import createTable

def isRunning():
   return 'AGENTS_UNDER_VALGRIND' in os.environ

def isValgrindAgent( agentName ):
   """ Returns True if the specified agent should be run under Valgrind (as
   specified by the 'AGENTS_UNDER_VALGRIND' environment variable). """

   # we can't run the heapchecker and Valgrind at the same time
   if 'HEAPCHECK' in os.environ:
      return False

   if isRunning():
      vgAgents = os.environ[ 'AGENTS_UNDER_VALGRIND' ].split( ',' )
      if agentName in vgAgents:
         return True
   return False

def getValgrindArgs( agentName ):
   """ Returns the Valgrind arguments that should be used when running the specified
   agent under Valgrind (as specified by the '<AGENTNAME>_VALGRIND_ARGS'
   environment variable). """

   vgArgsEnv = '%s_VALGRIND_ARGS' % agentName.upper()
   if vgArgsEnv in os.environ:
      vgArgs = os.environ[ vgArgsEnv ].split()
      if vgArgs:
         return vgArgs

   return None

def getDefaultValgrindArgs( agentName, logFilePath=None ):
   """ Returns default Valgrind arguments that can be used when running the specified
   agent under Valgrind. This will include any additional arguments specified by the
   "<AGENTNAME>_VALGRIND_EXTRA_ARGS" environment variable. """

   vgArgs = [ '-v',                      # verbose output
              '--num-callers=40',        # show longer stack frames
              '--gen-suppressions=all',  # generate suppressions for convenience
              '--tool=memcheck',         # check for memory issues
              '--leak-check=full',       # show details of each leak
              '--track-origins=yes',     # show origin of uninitialised values
              # Some agents fork() themselves to implement read-only show commands.
              # Valgrind needn't follow these processes as they may not clean up all
              # the parent process' allocations, and this is acceptable. Therefore,
              # silence Valgrind regarding these children.
              '--child-silent-after-fork=yes',
             ]

   # Valgrind runs afoul of boost's tagged pointers on 64-bit unless we tell
   # them to watch out.  The tagged pointers use the "spare" 16 bits of the
   # pointer value to store tagging information, and then mask it off before
   # using as an actual pointer. This argument causes valgrind to perform the
   # same masking to potential pointers it finds in the heap before checking if
   # they represent reachable references to allocated memory.
   # See bug/389998
   if ArPyUtils.arch() == 64:
      mask = "0xffffffffffff"
      vgArgs.append( "--tagged-pointer-mask=%s" % mask )

   # add log file location if specified
   if logFilePath:
      vgArgs = vgArgs + [ '--log-file=%s' % ( logFilePath ) ]

   # add some default suppression files (including the one for the agent itself)
   suppNames = [ 'all', 'python', 'tacc', agentName ]
   for suppName in suppNames:
      suppFile = '/usr/share/Artools/valgrind-%s.supp' % ( suppName )
      if os.path.isfile( suppFile ):
         vgArgs = vgArgs + [ '--suppressions=%s' % ( suppFile ) ]

   # add extra arguments to this agent provided via the environment
   vgArgsExtraEnv = '%s_VALGRIND_EXTRA_ARGS' % agentName.upper()
   if vgArgsExtraEnv in os.environ:
      vgArgsExtra = os.environ[ vgArgsExtraEnv ].split()
      if vgArgsExtra:
         vgArgs = vgArgs + vgArgsExtra
   return vgArgs

def getValgrindCmd( agentName, defaultLogFilePath=None ):
   """ Returns the Valgrind command that should be used when running the specified
   agent under Valgrind (as specified using the '<AGENT>_VALGRIND_ARGS'
   environment variable). If no explicit arguments have been specified, default
   arguments will be used instead. """

   vgArgs = getValgrindArgs( agentName )
   if not vgArgs:
      vgArgs = getDefaultValgrindArgs( agentName, defaultLogFilePath )

   return [ 'valgrind' ] + vgArgs

def addValgrindEnv( vgEnv, agentName ):
   """ Adds any environment variables (to the specified dictionary) that are needed
   when running an agent under Valgrind. """
   # Force STL to stop using its pool allocator, as it can result in lots of
   # spurious Valgrind output - see:
   #   https://gcc.gnu.org/onlinedocs/libstdc++/manual/debug.html#debug.memory
   vgEnv[ 'GLIBCXX_FORCE_NEW' ] = "1"

   # Similarly, force Python to use malloc instead of its own custom allocators
   # (available in Python 3.6 and later) - see:
   #   https://svn.python.org/projects/python/trunk/Misc/README.valgrind
   vgEnv[ 'PYTHONMALLOC' ] = "malloc"

   # Pass along any env variables from XXX_VALGRIND_ENV
   vgEnvName = f'{agentName.upper()}_VALGRIND_ENV'
   if vgEnvName in os.environ:
      vgEnvExtra = os.environ[ vgEnvName ].split( ',' )
      for envNameAndValue in vgEnvExtra:
         envName, envValue = envNameAndValue.split( '=', 2 )
         if envName in vgEnv:
            # Keep existing value.
            continue
         vgEnv[ envName ] = envValue

def getValgrindLogFilePath( cmd ):
   """ Returns the log file mentioned in the specified Valgrind command-line, or None
   if no log file was mentioned. """

   if isinstance( cmd, str ):
      cmd = cmd.split( ' ' )

   for arg in cmd:
      if arg.startswith( '--log-file=' ):
         return arg[ 11: ]

   return None

def expandValgrindLogFiles( logFilePattern ):
   """ Returns the list of Valgrind log files that match the specified pattern. This
   takes care of any %p tokens (which Valgrind expands to PID) that might be used in
   the log file name. """

   logFilePaths = []
   if "%p" in logFilePattern:
      logName = os.path.basename( logFilePattern )
      logDir = os.path.dirname( logFilePattern )

      logNameMatch = logName.replace( '%p', '[0-9]+' )
      logNameMatch = logNameMatch.replace( '%%', '%' )
      logNameMatch = "^%s$" % ( logNameMatch )
      logNameRe = re.compile( logNameMatch )
      for logFile in os.listdir(logDir):
         if logNameRe.match( logFile ):
            logFilePaths += [ os.path.join( logDir, logFile ) ]
   else:
      logFilePaths = [ logFilePattern.replace( '%%', '%' ) ]

   return logFilePaths


def _getTableHeaderVars():
   legendStr = "\tIR=Invalid Reads, IW=Invalid Writes\n"
   legendStr += "\tUR=Uninitialised Reads, US=Uninitialised Syscalls\n"
   legendStr += "\tIF=Invalid Frees, MF=Mismatched Frees\n"
   legendStr += "\tOB=Overlapping Blocks, FA=Fishy Args\n"
   legendStr += "\tDL=Definite Leaks, PL=Possible Leaks, IL=Indirect Leaks\n"
   legendStr += "\tSR=Still Reachable\n"

   rowNames = ( ( "File", "l" ), "IR", "IW", "UR", "US", "IF", "MF",
                  "OB", "FA", "DL", "PL", "IL", "SR" )
   return legendStr, rowNames


class ValgrindException( Exception ):
   '''
   Exception that can be raised when we detect errors/leaks in the valgrind logs
   '''
   pass # pylint: disable=unnecessary-pass

# pylint: disable-msg=E1101
# Pylint complains about member variables not declared
# They are declared via setattr in init
class ValgrindFileStats:
   statNames = [ 'invalidReads', 'invalidWrites', 'uninitReads',
                 'uninitSyscalls', 'invalidFrees', 'mismatchedFrees',
                 'overlappingBlocks', 'fishyArgs', 'definiteLeakCount',
                 'definiteLeakBytes', 'possibleLeakCount',
                 'possibleLeakBytes', 'indirectLeakCount',
                 'indirectLeakBytes', 'reachableCount',
                 'reachableBytes' ]

   def __init__( self, logFile, contents=None ):

      self.fileName_ = logFile

      # set of tracebacks per error type
      self.tracebackDict = defaultdict( set )
      print( "Processing valgrind log file %s" % self.fileName_ )

      # initialize all stat variables to 0
      for varName in self.statNames:
         setattr( self, varName, 0 )

      # matches the SUMMARY line Valgrind prints when the process terminates
      reSummaryLine = re.compile( r'^==\d+==\s\w+ SUMMARY:\s*$', re.MULTILINE )

      # matches the initial error line of a traceback
      reErrorLine = re.compile( r'^==\d+==\s\w+.*$' )

      # matches the final empty line of a traceback
      reTracebackEnd = re.compile( r'^==\d+==\s*$' )

      # matches error lines containing leak details
      reLeak = re.compile(
         r'^==\d+==\s([\d,]+).* bytes in [\d,]+ blocks are (\w+) lost' )
      reReach = re.compile(
         r'^==\d+==\s([\d,]+).* bytes in [\d,]+ blocks are still reachable' )

      if contents is None:
         with open( self.fileName_ ) as logf:
            contents = logf.read()

      # skip everything until we find a SUMMARY line - errors are reported when
      # they occur AND in the summary, so this avoids double counting
      summaryMatch = reSummaryLine.search( contents )
      assert summaryMatch, \
         "Valgrind log %s does not have a summary section" % self.fileName_
      startingPos = summaryMatch.end()
      remainingLines = contents[ startingPos : ].splitlines( True )

      # iterate each line looking for Valgrind errors and associated tracebacks
      traceback = ""
      errorType = None
      for line in remainingLines:
         if errorType:
            # we're parsing a traceback - just append to the current traceback
            traceback = traceback + line

            # check if this is the last line of the traceback
            tracebackEndMatch = reTracebackEnd.match( line )
            if tracebackEndMatch:
               self.tracebackDict[ errorType ].add( traceback )
               traceback = ""
               errorType = None
            continue

         # check if this is the first line of a traceback - however, even if we
         # get a match, we must check if it's one of the recognised errors
         errorMatch = reErrorLine.match( line )
         if not errorMatch:
            continue
         errorline = line

         # look for memcheck error messages in the log file - based on:
         #     http://valgrind.org/docs/manual/mc-manual.html#mc-manual.errormsgs
         if "Invalid read" in errorline:
            self.invalidReads += 1
            errorType = 'invalidRead'
         elif "Invalid write" in errorline:
            self.invalidWrites += 1
            errorType = 'invalidWrite'
         elif "Conditional jump or move depends on uninitialised value" \
               in errorline:
            self.uninitReads += 1
            errorType = 'uninitReads'
         elif "Syscall param" in errorline:
            self.uninitSyscalls += 1
            errorType = 'uninitSyscalls'
         elif "Invalid free" in errorline:
            self.invalidFrees += 1
            errorType = 'invalidFrees'
         elif "Mismatched free" in errorline:
            self.mismatchedFrees += 1
            errorType = 'mismatchedFrees'
         elif "Source and destination overlap" in errorline:
            self.overlappingBlocks += 1
            errorType = 'overlappingBlocks'
         elif "has a fishy" in errorline:
            self.fishyArgs += 1
            errorType = 'fishyArgs'
         elif "lost in loss record" in errorline:
            leakMatch = reLeak.match( errorline )
            if leakMatch:
               leakBytes = int( leakMatch.group( 1 ).replace( ',', '' ) )
               leakType = leakMatch.group( 2 )
               if leakType == 'definitely':
                  self.definiteLeakCount += 1
                  self.definiteLeakBytes += leakBytes
                  errorType = 'definiteLeak'
               elif leakType == 'possibly':
                  self.possibleLeakCount += 1
                  self.possibleLeakBytes += leakBytes
                  errorType = 'possibleLeak'
               elif leakType == 'indirectly':
                  self.indirectLeakCount += 1
                  self.indirectLeakBytes += leakBytes
                  errorType = 'indirectLeak'
         elif "still reachable in loss record" in errorline:
            reachMatch = reReach.match( errorline )
            if reachMatch:
               reachBytes = int( reachMatch.group( 1 ).replace( ',', '' ) )
               self.reachableCount += 1
               self.reachableBytes += reachBytes
               errorType = 'stillReachable'

         if errorType:
            # we have found the first line of a traceback
            traceback = line

   def tableStatsValues( self ):
      '''
      Returns the list of valgrind statistic values displayed in the table
      in the order of self.statNames
      statVariables ending with Bytes are ignored in this case
      '''
      # To dislay stats in table, we don't display the Byte values
      return [ getattr( self, statName ) for statName in self.statNames
               if not statName.endswith( 'Bytes' ) ]

   def tracebackDump( self ):
      '''
      Returns the list of tracebacks separated by the errortype
      '''
      tracebackStr = ''
      if self.tracebackDict:
         for errorType, tracebackSet in self.tracebackDict.items():
            tracebackStr += "========= Tracebacks of type %s =========\n" % errorType
            for traceback in tracebackSet:
               tracebackStr += traceback
               tracebackStr += "\n"
      return tracebackStr

   def hasErrors( self ):
      return any( self.tableStatsValues() )

   def getErrorDict( self ):
      '''
      Returns the non zero stat values in a dictionary
      '''
      return { statName : getattr( self, statName )
               for statName in self.statNames
               if getattr( self, statName ) }

   def addToTable( self, table ):
      '''
      Adds a new row with the current file statistics to the table
      '''
      logFileName = os.path.basename( self.fileName_ )
      table.newRow( logFileName, *self.tableStatsValues() )

   def printErrorSummary( self ):
      '''
      Prints the error summary from the valgrind log file.
      '''
      reErrorSummaryLine = re.compile( r'^[=-]{2}\d+[=-]{2}\s*ERROR\s*SUMMARY:\s*' )

      # Print all lines after ERROR SUMMARY in Valgrind log
      print( 'Error summary in %s' % self.fileName_ )
      errorSummaryFound = False
      with open( self.fileName_ ) as logf:
         for line in logf:
            if not errorSummaryFound:
               if reErrorSummaryLine.match( line ):
                  errorSummaryFound = True
               else:
                  continue

            print( line.rstrip() )

   def __str__( self ):
      '''
      Prints the valgrind statistics of the valgrind logfile
      '''
      statStr = "Valgrind stats for log file %s\n" % self.fileName_
      legendStr, rowNames = _getTableHeaderVars()
      statStr += legendStr
      statsTable = createTable( rowNames, tableWidth=80 )
      self.addToTable( statsTable )
      statStr += statsTable.output()
      return statStr

class ValgrindDirStats:
   '''
   Stores a list of ValgrindFileStats for a log directory
   '''
   def __init__( self, logDir ):
      self.logDir_ = logDir
      self.dirStats_ = {}
      self.hasErrors_ = False

   def hasErrors( self ):
      return self.hasErrors_

   def addValgrindFile( self, filePath ):
      '''
      Saves a ValgrindFileStats of filePath into the dirStats dictionary
      '''
      logFileStat = os.stat( filePath )
      # make the file readable by everyone (for convenience)
      os.chmod( filePath,
                logFileStat.st_mode | stat.S_IROTH | stat.S_IRGRP )
      fileStats = ValgrindFileStats( filePath )
      if fileStats.hasErrors():
         self.hasErrors_ = True
      self.dirStats_[ filePath ] = fileStats

   def addValgrindFiles( self, logFiles ):
      '''
      Adds each logFile stats to the ValgrindDirStats object
      '''
      print( 'Adding valgrind files for dir %s' % self.logDir_ )
      for logFilePath in logFiles:
         self.addValgrindFile( logFilePath )

   def __str__( self ):
      '''
      Prints the valgrind statistics of all the ValgrindDirStats
      '''
      statStr = ""
      dirErrors = {}
      logDumpStr = ""
      legendStr, rowNames = _getTableHeaderVars()
      if self.dirStats_:
         statsTable = createTable( rowNames, tableWidth=80 )
         for logFilePath, fileStat in self.dirStats_.items():
            errorDict = fileStat.getErrorDict()
            if errorDict:
               logName = os.path.basename( logFilePath )
               dirErrors[ logName ] = errorDict
               logDumpStr += fileStat.tracebackDump()
            fileStat.addToTable( statsTable )
         statStr += "\nValgrind stats for log files in %s:\n" % self.logDir_
         statStr += legendStr
         statStr += statsTable.output()
         if dirErrors:
            statStr += "\nErrors:\n" + pformat( dirErrors, indent=3 ) + "\n\n"
            statStr += logDumpStr
      else:
         statStr += "No Valgrind log files processed in %s" % self.logDir_

      return statStr

class ValgrindStatsCol:
   '''
   Extracts, prints and verifies valgrind statistics
   for agents run under AGENTS_UNDER_VALGRIND
   '''
   def __init__( self ):
      # self.stats will have the following structure
      # { <logdir> :
      #       <file1> : { <memcheckstats1> : count .. }
      #    ...
      # }
      self.stats = {}

   def getDirStats( self, logFileDir ):
      return self.stats.setdefault( logFileDir,
                                    ValgrindDirStats( logFileDir ) )

   def processValgrindOutput( self, cmd ):
      '''
      Identify valgrind log directory from the cmd-line args of the agent.
      Add the files represented by the logFilePathStr
      '''
      print( 'Processing Valgrind output from cmd %s' % cmd )
      logFilePathStr = getValgrindLogFilePath( cmd )
      if logFilePathStr:
         logFiles = expandValgrindLogFiles( logFilePathStr )
         logFileDir = os.path.dirname( logFilePathStr )
         dirStats = self.getDirStats( logFileDir )
         dirStats.addValgrindFiles( logFiles )

   def checkErrors( self ):
      '''
      Raises a ValgrindException if any valgrind errors were reported in any logs
      '''
      def hasErrors():
         return any( dirStats.hasErrors() for dirStats in
                     self.stats.values() )

      def getValgrindErrors():
         errors = [ str( dirStats ) for dirStats in self.stats.values()
                    if dirStats.hasErrors() ]
         return functools.reduce( operator.add, errors, '' )

      # note: when the exception is raised, Python will dump the local variables - if
      # we stored the Valgrind errors in a local variable, it would also be dumped
      # (in an ugly fashion that adds a lot of clutter to the output)
      if hasErrors():
         raise ValgrindException( getValgrindErrors() )

# Global statCol object which will keep state of
# all the processed agents from all namespace duts
_valgrindStats = ValgrindStatsCol()

def processValgrindOutput( cmd ):
   '''
   Examines any Valgrind output files produced by the specified command and
   updates the global Valgrind statistics collection.
   '''
   _valgrindStats.processValgrindOutput( cmd )

def maybeRaiseValgrindException():
   '''
   Raises a ValgrindException if any errors were present in any of the
   Valgrind output files that were previously examined using
   processValgrindOutput().
   '''
   _valgrindStats.checkErrors()

def runPythonScriptUnderValgrind( scriptFileName, scriptArgs=None, tool='memcheck',
                                  logDir='/tmp/Valgrind/' ):
   """
   Runs the python script <scriptFileName> under valgrind, and validates that there
   were no errors in the summary.
   The script file must exist either in $PATH, or in sys.path. It need not be
   executable.
   """
   # Needed to read the log file after the test.
   ArPyUtils.runMeAsRoot()

   scriptPath = None
   if scriptFileName.startswith( '/' ):
      scriptPath = scriptFileName
   else:
      binPaths = [ d for d in os.environ.get( 'PATH', '' ).split( ':' ) if d ]
      pathsToSearch = sys.path + binPaths
      for pathDir in pathsToSearch:
         scriptPath_ = os.path.join( pathDir, scriptFileName )
         if os.path.exists( scriptPath_ ):
            scriptPath = scriptPath_
            break
      if not scriptPath:
         assert 0, "Unable to find %s in %r" % ( scriptFileName, pathsToSearch )

   if scriptArgs is None:
      scriptArgs = []

   # Needed to read the log file after the test.
   if not os.path.exists( logDir ):
      os.makedirs( logDir )
   logBaseName = os.path.splitext( os.path.basename( scriptFileName ) )[ 0 ]
   logPath = os.path.join( logDir, '%s-%d.log' % ( logBaseName, os.getpid() ) )

   # For some undetermined reason, sudo is required here and runMeAsRoot is
   # insufficient to resolve whatever environment problems are causing valgrind
   # and other Sand plugins to interact badly.
   cmd = [ 'sudo', '/usr/bin/valgrind',
           '--suppressions=/usr/share/Artools/valgrind-python.supp',
           '--suppressions=/usr/share/Artools/valgrind-all.supp',
           '--trace-children=yes',
           '--tool=' + tool,
           '--log-file=' + logPath,
           '-s',
           'python', scriptPath,
           ] + scriptArgs
   subprocess.call( cmd )

   processValgrindOutput( ' '.join( cmd ) )
   maybeRaiseValgrindException()
