#!/usr/bin/env python3
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import os
import re
import signal
import sys
from Syscall import gettid, tgkill
import threading
import traceback

import Ark
import CliCommon
import CliThreadCtrl
import InterruptiblePoller
import Tac
import TacSigint
import Tracing

th = Tracing.Handle( "CliPatchSigint" )
t0 = th.trace0

_stateLock = threading.Lock()
threadState_ = dict() # pylint: disable=use-dict-literal
useSignum = signal.SIGUSR2

# Monkey patch SIGINT handler in Tac.py:
#    enableSigintKeyboardInterrupt
#    disableSigintKeyboardInterrupt
#    setManualSigintHandler
#
# So they can work in a multi-threading environment.
#
# The new mechanism:
#
# A global per-thread KeyboardInterrupt state is maintained, which contains
#
# 1. enabled (set by enable/disableKeyboardInterrupt)
# 2. interrupted (causes Tac.checkInterrupt to raise KeyboardInterrupt)
# The per thread interrupted state is maintained in c code, so access to it is done
# via the CliThreadCtrl python extension. That's because it also needs to be
# accessed from C code, see the CliPrint feature.
#
# First of all, setManualSigintHandler() installs its own global interrupt
# handler which just does nothing.
#
# enable/disableSigintKeyboardInterrupt would set enabled
# for the thread.
#
# in CliServer's signal forwarding code, when it gets a SIGINT, it
# checks the current thread's enabled state.
#
# 1) if it's enabled, it sends a signal (*) to the thread. This would trigger
# -EINTR for any blocking calls and interrupt the caller. However, it would
# not be a KeyboardInterrupt (Python code expecting it would need to handle
# OSError).
#
# 2) if it's disabled, just set the interrupted flag for the thread.
#
# checkInterrupt() would check and clear the per-thread interrupted flag.
#
# (*) We can use SIGINT, but for now we use SIGUSR2 as it'd be nice to be able
# to interrupt CliServer when doing interactive testing.

class KeyboardInterruptState:
   def __init__( self, tid, tag ):
      self.tid = tid
      self.tag = tag
      self.enabled = False
      self.signalPipe_ = os.pipe()
      for fd in self.signalPipe_:
         os.set_blocking( fd, False ) # pylint: disable=no-member
      # start off with an un-interrupted state
      self.clearKeyboardInterrupt()

   def checkInterrupt( self ):
      interrupted = self.isInterrupted()
      self.clearKeyboardInterrupt()
      if interrupted:
         raise KeyboardInterrupt

   def _drainSignalPipe( self ):
      while True:
         try:
            os.read( self.signalPipe_[ 0 ], 4096 )
         except BlockingIOError: # pylint: disable=undefined-variable
            break

   def clearKeyboardInterrupt( self ):
      # pylint: disable-msg=c-extension-no-member
      CliThreadCtrl.resetInterrupted( self.tid )
      self._drainSignalPipe()

   def enabledIs( self, enabled ):
      assert gettid() == self.tid
      Tac.threadsigmask( useSignum, not enabled )
      self.enabled = enabled
      if not enabled:
         self.checkInterrupt()

   def interruptedIs( self, interrupted ):
      # pylint: disable-msg=c-extension-no-member
      if interrupted:
         CliThreadCtrl.setInterrupted( self.tid )
      else:
         self.clearKeyboardInterrupt()

   def isInterrupted( self ):
      # pylint: disable-msg=c-extension-no-member
      return CliThreadCtrl.isInterrupted( self.tid )

   def __str__( self ):
      return "TID: %d Enabled: %s Interrupted: %s - %s" % \
         ( self.tid,
           self.enabled,
           self.isInterrupted(),
           self.tag )

   def close( self ):
      # pylint: disable-msg=c-extension-no-member
      os.close( self.signalPipe_[ 1 ] )
      os.close( self.signalPipe_[ 0 ] )
      CliThreadCtrl.delInterrupted( self.tid )

def getState( tid ):
   return threadState_.get( tid )

def createState( tag='' ):
   tid = gettid()
   state = threadState_.get( tid )
   if state:
      # This should be imppossible, but escalation 432734 hit it,
      # so let's print some debugging output
      print( "WARN: state already exists: %s" % str( state ) )
      printThreads()
      deleteState()
   threadState_[ tid ] = KeyboardInterruptState( tid, tag )

# BUG722277: need interlock between state deletion and kill()
@Ark.synchronized( _stateLock )
def deleteState():
   tid = gettid()
   state = threadState_.get( tid )
   if state:
      state.close()
      del threadState_[ tid ]
   else:
      print( "WARN: state for tid", tid, "does not exist" )

def enableSigintKeyboardInterrupt():
   sys.stdout.flush()
   sys.stderr.flush()
   threadState_[ gettid() ].enabledIs( True )

def disableSigintKeyboardInterrupt():
   threadState_[ gettid() ].enabledIs( False )

def checkInterrupt():
   # some threads may not have state (e.g., PyServer threads )
   state = threadState_.get( gettid() )
   if state:
      state.checkInterrupt()

def clearKeyboardInterrupt():
   threadState_[ gettid() ].clearKeyboardInterrupt()

@Ark.synchronized( _stateLock )
def kill( tid ):
   state = threadState_.get( tid, None )
   if not state:
      t0( "kill no state found for", tid )
      return
   t0( "kill", tid, str( state ) )
   state.interruptedIs( True )
   # write to pipe so InterruptiblePoller can work
   try:
      os.write( state.signalPipe_[ 1 ], b"K" )
   except BlockingIOError: # pylint: disable=undefined-variable
      # buffer full, just ignore it
      pass
   if state.enabled:
      # this is needed for interrupting blocking calls in C
      tgkill( os.getpid(), tid, int( useSignum ) )

def sigintHandler( signum, frame ):
   pass

def interruptiblePollerProvider():
   tid = gettid()
   state = threadState_[ tid ]
   return InterruptiblePoller.InterruptiblePoller( state.signalPipe_[ 0 ] )

# The 'show agent logs crash' command looks for crash indications in log files.
# ConfigAgent sometimes prints debugs information that could look like crashes, so
# escape those crash signatures by sanitizing though printMangleCrashIndications.
crashIndicationRE = re.compile( "|".join( CliCommon.crashIndications ) )

def printMangleCrashIndications( lines ):
   def mangler( match ):
      match = match.group()
      start = match[ 0 ]
      end = match[ 1 : ]
      return f"{start}__mangled__{end}"
   print( re.sub( crashIndicationRE, mangler, "".join( lines ) ) )

def printThreads():
   # print all CLI threads
   #
   # Do not take any lock to avoid being blocked as this is a debugging function
   threadIdToTid = { t.ident: t.native_id for t in threading.enumerate() }

   # pylint: disable-msg=W0212
   for threadId, frame in sys._current_frames().items():
      # pylint: enable-msg=W0212
      print()
      tid = threadIdToTid.get( threadId )
      if tid:
         state = getState( tid )
         print( 'XXXXXX Start TID', tid, '-', state.tag if state else '' )
      else:
         print( 'XXXXXX Start THREAD ID', threadId )
      sys.stdout.flush()
      printMangleCrashIndications( traceback.format_stack( frame ) )
      print( 'XXXXXX END' )
      print()
      sys.stdout.flush()

   print( 'CLI Patch sigint' )
   for state in list( threadState_.values() ):
      print( state )
   sys.stdout.flush()

def init():
   # it only works for python3
   signal.signal( useSignum, sigintHandler )
   signal.siginterrupt( useSignum, True )
   TacSigint.check = checkInterrupt
   TacSigint.clear = clearKeyboardInterrupt
   # pylint: disable-msg=W0212
   TacSigint._setImmediate = enableSigintKeyboardInterrupt
   # pylint: disable-msg=W0212
   TacSigint._unsetImmediate = disableSigintKeyboardInterrupt
   InterruptiblePoller.setPollerProvider( interruptiblePollerProvider )
