#!/usr/bin/env python3
# Copyright (c) 2022 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import argparse
from EDTAccess import ( defaultRootUser, defaultRootPasswd,
                        shellCmd, DutCmdSession, grabbedDut,
                        traceMsg, traceLevelIs,
                        TRACE_FATAL, TRACE_ERROR, TRACE_INFO )
import os
import signal
import socket
import sys

class EDTClient:
   def __init__( self ):
      self.serverAddress = '/tmp/edt_socket'
      self.serverHost = None
      self.serverPort = None
      self.args = None
      self.dutCmd = None
      path = os.path.abspath( os.path.dirname( sys.argv[ 0 ] ) )
      self.serverPath = path + "/edts"
      sys.path.append( "." )

   # Interrupt handler to deal with ctrl-c to terminate server when stuck
   def sigIntHandler( self, signum, frame ):
      # Stop server since it may be running background threads.
      traceMsg( TRACE_INFO, "Received keyboard interrupt, stopping server" )
      self.stopServer()
      sys.exit( 1 )

   def parseCmdOpts( self ):
      parser = argparse.ArgumentParser( prog=sys.argv[ 0 ],
                                        usage='%(prog)s [options]',
                                        description="Start a EDT Client "
                                                    "session" )
      parser.add_argument( "cmdName", help="command name" )
      parser.add_argument( "cmdArg", nargs='*',
            help="zero or more command arguments" )
      parser.add_argument( "-s", "--server",
            help="select dut for remote server" )
      parser.add_argument( "-p", "--port", nargs=1,
            help="specify port use for remote TCP socket, if "
                  "not specified, use local unix socket as "
                  "default" )
      parser.add_argument( "-v", "--verbose", type=int, default=TRACE_ERROR,
            help="trace level to display the message" )
      parser.add_argument( "--username", default=defaultRootUser,
            help="username for remote bash session on a dut" )
      parser.add_argument( "--password", default=defaultRootPasswd,
            help="password for remote bash session on a dut" )
      parser.add_argument( "--quit", action='store_true',
            help="force quit the existing server, client will exit right away" )
      self.args = parser.parse_args()
      traceLevelIs( self.args.verbose )
      if self.args.quit:
         self.stopServer()
         sys.exit( 0 )

   def fixupCmdArg( self, argStr ):
      def maybeAddQuotes( x ):
         if x and ( x[ 0 ] in "'\"" or x[ 0 ].isdigit() or
                    x in [ "True", "False", "None" ] ):
            return x
         else:
            return "'" + x + "'"
      kwComp = argStr.split( '=', 1 )
      if len( kwComp ) == 2:
         # keyword argument
         kwComp[ 1 ] = maybeAddQuotes( kwComp[ 1 ] )
         argStr = "=".join( kwComp )
      else:
         argStr = maybeAddQuotes( argStr )
      return argStr

   def getRequest( self ):
      # Convert command line arguments to server request.
      cmdN = self.args.cmdName
      cmdA = self.args.cmdArg
      if not cmdN:
         traceMsg( TRACE_FATAL, "no command" )
         sys.exit( 1 )
      # Return request as cmdN if no cmdA is provided
      if not cmdA:
         if '(' not in cmdN:
            cmdN = cmdN + "()"
         request = cmdN
      else:
         # If arguments provided with no quote, fix it
         args = [ self.fixupCmdArg( a ) for a in cmdA ]
         request = cmdN + "(" + ','.join( args ) + ")"
      return request

   def connectToServer( self ):
      # Connect to server.
      if self.args.port:
         sock = socket.socket( socket.AF_INET, socket.SOCK_STREAM )
         try:
            sock.connect( ( self.serverHost, self.serverPort ) )
         except OSError as m:
            traceMsg( TRACE_INFO, m )
            traceMsg( TRACE_INFO, "failed to connect to local server at "
                      f'{self.serverHost} port {self.serverPort}' )
            sock = None
      else:
         sock = socket.socket( socket.AF_UNIX, socket.SOCK_STREAM )
         try:
            sock.connect( self.serverAddress )
         except OSError as m:
            traceMsg( TRACE_INFO, m )
            traceMsg( TRACE_INFO, "failed to connect to remote server at "
                      f"{self.serverAddress}" )
            sock = None
      return sock

   def sendRequest( self, sock, request ):
      sock.sendall( request.encode() )
      while True:
         resp = sock.recv( 4096 )
         if not resp:
            break
         try:
            sys.stdout.write( resp.decode() )
         except socket.error:
            break

   def stopServer( self ):
      cmd = f"{self.serverPath} --stop -v {self.args.verbose}"
      if not self.args.port:
         shellCmd( cmd )
         traceMsg( TRACE_INFO, "stopped local server: " + cmd )
      else:
         self.dutCmd.execRemoteShellCmd( cmd )
         traceMsg( TRACE_INFO, "stopped remote server: " + cmd )

   def startServer( self ):
      self.stopServer()
      cmd = f"{self.serverPath} --start -v {self.args.verbose}"
      if not self.args.port:
         shellCmd( cmd )
         traceMsg( TRACE_INFO, "started local server: " + cmd )
      else:
         cmd += f" --port {self.args.port}"
         self.dutCmd.execRemoteShellCmd( cmd )
         traceMsg( TRACE_INFO, "started remote server: " + cmd )

   def run( self ):
      self.parseCmdOpts()
      # Get dut for remote server.
      if self.args.port:
         dutName = self.args.dut or grabbedDut()
         if not dutName:
            traceMsg( TRACE_FATAL, "no dut specified or grabbed" )
            sys.exit( 1 )
         self.serverHost = dutName
         self.serverPort = int( self.args.port[ 0 ] )
         self.serverPath = "/usr/bin/edts"
         # Create DUT command session.
         self.dutCmd = DutCmdSession( dutName, username=self.args.username,
                                      password=self.args.password )
      sock = self.connectToServer()
      if sock is None:
         self.startServer()
         sock = self.connectToServer()
         if sock is None:
            traceMsg( TRACE_FATAL, "server not found" )
            sys.exit( 1 )
      req = self.getRequest()
      self.sendRequest( sock, req )
      sock.close()

if __name__ == "__main__":
   myClient = EDTClient()
   signal.signal( signal.SIGINT, myClient.sigIntHandler )
   myClient.run()
