# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import errno
import os
import pickle
import socket
import sys
import time
import traceback

import CliCommon
import CliInputWrapper as CliInput

import FastServUtil

import TerminalUtil
import Tracing

traceHandle = Tracing.Handle( 'CliShellLib' )
log = traceHandle.trace0
warn = traceHandle.trace1
info = traceHandle.trace2
trace = traceHandle.trace3
debug = traceHandle.trace4

# pylint: disable=c-extension-no-member

class ConnectTimeout( Exception ):
   pass

class CliShell:
   def __init__( self, cliInterface ):
      self.cliInterface_ = cliInterface
      self.termTitleCtrlSeq_ = None
      self.runCmd_ = False

   def updateTerminalCtrlSeq( self ):
      self.termTitleCtrlSeq_ = TerminalUtil.terminalCtrlSeq()

   def terminalTitleStr( self, title ):
      """ Returns bytes that will set the terminal's title, or
      empty string if the terminal cannot be controlled. """
      seq = self.termTitleCtrlSeq_
      if not seq:
         return b""
      if not isinstance( title, bytes ):
         title = title.encode()
      return b"%s%s%s" % ( seq[ 0 ], title, seq[ 1 ] )

   def complete( self, key, cmd ):
      # Wrap tab completion and help output
      result = ""
      if key == b"\t":
         result = self.cliInterface_.tabComplete( cmd )
      elif key == b"?":
         print( "?" )
         result = self.cliInterface_.printHelp( cmd )
      else:
         assert False

      return result

   def prompt( self ):
      promptPrefix = os.environ.get( 'CLI_PROMPT_PREFIX', '' )
      return self.cliInterface_.prompt( promptPrefix )

   def _writeTerminalTitleStr( self, prompt ):
      # Any terminal control characters are written outside
      # of readline, which includes its length in the
      # terminal width calculation. This prevents odd line
      # wrapping behavior.

      # Sometimes the first prompt is missing because the write got interrupted, so
      # retry the write in case of EINTR, no idea which signal is was!
      printed = False
      while not printed:
         try:
            os.write( sys.stdout.fileno(), self.terminalTitleStr( prompt ) )
            printed = True
         except OSError as e:
            if e.errno == errno.EINTR:
               continue
            if e.errno != errno.EIO:
               raise
            # write can give an EIO if the primary end of the pseudo-tty has
            # been closed (which can happen if the parent, login or sshd, has
            # been killed, perhaps by the OOM killer).  If this has happened,
            # the Cli process ought to have been sent a SIGHUP, which ought to
            # have killed it.  However, empirically it seems that this is not
            # always the case.
            raise EOFError() # pylint: disable=raise-missing-from

   def startReadlineLoop( self, firstPromptCallback=None ):
      log( "start readline loop" )

      TerminalUtil.enableCtrlZ( False )
      self.updateTerminalCtrlSeq()

      while True:
         trace( "Waiting for command" )
         try:
            prompt = self.prompt()
            trace( "Received prompt", repr( prompt.encode() ) )
            self._writeTerminalTitleStr( prompt )
            if firstPromptCallback:
               firstPromptCallback()
               firstPromptCallback = None
            line = ""
            # Only enable Ctrl-Z during readline so Ctrl-Z can still be
            # turned into KeyboardInterrupt, but disable it while running
            # commands, or we may hang if user presses ctrl-Z (BUG34718).
            with TerminalUtil.CtrlZ( True ):
               historyKeys = self.cliInterface_.getHistoryKeys()
               trace( "Reading line" )
               line = CliInput.readline( prompt, self.complete, *historyKeys )
               trace( "Line read:", line )
         except KeyboardInterrupt:
            trace( "Keyboard interrupt:" )
            print()
            self.cliInterface_.exitConfigMode()
            continue
         except EOFError:
            trace( "EOF ERROR:" )
            print()
            break
         except ConnectTimeout:
            raise
         except: # pylint: disable=bare-except
            log( repr( traceback.format_exc() ) )
            # the call below will blowup while pickling certain exception...
            self.cliInterface_.handleCliException( sys.exc_info(),
                                                   '(incomplete)' )
            continue

         if not line:
            continue

         try:
            trace( "Trying to run:", line )
            self.runCmd_ = True
            self.cliInterface_.runCmd( line, expandAliases=True )
         except SystemExit:
            trace( "System exit, leaving readline loop", line )
            break
         except: # pylint: disable=bare-except
            self.cliInterface_.handleCliException( sys.exc_info(), line )
         finally:
            self.runCmd_ = False
            historyKeys = self.cliInterface_.getOrigModeHistoryKeys()
            # pylint: disable-next=no-member
            if not CliInput.addHistory( line, *historyKeys ):
               # parent history mode does not exist
               for key in self.cliInterface_.getParentHistoryKeys():
                  CliInput.newHistoryMode( *key ) # pylint: disable=no-member
               # pylint: disable-next=no-member
               r = CliInput.addHistory( line, *historyKeys )
               assert r

      self.cliInterface_.exitConfigMode()
      self.cliInterface_.endSession()

class _RemoteRequest:
   def __init__( self, method, args, kwargs ):
      self.method_ = method
      self.args_ = args
      self.kwargs_ = kwargs

class _RemoteAttr:
   def __init__( self, method, cliInputSock ):
      self.method_ = method
      self.cliInputSock_ = cliInputSock

   def __call__( self, *args, **kwargs ):
      request = _RemoteRequest( self.method_, args, kwargs )
      requestData = pickle.dumps(
         request,
         protocol=FastServUtil.PICKLE_PROTO )
      FastServUtil.writeBytes( self.cliInputSock_, requestData )
      responseData = FastServUtil.readBytes( self.cliInputSock_ )
      if not responseData:
         return None
      response = pickle.loads( responseData )
      if ( issubclass( type( response ), Exception ) or
           isinstance( response, SystemExit ) ):
         raise response
      return response

class RemoteCliInput:
   def __init__( self, cliInputSock ):
      self.cliInputSock_ = cliInputSock

   def __getattribute__( self, name ):
      if name == 'cliInputSock_':
         return super().__getattribute__( name )
      return _RemoteAttr( name, self.cliInputSock_ )

class CliConnector:
   def _createArgStr( self, argv ):
      return '\x00'.join( argv )

   def _createEnvStr( self, env ):
      return '\x00'.join( [ f'{key}\x00{value}' for
                          key, value in env.items() ] )

   def _sendFds( self, sock, fds ):
      socket.send_fds( sock, [ b'\0' ], fds )

   def _createAndSendSock( self, sock ):
      s1, s2 = socket.socketpair( socket.AF_UNIX, socket.SOCK_STREAM, 0 )
      self._sendFds( sock, [ s2.fileno() ] )
      s2.close()
      return s1

   def _connectToBackend( self, sysname, argv, env, uid, gid, ctty ):
      sock = socket.socket( socket.AF_UNIX, socket.SOCK_STREAM, 0 )
      startTime = time.time()
      # the CliServer might not be up at the beginning of time. So we will instead
      # keep track of our startTime, and if we haven't been able to connect
      # due to a connection refused we will raise that error.
      while True:
         try:
            sock.connect( CliCommon.CLI_SERVER_ADDRESS_FMT % sysname )
            break # this means we were able to connect. break while loop
         except OSError as e:
            # if the error is something other than connection refuse we raise error
            if e.errno != errno.ECONNREFUSED:
               raise

            currTime = time.time()
            # if we have waited for more than 120 seconds we also raise the error
            # otherwise we will continue our loop
            if currTime - startTime >= 120:
               raise

            # we sleep a bit before we retry
            time.sleep( 0.1 )

      signalSock = self._createAndSendSock( sock )
      FastServUtil.writeString( sock, self._createArgStr( argv ) )
      FastServUtil.writeString( sock, self._createEnvStr( env ) )
      FastServUtil.writeString( sock, str( uid ) )
      FastServUtil.writeString( sock, str( gid ) )
      FastServUtil.writeString( sock, ctty )
      return sock, signalSock

class EapiCliConnector( CliConnector ):
   def __init__( self, stateless=True ):
      self.stateless_ = stateless

   def connectToBackend( self, sysname, argv, env, uid, gid ):
      sock, signalSock = self._connectToBackend( sysname, argv, env, uid, gid, '' )
      os.write( sock.fileno(), b'c' if self.stateless_ else b'd' )
      responseSock = self._createAndSendSock( sock )
      requestSock = self._createAndSendSock( sock )
      statisticsSock = self._createAndSendSock( sock )
      sock.close()
      return ( signalSock, responseSock, requestSock, statisticsSock )

class NonTtyCliConnector( CliConnector ):
   def connectToBackend( self, sysname, argv, env, uid, gid, stdinFd, stdoutFd,
                         stderrFd ):
      sock, signalSock = self._connectToBackend( sysname, argv, env, uid, gid, '' )
      os.write( sock.fileno(), b'u' )
      self._sendFds( sock, [ stdinFd, stdoutFd, stderrFd ] )
      sock.close()
      return signalSock

class TtyCliConnector( CliConnector ):
   def connectToBackend( self, sysname, argv, env, uid, gid, ctty, secondaryPty ):
      sock, signalSock = self._connectToBackend( sysname, argv, env, uid, gid, ctty )
      os.write( sock.fileno(), b't' )
      self._sendFds( sock, [ secondaryPty ] )
      requestSock = self._createAndSendSock( sock )
      cliInputSock = self._createAndSendSock( sock )
      sock.close()
      return signalSock, requestSock, cliInputSock

class SimpleCliConnector( CliConnector ):
   def connectToBackend( self, sysname, argv, env, uid, gid, stdoutFd, stderrFd ):
      sock, signalSock = self._connectToBackend( sysname, argv, env, uid, gid, '' )
      os.write( sock.fileno(), b's' )
      self._sendFds( sock, [ stdoutFd, stderrFd ] )
      requestSock = self._createAndSendSock( sock )
      sock.close()
      return signalSock, requestSock
