#!/usr/bin/python3
#
# Do not use '/usr/bin/env' in the shebang!
# Othwerwise, `setproctitle` fails, because `len( 'env' ) < len( 'CliShell' )`.

# pylint: disable=consider-using-f-string

# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# Use our own tracing to avoid importing Tac
import CliShellTracing
CliShellTracing.install()

# pylint: disable=wrong-import-position
import curses.ascii
import os
import pickle
import pty
import select
import signal
import socket
import sys
import termios
import threading
import time
import traceback
import tty

import CliArgParser
import CliInputInterface
import CliInterface
import CliShellLib

import FastServUtil
import Plugins
import TerminalUtil

from CliTacLite import setproctitle # pylint: disable=no-name-in-module
from CliTacLite import threadsigmask_setall # pylint: disable=no-name-in-module
from CliTacLite import threadsigmask_isset # pylint: disable=no-name-in-module

import Tracing

# pylint: enable=wrong-import-position
th = Tracing.Handle( 'CliShell' )
t0 = th.trace0
t1 = th.trace1

SIGNALS_TO_FWD = ( signal.SIGINT, )
SELECT_TIMEOUT = 5
CTRL_C = curses.ascii.ctrl( 'C' ).encode()

def checkLibTac():
   with open( "/proc/self/maps" ) as f:
      content = f.read()
      if '/libtac.so' in content:
         sys.stderr.write( "ERROR: libtac linked in CliShell\r\n" )
         t0( "ERROR: libtac linked in CliShell:", content )
         sys.exit( 99 )

# NOTE for usage of threadsigmask_setall
# The main thread (the cli readline thread) is responsible for
# appropriately handling SIGINT and SIGTSTP signals. When a signal
# is received, the thread that receives the signal is random
# (although python promises it is just the main thread, it is a lie). If
# the another thread receives the signal, it will use the
# signal handler installed by the other thread, which longjump into the
# context of the other thread's execution.
# We must therefore ensure signals are only handled by the appropriate
# thread. We block all signals here, before spawning off multiple threads,
# so that all threads will start with these signals blocked. If a thread
# wishes to handle a specific signal, it must first unblock the signal.
# All the signals below are handled by libedit. We do not want the other
# threads to receive them.

class CliShell:
   def __init__( self, options ):
      self.sysname_ = options.sysname
      self.options_ = options
      self.signalSock_ = None
      self.cliInputSock_ = None
      self.requestSock_ = None
      self.cliShell_ = None
      self.primaryPty_ = None # this is the primary pseudo terminal
      self.secondaryPty_ = None # this is the secondary pseudo terminal
      self.stdoutCopyThread_ = None
      self.stdoutCopyThreadSockPair_ = self._createSocketPair()
      self.stdinCopyThread_ = None
      self.stdinCopyThreadSockPair_ = self._createSocketPair()
      self.originalTerminalControl_ = None
      self.savedSignalHandler_ = {}
      self.origStdinFd_ = os.dup( sys.stdin.fileno() )
      self.origStdoutFd_ = os.dup( sys.stdout.fileno() )
      self.standaloneShellArgs_ = None
      self.useReadline_ = ( not options.input_file and
                            not options.command and
                            not options.completions and
                            os.isatty( self.origStdinFd_ ) )

   def loadPlugins( self ):
      Plugins.loadPlugins( "CliShellPlugin", context=self )

   def setupConnection( self ):
      self._setupEnv()
      # Ignore SIGQUIT so that Ctrl-\ doesn't terminate the CLI session. Note
      # that we must do this after importing the Tac library because the
      # backtrace module overwrites the SIGQUIT handler.
      signal.signal( signal.SIGQUIT, signal.SIG_IGN )
      signal.signal( signal.SIGWINCH, self._sigWinchHandler )
      if not os.isatty( self.origStdinFd_ ):
         cliConnector = CliShellLib.NonTtyCliConnector()
         self.signalSock_ = cliConnector.connectToBackend( self.sysname_,
                                                           sys.argv[ 1: ],
                                                           os.environ,
                                                           os.getuid(),
                                                           os.getgid(),
                                                           sys.stdin.fileno(),
                                                           sys.stdout.fileno(),
                                                           sys.stderr.fileno() )
      else:
         if self.useReadline_:
            self.originalTerminalControl_ = termios.tcgetattr( self.origStdinFd_ )
            tty.setraw( self.origStdinFd_, termios.TCSADRAIN )
            TerminalUtil.enableIsig( self.origStdinFd_ )
         self.primaryPty_, self.secondaryPty_ = pty.openpty()
         TerminalUtil.copyWinSize( self.origStdinFd_, self.primaryPty_ )
         TerminalUtil.copyTtySpeed( self.origStdinFd_, self.primaryPty_ )
         if not self.useReadline_:
            # We do not want the intermetiate pty to alter output (BUG238875)
            TerminalUtil.disableOutputProcessing( self.primaryPty_ )
         ctty = os.ttyname( self.secondaryPty_ )
         cliConnector = CliShellLib.TtyCliConnector()
         self.signalSock_, self.requestSock_, self.cliInputSock_ = \
               cliConnector.connectToBackend( self.sysname_,
                                              sys.argv[ 1: ],
                                              os.environ,
                                              os.getuid(),
                                              os.getgid(),
                                              ctty, self.secondaryPty_ )
         if self.useReadline_:
            cliInterface = CliInterface.RemoteCliInterface( self,
                                                            self.requestSock_ )
            self.cliShell_ = CliShellLib.CliShell( cliInterface )
         os.dup2( self.secondaryPty_, sys.stdin.fileno() )
         sys.stdin = os.fdopen( sys.stdin.fileno(), 'r' )
         sys.__stdin__ = sys.stdin
         os.dup2( self.secondaryPty_, sys.stdout.fileno() )
         sys.stdout = os.fdopen( sys.stdout.fileno(), 'w' )
         sys.__stdout__ = sys.stdout
         self.stdoutCopyThread_ = threading.Thread( target=self._copyFdLoop,
                                                    args=( self.primaryPty_,
                                                           self.origStdoutFd_,
                                                           False,
                                             self.stdoutCopyThreadSockPair_[ 1 ] ) )
         self.stdoutCopyThread_.setDaemon( True )
         self.stdoutCopyThread_.start()
         self.stdinCopyThread_ = threading.Thread( target=self._copyFdLoop,
                                                   args=( self.origStdinFd_,
                                                          self.primaryPty_,
                                                          True,
                                             self.stdinCopyThreadSockPair_[ 1 ] ) )
         self.stdinCopyThread_.setDaemon( True )
         self.stdinCopyThread_.start()

   def checkForSignal( self, blocking=False ):
      if not blocking and not self._isFdReadable( self.signalSock_ ):
         t0( "checkForSignal: socket not readable" )
         return 255

      status = FastServUtil.readInteger( self.signalSock_ )
      if status is None:
         t0( "readInteger error" )
         return 255

      return status

   def execStandaloneShell( self ):
      if not self.useReadline_:
         return
      args = self.standaloneShellArgs_
      if not args:
         return
      self.noStandaloneShellTimeout()
      print( "Entering standalone shell." )
      # restore ctrl-c handler so rCli startup isn't interrupted
      signal.signal( signal.SIGINT, signal.SIG_DFL )
      try:
         os.execv( sys.executable, [ sys.executable ] + args )
      except OSError as e:
         print( "Cannot execute standalone shell:", os.strerror( e.errno ) )

   def setStandaloneShellTimeout( self, opts ):
      # Specify a timeout to get back the first prompt,
      # otherwise we drop to a standalone shell

      timeout = opts.standalone_shell_timeout
      if not timeout and self.useReadline_ and 'A4_CHROOT' not in os.environ:
         # default to 120 on physical duts
         timeout = 120
      if not timeout:
         return

      args = [ "/usr/bin/rCli", "--sysname", opts.sysname,
               "--privilege", str( opts.privilege ) ]
      if opts.standalone or opts.disable_aaa:
         args.append( "--disable-aaa" )
      self.standaloneShellArgs_ = args

      def _timeoutHandler( signum, stack ):
         raise CliShellLib.ConnectTimeout( "Cannot connect to ConfigAgent" )

      signal.signal( signal.SIGALRM, _timeoutHandler )
      signal.alarm( timeout )

   def noStandaloneShellTimeout( self ):
      if self.standaloneShellArgs_:
         signal.signal( signal.SIGALRM, signal.SIG_IGN )
         signal.alarm( 0 )
         self.standaloneShellArgs_ = None

   def waitForExit( self ):
      blocking = True
      try:
         # make sure we do not import Tac
         assert 'Tac' not in sys.modules
         assert '_Tac' not in sys.modules
         if self.cliShell_:
            t0( "creating input thread" )
            cliInputLoop = threading.Thread( target=self._cliInputLoop )
            cliInputLoop.setDaemon( True )
            cliInputLoop.start()
            # We don't want the forntend to get KeyboardInterrupt as it may
            # disrupt the data exchange with the backend. Ctrl-C is already
            # handled by libedit/CliInput and stdin copy thread properly.
            signal.signal( signal.SIGINT, signal.SIG_IGN )
            self.cliShell_.startReadlineLoop(
               firstPromptCallback=self.noStandaloneShellTimeout )
         else:
            self.startSignalForwarding()
      except Exception as e: # pylint: disable-msg=broad-except
         t0( "waitForExit exception:", str( e ) )
         t0( repr( traceback.format_exc() ) )
         blocking = False

      try:
         return self.checkForSignal( blocking=blocking )
      except Exception as e: # pylint: disable-msg=broad-except
         t0( "checkForSignal exception:", str( e ) )
         return 255
      finally:
         self.stopSignalForwarding()
         if self.stdoutCopyThread_:
            self._wakeupThread( self.stdoutCopyThread_ )
            self.stdoutCopyThread_.join()
         if self.stdinCopyThread_:
            self._wakeupThread( self.stdinCopyThread_ )
            self.stdinCopyThread_.join()

   def cleanup( self ):
      os.dup2( self.origStdinFd_, sys.stdin.fileno() )
      os.dup2( self.origStdoutFd_, sys.stdout.fileno() )
      if self.originalTerminalControl_:
         termios.tcsetattr( self.origStdinFd_, termios.TCSADRAIN,
                            self.originalTerminalControl_ )

   def _wakeupThread( self, thread ):
      if thread == self.stdoutCopyThread_:
         self.stdoutCopyThreadSockPair_[ 0 ].sendall( b'a' )
      elif thread == self.stdinCopyThread_:
         self.stdinCopyThreadSockPair_[ 0 ].sendall( b'a' )
      else:
         assert False

   def _sigWinchHandler( self, signum, frame ):
      if self.primaryPty_:
         TerminalUtil.copyWinSize( self.origStdinFd_, self.primaryPty_ )

   def _signalBackEnd( self, signum, frame ):
      t0( "signal backend", signum )
      try:
         FastServUtil.writeInteger( self.signalSock_, signum )
      except OSError:
         pass

   def _createSocketPair( self ):
      return socket.socketpair( socket.AF_UNIX, socket.SOCK_STREAM, 0 )

   def _isFdReadable( self, fd ):
      filesReadyToRead, _, _ = select.select( [ fd ], [], [], 0 )
      return fd in filesReadyToRead

   def disableIsig( self ):
      TerminalUtil.disableIsig( self.origStdinFd_ )

   def enableIsig( self ):
      TerminalUtil.enableIsig( self.origStdinFd_ )

   def startSignalForwarding( self ):
      for s in SIGNALS_TO_FWD:
         sigHandler = signal.getsignal( s )
         if sigHandler is not None:
            self.savedSignalHandler_[ s ] = sigHandler
            signal.signal( s, self._signalBackEnd )

   def stopSignalForwarding( self ):
      for s in SIGNALS_TO_FWD:
         if s in self.savedSignalHandler_:
            signal.signal( s, self.savedSignalHandler_[ s ] )

   def ungracefulExit( self ):
      if self.signalSock_:
         self.signalSock_.close()
      if self.cliInputSock_:
         self.cliInputSock_.close()

      os.kill( os.getpid(), signal.SIGTERM )

   def _copyFd( self, fromFd, toFd, xlateCtrlC ):
      try:
         buf = os.read( fromFd, 1024 )
         if not buf:
            return False
         wroteSoFar = 0
         while wroteSoFar < len( buf ):
            wroteSoFar += os.write( toFd, buf[ wroteSoFar : ] )
         if ( xlateCtrlC and self.cliShell_ and self.cliShell_.runCmd_ and
              CTRL_C in buf ):
            # Wait a little bit for data to be sent out so ^C can be
            # printed first by the backend before getting the signal
            t0( 'Ctrl-C detected, sending to backend' )
            time.sleep( 0.01 )
            self._signalBackEnd( signal.SIGINT, None )
         return True
      except OSError:
         return False

   def _copyFdLoop( self, fromFd, toFd, xlateCtrlC, wakeupSock ):
      threadsigmask_setall() # see Note at the top as to why we do this
      try:
         while True:
            filesReadyToRead, _, _ = select.select( [ fromFd,
                                                      wakeupSock.fileno() ], [], [],
                                                    SELECT_TIMEOUT )
            if wakeupSock.fileno() in filesReadyToRead:
               wakeupSock.recv( 1 )
               while self._isFdReadable( fromFd ):
                  if not self._copyFd( fromFd, toFd, xlateCtrlC ):
                     return
               return

            if fromFd in filesReadyToRead:
               if not self._copyFd( fromFd, toFd, xlateCtrlC ):
                  self.ungracefulExit()
                  return
      except Exception as e: # pylint: disable-msg=broad-except
         t0( 'terminating due to: %s' % e )
         self.ungracefulExit()

   def _cliInputLoop( self ):
      threadsigmask_setall() # see Note at the top as to why we do this
      while True:
         # we make sure that our threadmask here is still blocked
         assert threadsigmask_isset( signal.SIGINT )
         assert threadsigmask_isset( signal.SIGTSTP )
         requestData = FastServUtil.readBytes( self.cliInputSock_ )
         if not requestData:
            break
         request = pickle.loads( requestData )
         try:
            procedure = getattr( CliInputInterface, request.method_ )
            try:
               response = procedure( *request.args_, **request.kwargs_ )
            except: # pylint: disable-msg=bare-except
               response = sys.exc_info()[ 1 ]
            # remove protocol pin when everything is python3
            responseData = pickle.dumps(
               response,
               protocol=FastServUtil.PICKLE_PROTO )
            FastServUtil.writeBytes( self.cliInputSock_, responseData )
         except: # pylint: disable-msg=bare-except
            traceback.print_exc()
         # check that we do not pull in CliPlugins here
         excludedFiles = [ x for x in sys.modules if 'CliPlugin' in x ]
         assert not excludedFiles, 'importing CliPlugins disallowed: %s' % \
            ( ','.join( excludedFiles ) )

   def _setupEnv( self ):
      try:
         os.environ[ 'REALTTY' ] = os.ttyname( self.origStdinFd_ )
      except OSError:
         pass

      os.environ[ 'PWD' ] = os.getcwd()
      # We don't pass TRACE/TRACEFILE to ConfigAgent since it's already started.
      # If someone already enabled TRACE for CliShell, it's already honored, so
      # let's just pop them out
      os.environ.pop( 'TRACE', None )
      os.environ.pop( 'TRACEFILE', None )

def main():
   setproctitle( "CliShell" )

   options = CliArgParser.parseArgs( standaloneGuards=False )
   cliShell = CliShell( options )
   # Do this before setting up connections
   cliShell.loadPlugins()

   cliShell.setStandaloneShellTimeout( options )
   status = 255
   try:
      cliShell.setupConnection()
   except OSError as e:
      print( 'Unable to connect:', e.strerror, file=sys.stderr )
   except Exception as e: # pylint: disable=broad-except
      print( e, file=sys.stderr )
   else:
      if 'ABUILD' in os.environ:
         checkLibTac()
      try:
         status = cliShell.waitForExit()
      except OSError as e:
         print( 'Server closed connection:', e.strerror, file=sys.stderr )
   finally:
      cliShell.cleanup()

   if status == 255:
      cliShell.execStandaloneShell()
   sys.exit( status )

if __name__ == '__main__':
   main()
