#!/usr/bin/env python3
# Copyright (c) 2010, 2011 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import codecs
import io
import os
import select
import six
import socket
import sys
import threading

import Cell
import InterruptiblePoller
import PyClient
import Tac
import TacSigint
import Tracing
from AgentCommandSocketProvider import socketProvider

t0 = Tracing.trace0

threadLock = threading.Lock()
requestNum = 1

testAgentCommandRequest = None

def getRequestNum():
   global requestNum
   threadLock.acquire() # pylint: disable=consider-using-with
   rNo = requestNum
   requestNum += 1
   threadLock.release()
   return rNo

def _configPath( dirName ):
   return Cell.path( "agent/commandRequest/config/" ) + dirName

class RunSocketCommandException( Exception ):
   pass

# Create a socket to be used by processes to communicate. The most typical usage
# usage for this is the execution of CLI commands. The sequence of events in this
# case is the following:
# - the cli creates the socket, specifying the command to run
# - the platform agent reacts to the socket creation and run the desired command
# - the platform agent send the command output to the socket
# - the cli process receives the output and prints it, then closes the socket
def createSocket( pyClientOrEm, sysname, dirName, requestName,
                  command, commandType,
                  keepalive, asyncCommand, timeout, outputFormat, revision=1 ):

   sock, targetSpec = socketProvider().createSocket( requestName, timeout )
   if sock is None:
      return None, None

   # Add socket to the agentCommandRequest directory. This directory must be
   # mounted by the agent using the cli. It is not created by Sysdb.

   t0( "Creating sysdb entity" )

   assert isinstance( revision, int )
   assert revision > 0
   if isinstance( pyClientOrEm, PyClient.PyClient ):
      cmd = """import Tac
r = Tac.root.entity[ '{}/Sysdb/{}' ].newEntity(
      'Agent::AgentCommandRequest', {!r})
r.target = {!r}
r.commandType = {!r}
r.commandString = {!r}
r.keepalive = {!r}
r.asyncCommand = {!r}
r.outputFormat = {!r}
r.revision = {!r}
r.initialized = True""".format( sysname,
                            _configPath( dirName ),
                            requestName,
                            targetSpec,
                            commandType,
                            command,
                            keepalive,
                            asyncCommand,
                            "of" + outputFormat.title(),
                            int( revision ) )
      pyClientOrEm.execute( cmd )
   else:
      agentRoot = pyClientOrEm.root().entity[ _configPath( dirName ) ]
      request = agentRoot.newEntity( 'Agent::AgentCommandRequest', requestName )
      request.target = targetSpec
      request.commandType = commandType
      request.commandString = command
      request.keepalive = keepalive
      request.asyncCommand = asyncCommand
      request.revision = revision
      # used by cli show commands with a capi model based on 'deferredModel'
      request.outputFormat = "of" + outputFormat.title()
      # This must be the last attribute to change as the agentCommandRequest
      # react to it
      request.initialized = True

   # Listen for incoming connections
   t0( "Waiting for connection" )

   try:
      ( connection, client_address ) = sock.accept()
      t0( f"Received connection from {client_address}" )
   except socket.timeout:
      t0( "Error: connection timeout" )
      deleteSocket( pyClientOrEm, sysname, dirName, requestName, targetSpec )
      return None, None
   finally:
      # close listening socket asap as it is no longer needed. This avoids btest
      # failure in AgentCliTest.py due to 'ResourceWarning: unclosed socket.socket'
      sock.close()

   t0( "Connection received" )
   return connection, targetSpec

def deleteSocket( pyClientOrEm, sysname, dirName, requestName, targetSpec ):
   socketProvider().deleteSocket( requestName, targetSpec )

   t0( "Deleting sysdb entity" )
   if isinstance( pyClientOrEm, PyClient.PyClient ):
      cmd = """import Tac
Tac.root.entity[ '{}/Sysdb/{}' ].deleteEntity( '{}' )""".format( sysname,
                                                          _configPath( dirName ),
                                                          requestName )
      pyClientOrEm.execute( cmd )
   else:
      agentRoot = pyClientOrEm.root().entity[ _configPath( dirName ) ]
      agentRoot.deleteEntity( requestName )

# Same as runSocketCommand, but sets up the environment for 'cliprint', a combined
# ways of printing text or json, in a steamed fashion (no intermediate model).
# The print APIs (CliPrint.h) used by the handler function, running in either Cli or
# Agent context, will pick the relevant args from the env (unbeknownst/transparent
# to the API user.
#
# Similarly to the Cli handlers, returns model if it was rendered (if provided),
# or None if it was not rendered and there were only errors.
# Note that if your model does not inherit from DeferredModel, you must call
# cliPrinted( model ) to create a special instance of the model to return from the
# handler (this isn't done automatically because Agent cannot depend on Cli).
def runCliPrintSocketCommand( entityManager, dirName, commandType, command, mode,
                              keepalive=False, asyncCommand=False, timeout=120,
                              connErrMsg=None, stringBuff=None,
                              forceOutputFormat=None, model=None ):

   outputFormat = forceOutputFormat if (
         forceOutputFormat ) else mode.session_.outputFormat_
   try:
      runSocketCommand( entityManager, dirName, commandType, command,
                        keepalive=keepalive, asyncCommand=asyncCommand,
                        timeout=timeout,
                        stringBuff=stringBuff,
                        outputFormat=outputFormat,
                        throwException=True,
                        connErrMsg=connErrMsg,
                        revision=mode.session_.requestedModelRevision() )
      return model
   except RunSocketCommandException as e:
      mode.addError( str( e ) )
      return None

def _handleOutput( chunks, stringBuff, errorResponses,
                   throwException, error=False ):
   """
   returns tuple(
      error: Was an error found, which should be thrown by the caller
      message: The message to be thrown by the caller
   )
   """
   def bufferedWriteToFd( chunks, fd ):
      with io.BufferedWriter( io.FileIO( fd, mode='wb', closefd=False ) ) as bw:
         for chunk in chunks:
            bw.write( chunk )
         bw.flush()
   if stringBuff:
      decoder = codecs.getincrementaldecoder( 'utf-8' )()
      for chunk in chunks:
         stringBuff.write( decoder.decode( chunk ) )
      stringBuff.write( decoder.decode( b'', final=True ) )
      if throwException:
         response = stringBuff.getvalue()
         for er in errorResponses:
            if er in response:
               error = True
               break
         if error:
            return ( True, response )
   elif error:
      if throwException:
         return ( True, b"".join( chunks ) )
      sys.stderr.flush()
      bufferedWriteToFd( chunks, sys.stderr.fileno() )
   else:
      if throwException:
         response = b"".join( chunks )
         for er in errorResponses:
            if six.ensure_binary( er ) in response:
               return ( True, response )
      sys.stdout.flush()
      bufferedWriteToFd( chunks, sys.stdout.fileno() )
   return ( False, None )

def _handleConnection( connection, chunks, stringBuff, errorResponses,
                       throwException, doBuffering ):
   data = connection.recv( 4096 )
   error, message = False, None
   if data:
      if doBuffering:
         chunks.append( data )
      else:
         # It should not be possible to hit this assert if we've passed the checks
         # in runSocketCommand.
         assert stringBuff is None
         _handleOutput( [ data ], stringBuff, errorResponses, throwException )
   else:
      t0( "Receive completed" )
      if doBuffering:
         # Finally emit all of the buffered chunks into the stdout or stringBuff
         error, message = _handleOutput( chunks, stringBuff, errorResponses,
                                         throwException )
      return True, error, message
   return False, None, None

def _cohabLoop( connection, chunks, stringBuff, errorResponses, throwException,
                doBuffering, timeout ):
   startTime = Tac.now()
   error, message = False, None
   while True:
      # small timeout value to run activities if select gets blocked
      ready = select.select( [ connection ], [], [], 0.001 )
      if ready[ 0 ]:
         msgComplete, error, message = _handleConnection( connection, chunks,
               stringBuff, errorResponses, throwException, doBuffering )
         if msgComplete:
            return error, message
      else:
         if Tac.now() - startTime < timeout:
            # These mini timeouts are only to flush activity loop
            if os.environ.get( 'SIMULATION_SAND' ):
               # SandDanz uses AgentCommandRequest.runSocketCommand in virtualTime
               # btests that require nonzero time to handle reads and writes from
               # the os socket
               Tac.runActivities( 0.1 )
            else:
               Tac.runActivities( 0 )
            continue
         chunks.append( "Cli connection timeout" )
         error, message = _handleOutput( chunks, stringBuff, errorResponses,
                                         throwException, error=True )
         return error, message

def _nonCohabLoop( connection, chunks, stringBuff, errorResponses, throwException,
                   doBuffering, timeout ):
   p = InterruptiblePoller.getPoller()
   p.register( connection, select.POLLIN )
   while True:
      eventList = p.poll( timeout * 1000 if timeout else None )
      if eventList:
         msgComplete, error, message = _handleConnection( connection, chunks,
               stringBuff, errorResponses, throwException, doBuffering )
         if msgComplete:
            return error, message
      else:
         chunks.append( "Cli connection timeout" )
         error, message = _handleOutput( chunks, stringBuff, errorResponses,
                                         throwException, error=True )
         return error, message

def _dirExists( pyClientOrEm, sysname, dirName ):
   if isinstance( pyClientOrEm, PyClient.PyClient ):
      errStr = "Error: no Tac::Dir"
      dirCmd = """import Tac
if Tac.root.entity.get( '{}/Sysdb/{}' ) is None:
   print( \"{}\" ) """.format( sysname, _configPath( dirName ), errStr )

      output = pyClientOrEm.execute( dirCmd )
      if errStr in output:
         t0( "%s does not exist" % _configPath( dirName ) )
         return False
   elif pyClientOrEm.root().entity.get( _configPath( dirName ) ) is None:
      t0( "%s does not exist" % _configPath( dirName ) )
      return False
   return True

def runSocketCommand( entityManager, dirName, commandType, command, keepalive=False,
                      asyncCommand=False, timeout=120, stringBuff=None,
                      forceBuffering=False,
                      outputFormat="unset", throwException=False, errors=None,
                      connErrMsg=None, revision=1 ):
   """Wrapper around createSocket/deleteSocket. It is mainly used by the Cli.
   It creates a socket and a corresponding entity of type AgentCommandRequest
   in sysdb (the socket name and the sysdb directory are passed as parameter).
   The platform agent has to create the sysdb directory, and it reacts to the
   entity creation through a AgentCommandRequestDir state machine. This state
   machine will call a callback on the agent to run the command and send the
   output via the socket.
   runSocketCommand also prints the output received and deletes the socket once
   the command has completed.

   stringBuff: (a file-like object) If specified, all output will be sent to it
               rather than to stdout/stderr. Passing stringBuff will cause all output
               to be buffered before written to stringBuff (output will NOT) be
               streamed directly into this buffer from the socket. See forceBuffering
               for details about buffering behavior.
   forceBuffering: Setting this to True will cause output from the socket to only be
                   printed if no late connection errors occur (Though the error
                   itself may be printed. See throwException).
                   Leaving this as False will cause buffering to occur as-needed,
                   based on other arguments (stringBuff or errors).
   throwException: Setting to True allows RunSocketCommandException to be raised if
                   any error occurs. The error text will be propagated back through
                   the exception, and is not written directly to the output stream
                   or stringBuff.
                   Leaving this as False will cause the error to be written to the
                   stderr.
   errors: (str or list of strs) Message(s) the agent/socket can spit out, which are
           matched against and considered errors.
           throwException MUST be True if errors is set.
           Setting errors will also automatically cause internal buffering to occur,
           since it is required for matching.

   *** NOTE: Maximum response size supported is 1Mb. Please refer to Bug113060
   """

   # Used only for breadth test verification
   if 'AGENT_COMMAND_REQUEST_DRY_RUN' in os.environ:
      print( "Dir: %s, Type: %s, Command: %s" %
         ( dirName, commandType, command ) )
      return

   if testAgentCommandRequest:
      testAgentCommandRequest.runSocketCommand(
         entityManager, dirName, commandType, command, stringBuff=stringBuff,
         outputFormat=outputFormat )
      return

   doBuffering = bool(
                       # Detecting errors in the output requires buffering, since
                       # the error may be split across multiple recv'd chunks.
                       errors or
                       # When stringBuff is provided, the user never expects a
                       # garbled combination or good and error output. They should
                       # receive only one or the other. This means we need to buffer
                       # the good output aside from stringBuff, in case we encounter
                       # an error mid-stream.
                       stringBuff or
                       # Set when the user may want to ensure we never write to both
                       # 'stdout' and 'stderr' in the same command to prevent
                       # premature output that would be rendered moot/syntactically
                       # incorrect if an error later happens.
                       # Should not be used by commands with large scale (ConfigAgent
                       # can run out of memory).
                       forceBuffering
                       )
   assert throwException or not errors, \
         "throwException can't be False when errors is set"

   errorResponses = []
   if errors:
      if type( errors ) != list: # pylint: disable=unidiomatic-typecheck
         errorResponses.append( errors )
      else:
         errorResponses = errors

   if entityManager.isLocalEm():
      t0( "Running in cohab mode" )
      # local entity manager avoids pyClient for cohabiting tests
      pc = entityManager
   else:
      pc = PyClient.PyClient( entityManager.sysname(), "Sysdb" )

   def _handleError( errText ):
      if outputFormat == 'json':
         # Note that this shouldn't end up being used in conjuntion with cli handler
         # that returns a model. Generally those Clis will have throwException=True
         # and no stringBuff, so this won't actually render to the stdout.
         # If your cli is encountering problems because of this, consider using
         # runCliPrintSocketCommand and using the return value as needed.
         connErrMsg = '{"errors": ["%s"]}' % errText
      else:
         connErrMsg = "%% %s\n" % errText
      connErrMsg = six.ensure_binary( connErrMsg )

      error, _ = _handleOutput( [ connErrMsg ], stringBuff, errorResponses,
                                throwException, error=True )
      if error:
         raise RunSocketCommandException( errText )

   # Make sure AgentCommandRequest directory exists
   if not _dirExists( pc, entityManager.sysname(), dirName ):
      _handleError( connErrMsg or "Error: Agent has not been started" )
      return

   sysdbSafeCommandType = commandType.replace( '/', '_' )
   requestName = sysdbSafeCommandType + '-' + str( os.getpid() ) + '-' + \
       str( getRequestNum() )

   connection, targetSpec = createSocket(
      pc, entityManager.sysname(), dirName, requestName,
      command, commandType, keepalive=keepalive,
      asyncCommand=asyncCommand, timeout=timeout,
      outputFormat=outputFormat, revision=revision )

   error, message = False, None
   if not connection:
      _handleError( connErrMsg or "Cli connection error" )
      return

   t0( "Receiving output" )
   chunks = []
   try:
      if entityManager.isLocalEm():
         error, message = _cohabLoop( connection, chunks, stringBuff,
                                      errorResponses, throwException, doBuffering,
                                      timeout )
      else:
         error, message = _nonCohabLoop( connection, chunks, stringBuff,
                                         errorResponses, throwException,
                                         doBuffering, timeout )
   except KeyboardInterrupt: # pylint: disable=try-except-raise
      raise
   except Exception as e: # pylint: disable=W0703
      t0( "Failed to read: %s" % str( e ) )

      _handleError( connErrMsg or "Cli connection exception" )
   else:
      if error:
         raise RunSocketCommandException( message )
   finally:
      t0( "Closing connection" )
      connection.close()

      t0( "Deleting socket" )
      deleteSocket( pc, entityManager.sysname(), dirName, requestName, targetSpec )

class AcrCommand:
   def __init__( self ):
      CommandParams = Tac.Type( "Agent::AgentCommandRequestCommandParams" )
      self.delimiter = CommandParams.commandDelimiter

      self.params = set()

   def addArgs( self, args ):
      """
      Calls self.addParam on every item in the given dictionary

      :param args: a dictionary of keyword to argument
      """
      for keyword, argument in args.items():
         self.addParam( keyword, argument )

   def addParam( self, keyword, argument=None ):
      """
      Adds keyword and argument to to the internal list of parameters. The parameter
      is stored in the form "keyword=argument" if an argument is provided, else only
      the keyword is stored.

      :param keyword: the keyword to add. Required.
      :param argument: the argument to add. Optional.
      """
      if argument is None:
         self.params.add( keyword )
      else:
         self.params.add( f"{keyword}={argument}" )

   def commandParams( self ):
      """
      From the stored parameters, generate a string that fits the following EBNF:
      parameter
            : keyword
            | keyword "=" argument
            | parameter "\t" parameter
            ;

      keyword : STRING ;
      argument : STRING ;
      """
      # Sanitize the params by replacing any tabs with spaces
      result = [ p.replace( self.delimiter, ' ' ) for p in self.params ]
      return self.delimiter.join( result )

def runCliPrintSocketCommandWithArgs( entityManager, dirName,
                                      command, args, mode, **kwargs ):
   """
   Takes the given args dictionary and convert it into a string to be passed onto
   runCliPrintSocketCommand
   """
   acrCommand = AcrCommand()
   acrCommand.addArgs( args )
   return runCliPrintSocketCommand( entityManager, dirName, command,
                                    acrCommand.commandParams(),
                                    mode, **kwargs )
