#!/usr/bin/env python3
# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import os
import glob
import argparse
import pwd
import subprocess
import re
import time
from collections import defaultdict

def getCmdLine( pid ):
   try:
      with open( '/proc/%d/cmdline' % pid ) as f:
         return f.read().replace( '\x00', '' )
   except OSError:
      return ''

# Represents the result of the count query.
class InotifyCount:
   def __init__( self, count ):
      self.count = count
   def print( self ):
      print( 'Number of FDs Used: %d' % self.count )

# Represents the result of the limits query.
class InotifyLimits:
   def __init__( self, limits ):
      self.limits = limits
   def print( self ):
      for category, maxVal in self.limits.items():
         print( '%s: %d\n' % ( category, maxVal ) )

class InotifyProc:
   def __init__( self, cmdline, count, owner ):
      self.cmdline = cmdline
      self.count = count
      self.owner = owner

# Represents the result of the ps query.
class InotifyPs:
   def __init__( self, ps ):
      self.ps = ps
   def print( self ):
      print( 'FDs Used by Each Process:\n' )
      for proc, info in self.ps.items():
         print( '%d %s %s' % ( info.count, proc, info.cmdline ) )

class InotifyFile:
   def __init__( self, filepath, wd ):
      self.filepath = filepath
      self.wd = wd

# Represents the result of the files query.
class InotifyFiles:
   def __init__( self, fileList, pid ):
      self.fileList = fileList
      self.pid = pid

   def print( self ):
      print( 'Watch descriptor and files in /var/shmem '
             'watched by the process with PID %d:\n' % self.pid )
      title = ( "Watch Descriptor", "File Path" )
      formatStr = '{:>%d}   {}' % len( title[ 0 ] )
      print( formatStr.format( *title ) + '\n' )
      for f in self.fileList:
         print( formatStr.format( f.wd, f.filepath ) )
      print( '-----\nTotal files: %d' % len( self.fileList ) )

class InotifyPid:
   def __init__( self, cmdline, fds ):
      self.cmdline = cmdline
      self.fds = fds

class InotifyPids:
   def __init__( self, pids ):
      self.pids = pids

   def printWatch( self ):
      print( 'Number of inotify watches in /var/shmem used by each processs:\n' )
      totalWatches = 0
      title = ( "Count", "Pid", "Name" )
      formatStr = '{:>10} {:>7}   {}'
      output = []
      for pid, pidObj in self.pids.items():
         count = sum( len( fd.fileList ) for fd in pidObj.fds.values() )
         totalWatches += count
         output.append( ( count, pid, pidObj.cmdline ) )

      output.sort()
      print( formatStr.format( *title ) + '\n' )
      for outputLine in output:
         print( formatStr.format( *outputLine ) )
      print( '-----\nTotal watches: %d' % totalWatches )

   def insertFile( self, pid, fd, path, wd ):
      # Inserts InotifyFile, creating corresponding pid and fd maps
      # if they do not exist
      pidsMap = self.pids
      if pid not in pidsMap:
         pidsMap[ pid ] = InotifyPid( cmdline=getCmdLine( pid ), fds={} )
      fdsMap = pidsMap[ pid ].fds
      if fd not in fdsMap:
         fdsMap[ fd ] = InotifyFiles( fileList=[], pid=pid )
      fdsMap[ fd ].fileList.append( InotifyFile( filepath=path, wd=wd ) )

# Represents the result of the user query.
class InotifyUser:
   def __init__( self, users ):
      self.users = users
   def print( self ):
      print( 'FDs Used by Each User:\n' )
      for info in self.users.values():
         print( '%d %s' % ( info.count, info.owner ) )

# Represents the result of the tree query.
class InotifyTree:
   def __init__( self, tree ):
      self.tree = tree
   def print( self ):
      for proc, info in self.tree.items():
         print( '%d %s %s' % ( info.count, proc, info.cmdline ) )

class InotifyQueryLib:
   def __init__( self ):
      pass

   def _isInotifyFile( self, filename ):
      try:
         return os.readlink( filename ) == 'anon_inode:inotify'
      except OSError:
         return False

   def _stripFilename( self, filename ):
      # Takes a filename of the form /proc/<PID>/fd/<NUM> and only returns 
      # /proc/<PID>/fd/, e.g. /proc/5/pid/4 -> /proc/5/pid/
      strippedName = filename.split( '/' )[ 0:-1 ]
      return '/'.join( strippedName )

   def _inotifySymlinkFiles( self, pid=None ):
      # Returns a list of all the files matching /proc/PID/fd/NUM that have
      # an inotify symlink, either for a specific pid, or for all pids if none
      # is provided
      pidExpression = str( pid ) if pid is not None else '[0-9]*'
      pathRe = '/proc/%s/fd/*' % pidExpression
      return [ filename for filename in glob.glob( pathRe )
            if self._isInotifyFile( filename ) ]

   def _inotifyFiles( self ):
      # Returns a list of all the folders matching /proc/PID/fd/ that have an
      # inotify symlink
      return [ self._stripFilename( filename )
            for filename in self._inotifySymlinkFiles() ]

   def _inotifyFileWatches( self, pid, fd=None ):
      # Returns a list of all watches (inode, watch descriptor, root directory path)
      # consumed by the given pid using inotify.
      # Filters out watches not under '/var/shmem'
      directoryFilter = '/var/shmem'
      fdExpression = str( fd ) if fd is not None else "*"
      inos = []
      pathRe = '/proc/%d/fd/%s' % ( pid, fdExpression )
      fdInfoRe = re.compile(
            r'inotify wd:([0-9a-f]+) ino:([0-9a-f]+) sdev:([0-9a-f]+)' )
      for n in [ filename.split( '/' )[ 4 ] for filename in glob.glob( pathRe )
            if self._isInotifyFile( filename ) ]:
         lines = []
         try:
            with open( '/proc/%d/fdinfo/%d' % ( pid, int( n ) ) ) as f:
               lines = [ line for line in f.readlines() if 'inotify' in line ]
         except OSError:
            pass
         for line in lines:
            # example line: "inotify wd:4 ino:bcd520c sdev:d5"...
            match = fdInfoRe.match( line )
            if match is None:
               raise ValueError
            wd = int( match.group( 1 ), 16 )
            ino = int( match.group( 2 ), 16 )
            sdev = int( match.group( 3 ), 16 )
            major = sdev >> 20
            minor = sdev & 0xfffff
            try:
               with open( '/proc/%d/mountinfo' % pid ) as f:
                  # example line: "777 748 0:51 / /var/shmem"...
                  regexStr = r"\d+ \d+ {}:{} \S+ {}".format(
                        major, minor, directoryFilter )
                  mountInfoRe = re.compile( regexStr )
                  for mountInfoLine in f.readlines():
                     match = mountInfoRe.match( mountInfoLine )
                     if match:
                        inos.append( ( ino, wd, directoryFilter ) )
            except OSError:
               pass
      return inos

   def _inotifyFileNames( self, watchesMap, baseDirectory='/var/shmem' ):
      # Param watchesMap: Map( ino -> List( ( pid, fd, wd ) ) )
      # Param baseDirectory: name of the base directory being checked for watches
      # Returns a populated InotifyPids object corresponding to the given map.

      # Find filepaths of watched files
      output = InotifyPids( pids=dict() ) # pylint: disable=use-dict-literal
      for root, dirnames, files in os.walk( baseDirectory ):
         for basename in files + dirnames:
            if not watchesMap:
               break
            path = os.path.join( root, basename )
            try:
               currentIno = os.stat( path ).st_ino
               if currentIno in watchesMap:
                  for ( pid, fd, wd ) in watchesMap[ currentIno ]:
                     output.insertFile( pid, fd, path, wd )
                  del watchesMap[ currentIno ]
            except OSError:
               pass

      for pidObj in output.pids.values():
         for wdObj in pidObj.fds.values():
            wdObj.fileList.sort( key=lambda x: x.wd )

      return output

   def _inotifyWatchesMap( self, pid=None ):
      # Returns a map from ino to list of ( pid, fd, wd ) that have an inotify watch,
      # either for a given pid, or for all pids if none is specified

      watchesMap = defaultdict( list ) # Map(ino -> List( ( pid, fd, wd ) ) )
      for filename in self._inotifySymlinkFiles( pid ):
         # filename is '/proc/PID/fd/FD'
         filenameSplit = filename.split( '/' )
         pid = int( filenameSplit[ 2 ] )
         fd = int( filenameSplit[ 4 ] )
         watches = self._inotifyFileWatches( pid, fd )
         for ino, wd, _ in watches:
            watchesMap[ ino ].append( ( pid, fd, wd ) )
      return watchesMap

   def filesModel( self, pid ):
      watchesMap = self._inotifyWatchesMap( pid )
      inotifyPids = self._inotifyFileNames( watchesMap )

      result = InotifyFiles( list(), pid ) # pylint: disable=use-list-literal
      if pid in inotifyPids.pids:
         for inotifyFiles in inotifyPids.pids[ pid ].fds.values():
            result.fileList += inotifyFiles.fileList
      return result

   def pidsModel( self ):
      watchesMap = self._inotifyWatchesMap()
      return self._inotifyFileNames( watchesMap )

   def limits( self ):
      maxLimits = {}
      maxNames = {
         'max_user_watches' : 'Max User Watches',
         'max_user_instances' : 'Max User Instances',
         'max_queued_events' : 'Max Queued Events',
      }
      for name in maxNames: # pylint: disable=consider-using-dict-items
         path = os.path.join( '/proc/sys/fs/inotify', name )
         with open( path ) as f:
            maxLimits[ maxNames[ name ] ] = int( f.read() )
      return InotifyLimits( limits=maxLimits )

   def watch( self ):
      return self.pidsModel()

   def file( self, pid ):
      return self.filesModel( pid )

   def count( self ):
      return InotifyCount( count=len( self._inotifyFiles() ) )

   def ps( self ):
      procs = {}
      for filename in self._inotifyFiles():
         if filename in procs:
            procs[ filename ].count += 1
         else:
            pid = int( filename.split( '/' )[ 2 ] )
            procs[ filename ] = \
               InotifyProc( count=1, cmdline=getCmdLine( pid ), owner=None )
      return InotifyPs( ps=procs )

   def user( self ):
      users = {}
      for f in self._inotifyFiles():
         uid = os.lstat( f ).st_uid
         if uid in users:
            users[ uid ].count += 1
         else:
            owner = pwd.getpwuid( uid ).pw_name
            users[ uid ] = InotifyProc( count=1, owner=owner, cmdline=None )
      return InotifyUser( users=users )

   def tree( self, pid ):
      pstreeOut = subprocess.check_output( [ 'pstree', '-pl', str( pid ) ],
                                           text=True )
      # The pstree output will look something like this:
      # foo(2539)---bar(6589)
      #           |--baz(11159)---qux(11204)
      # So we sanitize this output to only fetch the pids.
      children = [ int( pid[ 1:-1 ] ) for pid in re.findall( r'\(\d+\)', \
                   pstreeOut ) ]
      inotifyChildren = {}
      for childPid in children:
         pathRe = '/proc/%d/fd/*' % childPid
         for filename in glob.glob( pathRe ):
            if self._isInotifyFile( filename ):
               strippedName = self._stripFilename( filename )
               if strippedName in inotifyChildren:
                  inotifyChildren[ strippedName ].count += 1
               else:
                  inotifyChildren[ strippedName ] = \
                     InotifyProc( count=1, cmdline=getCmdLine( childPid ),
                                  owner=None )
      return InotifyTree( tree=inotifyChildren )

def main():
   def nonNegativeInt( num ):
      # isdigit() returns True if the number is a positive integer (including 0).
      if num.isdigit():
         return int( num )
      else:
         raise argparse.ArgumentTypeError( 'Please provide a non-negative integer' )

   def checkIsRootWarning():
      if os.geteuid() != 0:
         print( "*** Warning ***: Output may not be complete without root access." )


   parser = argparse.ArgumentParser( description='inotify query' )
   parser.add_argument( '--count', type=nonNegativeInt, nargs='?', const=-1,
                        metavar='TIMEOUT',
                        help='Print the total number of inotify fd\'s currently ' \
                              'allocated. The optional integer timeout will ' \
                              'enable looping indefinitely while printing.' )
   parser.add_argument( '--file', type=int, metavar='PID',
                        help='Given a PID, look up and print the paths of every '
                             'file being watched by that PID using inotify' )
   parser.add_argument( '--limits', action='store_true',
                        help='Print the maximum number of inotify watches and ' \
                             'instances for each user.\nAlso print the maximum ' \
                             'number of queued events.' )
   parser.add_argument( '--ps', action='store_true',
                        help='For each process using inotify, print the number ' \
                             'of inotify fd\'s owned by that process, the PID, ' \
                             'and the /proc/PID/cmdline.' )
   parser.add_argument( '--tree', type=int, metavar='PID',
                        help='Given a parent PID, scan the entire tree of child ' \
                             'processes and print out each process usage count. ' \
                             'Great for dumping the inotify cost of an Abuild, ' \
                             'for example.' )
   parser.add_argument( '--user', action='store_true',
                        help='For each process using inotify, print out the ' \
                             'number of fd\'s in use by each user.' )
   parser.add_argument( '--watch', action='store_true',
                        help='For each process using inotify, print the number '
                             'of inotify watches used by that process, the PID, '
                             'and the /proc/PID/cmdline.' )
   args = parser.parse_args()

   if not any( vars( args ).values() ):
      parser.error( 'No arguments provided' )

   checkIsRootWarning()

   inotifyQuery = InotifyQueryLib()
   if args.limits:
      inotifyQuery.limits().print()
   if args.count != None: # pylint: disable=singleton-comparison
      while True:
         inotifyQuery.count().print()
         # We use -1 as the default value when no optional timeout is specified.
         if args.count == -1: # pylint: disable=no-else-break
            break
         else:
            time.sleep( args.count )
   if args.ps:
      inotifyQuery.ps().print()
   if args.user:
      inotifyQuery.user().print()
   if args.tree:
      inotifyQuery.tree( args.tree ).print()
   if args.watch:
      inotifyQuery.watch().printWatch()
   if args.file:
      inotifyQuery.file( args.file ).print()

if __name__ == '__main__':
   main()
