#!/usr/bin/env python3
# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string
# pylint: disable=consider-using-with

import os
import re
import signal
import threading
import QuickTrace
import Tracing

traceHandle = Tracing.Handle( 'CliSubprocMgr' )
t0 = traceHandle.trace0
t1 = traceHandle.trace1

qt0 = QuickTrace.trace0
qt1 = QuickTrace.trace1
qv = QuickTrace.Var

def dbfile( sysname ):
   return f"/var/run/CliSubproc.db.{sysname}"

class SubprocMgr:
   # Managed subprocesses created by CLI
   # In case of ConfigAgent restart, we need to be able to find lingering
   # subprocesses and cleanup them up.
   #
   # Note all child processes are created with pdeathsig() set to SIGKILL,
   # but this is cleared whenever the subprocess changes its credentials
   # such as setuid binaries. So we keep track of child pids in a file
   # and use it to cleanup after restart.

   def __init__( self, sysname ):
      self.dbfile_ = dbfile( sysname )
      self.lock_ = threading.Lock()
      self.children_ = dict() # pylint: disable=use-dict-literal
      self.ppidRe_ = re.compile( r"PPid:\s+(\d+)" )
      # create the file now (so doCleanup can set perms)
      with open( self.dbfile_, "a" ):
         pass
      self.doCleanup()

   def addPid( self, pid, skipSigInt ):
      with self.lock_:
         qt1( "add child", qv( pid ), "skipSigInt", skipSigInt )
         t1( "add child", pid, "skipSigInt", skipSigInt )
         self.children_[ pid ] = skipSigInt
         self.syncDb()

   def removePids( self, pids ):
      # This is called by individual threads
      if pids:
         with self.lock_:
            for pid in pids:
               qt1( "remove child", qv( pid ) )
               t1( "remove child", pid )
               self.children_.pop( pid, None )
            self.syncDb()

   def skipChildSigInt( self, pid ):
      return self.children_.get( pid, False )

   def syncDb( self ):
      # write existing children to database
      try:
         with open( self.dbfile_, "w" ) as f:
            f.write( ' '.join( str( x ) for x in self.children_ ) )
      except OSError as e:
         t0( "fail to write dbfile:", e.strerror )
         qt0( "fail to write dbfile:", qv( e.strerror ) )

   def _isOrphan( self, pid ):
      try:
         with open( "/proc/%s/status" % pid ) as f:
            for line in f:
               if line.startswith( "PPid:" ):
                  ppid = 'unknown'
                  m = self.ppidRe_.match( line )
                  if m:
                     ppid = m.group( 1 )
                  qt1( "child", qv( pid ), "ppid", qv( ppid ) )
                  t1( "child", pid, "ppid", ppid )
                  return ppid == '1'
      except OSError:
         # just exited?
         pass

      return False

   def _kill( self, pid ):
      try:
         cmdline = open( "/proc/%s/cmdline" % pid ).\
                   read().replace( '\0', ' ' ).strip()
         qt0( "kill child", qv( pid ), repr( cmdline ) )
         t0( "kill child", pid, repr( cmdline ) )
         os.killpg( pid, signal.SIGKILL )
      except OSError:
         pass

   def doCleanup( self ):
      try:
         pids = open( self.dbfile_ ).read()
         for pid in pids.split():
            pid = int( pid )
            # only kill pid if its ppid is 1
            if self._isOrphan( pid ):
               self._kill( pid )
      except OSError:
         pass
      # make sure the file is writable by everyone
      try:
         os.chmod( self.dbfile_, 0o666 )
      except OSError:
         pass
      # clean up the file
      self.syncDb()
