# Copyright (c) 2012 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import os
import sys
import threading
import queue

class Task:
   """A Task represents a piece of work to be executed, whose results can be obtained
   later on.  The caller shouldn't make any assumptions about which thread the Task
   is executed in.

   By default, each Task is executed in its own thread, although seting the
   ARPYUTILS_THREADING_TASK_SYNCHRONOUS environment variable causes Tasks to be
   executed in the foreground thread instead, which can be useful for debugging.  At
   some point in future we could make Tasks use a thread pool without changing the
   API.

   Example:
      def factorial( n ):
         time.sleep( 1 )
         return 1 if ( n == 0 ) else ( n * factorial( n - 1 ) )
      tasks = [ Task( factorial, i ) for i in range( 10 ) ]
      results = [ t.result() for t in tasks ]
"""

   def __init__( self, fn, *args, **kwargs ):
      """Constructs a Task which will call fn with the specified args and kwargs."""

      self.fn_ = fn
      self.args_ = args
      self.kwargs_ = kwargs
      self.result_ = None
      self.synchronous_ = ( 'ARPYUTILS_THREADING_TASK_SYNCHRONOUS' in os.environ )
      if self.synchronous_:
         self.thread_ = None
      else:
         self.thread_ = threading.Thread( target=self._run )
         self.thread_.daemon = True
         self.thread_.start()

   def _execute( self ):
      return self.fn_( *self.args_, **self.kwargs_ )

   def _run( self ):
      try:
         value = self._execute()
      except: # pylint: disable=W0702
         self.result_ = ( 'raised', sys.exc_info() )
      else:
         self.result_ = ( 'returned', value )

   def completed( self ):
      """Returns True if the Task has completed, or False otherwise."""

      if self.synchronous_:
         if self.result_ is None:
            self._run()
      return self.result_ is not None

   def result( self ):
      """Blocks until the Task has completed; then returns the return value of fn, or
      raises whatever exception was raised by fn.

      This method may be called multiple times, and will always return the same
      value, without re-executing the task."""

      if self.synchronous_:
         if self.result_ is None:
            self._run()
      else:
         self.thread_.join()

      assert self.result_ is not None
      if self.result_[ 0 ] == 'returned':
         return self.result_[ 1 ]
      else:
         assert self.result_[ 0 ] == 'raised'
         _, val, tb = self.result_[ 1 ]
         # Discard the last two frames from the traceback (the _run and _execute
         # methods), since they're not interesting.
         tb = tb.tb_next.tb_next
         raise val.with_traceback( tb )

class JobRunner:
   """A JobRunner represents a piece of work which can be divided into Tasks whose
   arguments come from the entries of queue. The results are then stored in the
   output queue. The number of Tasks defaults to the number of available CPU cores
   which helps prevent huge backtraces on KeyboardInterrupt.

   Example:
      numbers = queue.Queue()
      for i in range( 10 ):
         numbers.put( i )

      def foo( n ):
         time.sleep( n * 2 )
         return n, "I slept % seconds" % n * 2

      queueJob = JobRunner( foo, i )
      output = jobRunner.run( blocking=True )

      while not output.empty():
         n, msg = output.get( block=0 )
         print  n, msg
   """
   def __init__( self, fn, argQueue, maxThreads=None ):
      self.argQueue_ = argQueue
      self.fn_ = fn
      self.maxThreads_ = maxThreads if maxThreads else availableCores()
      self.workers_ = None
      self.output_ = queue.Queue()

   def run( self, blocking=False ):
      numThreads = self.maxThreads_

      def taskLoop( fn, argQueue, outputQueue ):
         while True:
            try:
               args = argQueue.get( block=0 )
            except queue.Empty:
               break
            outputQueue.put( fn( args ) )

      self.workers_ = [ Task( taskLoop, self.fn_, self.argQueue_, self.output_ ) \
                           for _ in range( numThreads ) ]
      
      if blocking:
         return self.result()

   def result( self ):
      """Blocks until the QueueJob has completed; then returns the output queue

      This method may be called multiple times, and will always return the same
      value, without re-executing the job."""
      exceptionInfo = None

      for worker in self.workers_:
         # Catch all exceptions raised by the workers. Then store the traceback
         # info for the first exception raised. If this isn't done and an
         # exception is raised it would stop this function from blocking.
         # So instead of raising it right away store it, so we can call
         # result() on the remaining threads then if an exception was seen,
         # raise it.
         try:
            worker.result()
         except: # pylint: disable=bare-except
            if not exceptionInfo:
               # We use sys.exc_info() so that we can preserve the original
               # traceback. sys.exc_info() returns a tuple which contains
               # the exception type, the exception message, and the traceback.
               # Store this so we can raise the exact same exception later.
               exceptionInfo = sys.exc_info()

      if exceptionInfo:
         raise exceptionInfo[ 1 ].with_traceback( exceptionInfo[ 2 ] )
      return self.output_

def availableCores():
   if sys.platform == 'darwin':
      import subprocess # pylint: disable=import-outside-toplevel
      # pylint: disable-next=consider-using-with
      sysctl = subprocess.Popen( [ 'sysctl', '-n', 'hw.ncpu' ],
                                 stdout=subprocess.PIPE,
                                 universal_newlines=True )
      ncpu, _ = sysctl.communicate()
      return int( ncpu )
   else:
      import re # pylint: disable=import-outside-toplevel
      with open( "/proc/cpuinfo" ) as f:
         return len( re.findall( "processor\t", f.read() ) )

def availableRam(): # pylint: disable=inconsistent-return-statements
   """Return the amount of physical RAM (in bytes) in the machine."""
   import re # pylint: disable=import-outside-toplevel
   regex = re.compile( r"MemTotal:\s*(\d+)\s*kB" )
   with open( "/proc/meminfo" ) as f:
      for line in f:
         m = regex.match( line )
         if m:
            return int( m.group( 1 ) ) * 1024

