#!/usr/bin/env python3
# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

from collections import namedtuple
import functools
import itertools
import os
import pickle
import re
import resource
import signal
import sys
import time

# Used as an argument to Decorators.retry to indicate that
# unlimited number of attempts is requested
INFINITE_ATTEMPTS = -1

# Note: Converting this to new-style classes breaks stuff.
class Extendable:
   """Derive from this class if you want to make your class, and any subclasses,
   extendable.  "Extendable" here means that other code (probably in a plugin) can
   add methods and baseclasses to the class by means of the @cls.extension decorator.

   Do not us this class lightly!  Using this class is probably the wrong thing to do.
   Adding methods to a class from outside the class definition can lead to code that
   is very hard to understand and maintain, as someone reading code that uses this
   class may have to search the entire codebase to find the definition of a method
   that is being called.

   An alternative, more C++-like design that uses multiple classes, rather than
   putting all the methods in one single class, would almost certainly be better."""

   @classmethod
   def extension( cls, ext ):
      """A decorator that adds the decorated function to this class's dict, or adds
      the decorated class to this class's list of baseclasses.  For example::

         class Foo( Extendable ): pass
         class Bar( Foo ): pass

         @Bar.extension
         def myFunction( self, x, y ):
            pass

         @Bar.extension
         class MyClass:
            pass

      causes myFunction to become a method of Bar, and causes MyClass to become a
      baseclass of Bar."""

      assert cls is not Extendable, "Can't extend Extendable itself!"

      import types # pylint: disable=import-outside-toplevel
      if isinstance( ext, ( types.FunctionType, functools.partial ) ):
         # Don't allow an extension to replace an existing class member (whether that
         # member was defined in the class itself, or added later as an extension),
         # unless that member was decorated with @overridable.
         name = ext.__name__
         if name in cls.__dict__:
            f = getattr( cls, name )
            if not getattr( f, '__ArPyUtils_overridable', False ):
               if hasattr( f, 'func_code' ):
                  msg = "Class %s already has a method %r, defined at %s:%d" % (
                     cls.__name__, name,
                     f.__code__.co_filename,
                     f.__code__.co_firstlineno )
               else:
                  msg = f"Class {cls.__name__} already has a member {name!r}"
               assert False, msg
         if hasattr( cls, "traceExtendableExtensionCall" ):
            realExt = ext
            # Tac.memoize looks at f.__code__.co_argcount so we need to play games
            # here so memoized extensions work
            if realExt.__code__.co_argcount == 0:
               def traceThenCall():
                  # pylint: disable=no-member
                  cls.traceExtendableExtensionCall( realExt )
                  return realExt()
            else:
               def traceThenCall( self, *cmds, **kwargs ):
                  # pylint: disable=no-member
                  cls.traceExtendableExtensionCall( realExt )
                  return realExt( self, *cmds, **kwargs )
            ext = traceThenCall
            ext.__name__ = name
         setattr( cls, name, ext )

      elif isinstance( ext, type ):
         cls.__bases__ = ( ext, ) + cls.__bases__
         # ext must be added to the start of the __bases__ list, to ensure that
         # extensions of a class take precedence over extensions of its baseclass.

      else:
         assert False, "Argument to extension must be a function or a class"

      return ext

def overridable( fn ):
   """A decorator used to mark a method of a class derived from Extendable as being
   overridable.  "Overridable" here means than the method can be replaced by an
   extension method.  (Normally, @cls.extension is only allowed to add methods to a
   class, not to replace existing methods.)

   Use of this decorator is very strongly discouraged.  Allowing an extension to
   replace an existing method can lead to serious confusion when trying to understand
   how the code works."""

   fn.__ArPyUtils_overridable = True # pylint: disable-msg=W0212
   return fn

def memoizeDisabled():
   return bool( os.getenv( 'TEST_NO_MEMOIZE' ) )

def staticMemoize( func ):
   '''Decorator that uses a per function static cache to memoize function calls.

   If the decorated function is a class method, the first call uses the
   instance for which it was called on, but all subsequent calls with
   the same args (ignoring self) return the cached value.

   Cache is stored in wrapped function and can be reset using
   unStaticMemoize( func ).
   '''
   return timedStaticMemoize( timeout=None )( func )

TimedValue = namedtuple( "TimedValue", [ 'value', 'expiry' ] )
def timedStaticMemoize( timeout=None ):
   '''Same as staticMemoize, but additionally allows for specifying a timeout
   for the cached values. If a memoized function is called more than timeout
   seconds after it was last calculated, it will be executed and the new value
   will be cached and returned.

   A timeout of None means values won't ever expire via this mechanism.
   '''
   def decorator( func ):
      cache = func.memoCache_ = {}

      @functools.wraps( func )
      def doMemoize( *args, **kwargs ):
         if memoizeDisabled():
            return func( *args, **kwargs )

         funcVarnames = func.__code__.co_varnames
         # if 'self' is first arg, don't include it
         if funcVarnames and funcVarnames[ 0 ] == 'self':
            filteredArgs = args[ 1: ]
         else:
            filteredArgs = args

         lookupArgs = ( filteredArgs, kwargs )
         argsKey = pickle.dumps( lookupArgs, protocol=pickle.HIGHEST_PROTOCOL )

         # retrieve value from cache, compute if necessary
         if argsKey not in cache or ( cache[ argsKey ].expiry and
                                      cache[ argsKey ].expiry < time.time() ):
            value = func( *args, **kwargs )
            expiry = time.time() + timeout if timeout else None
            cache[ argsKey ] = TimedValue( value=value, expiry=expiry )
         return cache[ argsKey ].value

      return doMemoize
   return decorator

def staticUnmemoize( func ):
   '''Resets the static memoize cache for the given function.'''
   func.memoCache_.clear()

def revertableMemoize( func ):
   '''Decorator that memoizes the decorated function.  The first time
   function foo is called, its return value is stored in an attribute
   foo_, and subsequent calls get that value.  This is just like
   Tac.memoize, but the memo can be reverted with a call to unMemoize.'''

   varname = func.__name__ + "_"
   @functools.wraps( func )
   def doMemoize( self ):
      if memoizeDisabled():
         return func( self )

      if hasattr( self, varname ):
         return getattr( self, varname )

      value = func( self )
      setattr( self, varname, value )
      # Unlike Tac.memoize, we don't replace this function with a lambda
      # returning the value here.  That makes unMemoize much more difficult
      # to get right.
      return value

   return doMemoize

def unMemoize( obj, funcName ):
   '''Clear out the value cached by memoize for the function with the given
   name.  Another call to the memoized function will fill the cache with a
   newly computed value.  We take the function name as a string rather than
   taking a reference to the function here, because too often the function
   that was memoized is a couple of layers deep in lambdas and additonal
   decorators.  That makes it impossible for the calling code to find the
   right function object to pass in.'''

   attrName = funcName + '_'
   if hasattr( obj, attrName ):
      delattr( obj, attrName )

def memoize( obj, funcName, value ):
   '''Force the value to be cached for the function with the given name.
   Another call to the memoized function will return this value until
   unMemoize() is called.'''
   attrName = funcName + '_'
   setattr( obj, attrName, value )

def retry( retryCheckFunc=None, retryCheckEmbeddedMethodName='', attempts=2,
           retryInterval=60, retryTimeout=None, retryOnExceptions=(),
           retryMessageFn=None ):
   '''Retries the provided function for a maximum of "attempts" times
   and/or until retryTimeout number of seconds has elapsed.
   Unlimited number of attempts can be requested using:

      attempts=INFINITE_ATTEMPTS

   (in which case retryTimeout limit would apply only).

   Use retryMessageFn to provide a function returning string for
   periodic reporting to stderr on what the code is waiting for.
   Use optional {delay} tag to embed the number of seconds waited
   so far in the message string.
    
   After a failure, the function is retried only after waiting the "retryInterval"
   (in seconds). After success, the function returns.

   This function returns the return value of the last retried function call,
   unless there were no attempts, in which case it returns None.

   Whether the function call was a success or failure is determined by the retry
   check function, which takes in the return value of the retried function call as
   input, and returns True if we should retry, and False if we should not. If no
   retry check function is specified, the negated return value of the retried
   function will be used to determine this (i.e. if bool( ret ) is False we retry,
   if True we do not).

   If retryCheckFunc is provided, we use that as our retry check function. If it is
   None and retryCheckEmbeddedMethodName is provided, we use that to retrieve the
   class method from the self instance object of the called function and use this
   as the retry check function.'''

   def decorator( func ):
      @functools.wraps( func )
      def wrapper( *args, **kwargs ):
         failureRetryCheckFunc = None
         ret = None
         started = time.time()
         timeout = started + retryTimeout if retryTimeout else None

         if retryCheckFunc:
            failureRetryCheckFunc = retryCheckFunc
         elif retryCheckEmbeddedMethodName:
            self = args[ 0 ]
            failureRetryCheckFunc = getattr( self, retryCheckEmbeddedMethodName )

         # itertools.count() is a simple iterator returning endless counter
         iterator = ( itertools.count()
                      if attempts == INFINITE_ATTEMPTS else range( attempts ) )
         for attempt in iterator:
            try:
               ret = func( *args, **kwargs )
            except Exception as e: # pylint: disable=broad-except
               if not isinstance( e, retryOnExceptions ):
                  raise
               ret = False
            if ( failureRetryCheckFunc( ret ) if failureRetryCheckFunc
                 else not ret ):
               now = time.time()
               if timeout and now >= timeout:
                  break
               # if not last attempt, sleep for retryInterval
               if attempt != attempts - 1:
                  time.sleep( retryInterval )
                  delay = int( now - started + retryInterval )
                  if retryMessageFn and retryTimeout and delay < retryTimeout:
                     print(
                        "%s: %s" %
                        ( sys.argv[ 0 ], retryMessageFn().format( delay=delay ) ),
                        file=sys.stderr )

            else:
               return ret
         return ret

      return wrapper
   return decorator

ForkedFuncStatus = namedtuple( 'ForkedFuncStatus', 'code,signal' )

class ForkedFuncError( Exception ):
   def __init__( self, func, status ):
      Exception.__init__( self,
                          "Forked function %s exited with status %r" %
                          ( func, status ) )
      self.status = status
      self.func = func

# Acts as a global trace setting.
# If True, do not print exit status errors in the child process.
_runAsForkQuiet = True
def runAsForkQuietIs( quiet ):
   global _runAsForkQuiet
   _runAsForkQuiet = quiet

# If True, exceptions caught in runAsFork will drop the process into pdb before
# exiting.
_runAskForkPdb = False
def runAsForkPdbIs( doPdb ):
   global _runAskForkPdb
   _runAskForkPdb = doPdb

if 'RUNASFORK_PDB' in os.environ:
   runAsForkPdbIs( True )

def runAsFork( func ):
   """A decorator to perform every call to the wrapped function in a fork of the
   current process.

   The return value of the wrapped function must currently always be None.
   Any other return value will result in the fork exiting with non-zero exit status.
   Any uncaught exceptions will also yield a non-zero exit status.
   Calling exit() will cause the fork to exit with the same code exit is called with.

   Any non-zero exit of the fork will raise ForkedFuncError in the parent process.

   Args:
      func: The function to be wrapped. Should return None. There are no restrictions
            on its arguments.

   Returns:
      A function wrapper with the above behaviour.

   Details:
      A major use case for this is to run many test cases of behaviour which modify
      global process state, which would be too cumbersome/messy to reverse after
      each test case. With runAsFork, many tests can be written in a single file/lib,
      and run in succession without them impacting each other.

      Another could be to run code which expects a signal like SIGABRT, which we want
      to be able to detect. This would usually also be for test purposes.

   Example:
      @runAsFork
      def die():
         os.kill( os.getpid(), signal.SIGKILL )

      try:
         die()
      except ForkedFuncError:
         # Still alive!
         ...
   """
   @functools.wraps( func )
   def wrapper( *args, **kwargs ):
      childPid = os.fork()
      if childPid:
         # This is the parent process. Wait for the child to finish.
         _, status = os.waitpid( childPid, 0 )
         exitCode = os.WEXITSTATUS( status ) if os.WIFEXITED( status ) else 0
         exitSignal = os.WTERMSIG( status ) if os.WIFSIGNALED( status ) else 0
         statusObj = ForkedFuncStatus( exitCode, exitSignal )

         if exitCode or exitSignal:
            raise ForkedFuncError( func, statusObj )
      else:
         # This is the child process. Run the wrapped function, then exit.
         code = 0
         try:
            retVal = func( *args, **kwargs )
            if retVal is not None:
               if not _runAsForkQuiet:
                  print( 'Forked function returned non-None value (not supported):',
                         repr( retVal ),
                         file=sys.stderr )
               code = 1

         except SystemExit as e:
            # Special handling for any calls to exit(). This will propagate the code.
            if not _runAsForkQuiet:
               print( 'Forked function',
                      func,
                      'raised SystemExit with code',
                      e.code,
                      file=sys.stderr )
            code = e.code
         except Exception as e: # pylint: disable=broad-except
            if not _runAsForkQuiet:
               print( 'Forked function',
                      func,
                      'raised exception:',
                      repr( str( e ) ),
                      file=sys.stderr )

            if _runAskForkPdb:
               # Drop into pdb with the exception's traceback before killing
               # the child process.
               import pdb # pylint: disable=import-outside-toplevel
               import traceback # pylint: disable=import-outside-toplevel
               traceback.print_exc()
               # pylint: disable=E1101,no-member
               pdb.post_mortem()

            code = 1

         # Use _exit rather than exit to avoid exit handlers or try except blocks
         # that are defined for SystemExit.
         # We want to exit the forked child and return to the parent without any
         # tampering.
         # pylint: disable=protected-access
         os._exit( code )

   return wrapper

def expectSigabrt( expectedMessageRe=None ):
   '''This decorator will fork then execute the wrapped function in the child
   process; the parent process waits for the child process to exit with the expected
   status code for asserts and will check the output against the given expected
   message.

   Example:
      @expectSigabrt()
      def testCppAssert():
         # Call some C++ code that is expected to assert.
         ...

      @expectSigabrt( expectedMessageRe="" )
      def testCppAssert():
         # Call some C++ code that is expected to assert.
         ...
   '''
   def decorator( func ):
      @functools.wraps( func )
      def wrapper( *args, **kwargs ):
         # Create a pipe to communicate between parent and child process.
         readPipe, writePipe = os.pipe()

         pid = os.fork()
         if pid == 0:
            # Child process

            # We are intentionally triggering a crash here, no need for a core file.
            # Generating a core file here would needlessly consume significant
            # amounts of disk space any time the test is run.
            resource.setrlimit( resource.RLIMIT_CORE, ( 0, 0 ) )

            # Redirect stderr to the pipe and close unused pipes.
            os.close( readPipe )
            os.dup2( writePipe, sys.stderr.fileno() )
            os.close( writePipe )

            # Call the provided trigger function, this should trigger the code
            # leading to the assert expected to fail.
            func( *args, **kwargs )

            # We don't expect to reach this code.
            exit( 1 ) # pylint: disable=consider-using-sys-exit
         else:
            # Original process - expect child to terminate due to SIGABRT

            os.close( writePipe )

            # Read the stderr of the child coming through the pipe.
            readSize = 4096
            childErrOutput = bytearray()
            while True:
               readBuf = os.read( readPipe, readSize )
               childErrOutput += readBuf
               if not readBuf:
                  # We've reached EOF, so exit the loop.
                  break
            os.close( readPipe )

            # Make sure the child aborted.
            _, exitCode = os.waitpid( pid, 0 )
            assert os.WIFSIGNALED( exitCode ) and \
                   os.WTERMSIG( exitCode ) == signal.SIGABRT, "Expected SIGABRT"

            if expectedMessageRe is not None:
               childErrOutput = childErrOutput.decode()
               # Make sure we hit the expected assert.
               assert re.search( expectedMessageRe, childErrOutput ) is not None, \
                  "Expected an assert message matching '{}', but didn't find it " \
                  "in the stderr output:\n{}".format( expectedMessageRe,
                                                      childErrOutput )
      return wrapper
   return decorator

def _funcName( func ):
   if hasattr( func, '__qualname__' ):
      # py3 shortcut
      return '.'.join(
         func.__qualname__.replace( '<locals>.', '' ).rsplit( '.', 2 )[ -2 : ]
      )

   funcName = []
   # An instance method with an `im_class` attribute.
   boundClass = getattr( func, 'im_class', None )
   if boundClass:
      self = getattr( func, 'im_self', None )
      if boundClass.__name__ == 'type' and self:
         # classmethod
         funcName = [ self.__name__ ]
      else:
         funcName = [ boundClass.__name__ ]
   funcName.append( func.__name__ )
   return '.'.join( funcName )

def _funcSignature( func, args, kwargs ):
   """
   Returns a string showing the called function + args

   Examples:
     - instance method
        "MyClass.foo(1, 2)"
        "MyClass.foo(1, 2, optional=True)"

     - function
        "someFunc(kw=1)"
   """
   argStr = ", ".join( map( repr, args ) ) if args else ""
   sepStr = ", " if args and kwargs else ""
   kwStr = ", ".join( f"{k}={v!r}"
                      for k, v in kwargs.items() ) if kwargs else ""
   return f'{_funcName( func )}({argStr}{sepStr}{kwStr})'

# _DescriptorWrapper is a descriptor, so that when the decorator (with 'trace'
# closure) is accessed, it will call __get__, which then returns an instance
# of TimerWrapper with the actual instance method, as applicable, at runtime.
# For bound method, this means we'll:
#  - have an `im_class` available for _funcSignature
#  - not have 'self' in args.
class _DescriptorWrapper:
   def __init__( self, func ):
      self.func = func
      # Along the lines of functools.wraps:
      for attr in ( '__module__', '__doc__' ):
         if hasattr( func, attr ):
            setattr( self, attr, getattr( func, attr ) )

   def __repr__( self ):
      return repr( self.func )

   def __get__( self, obj, _type=None ):
      # Instantiate a TimerWrapper on the *instance* that we're wrapping.
      # Note that this is happening at runtime instead of at the module
      # level like it would with a basic decorator implementation.
      # Notes:
      # 1. self.__class__ is TimerWrapper itself, so when we access the
      #    unbound method descriptor (the initial TimerWrapper), we actually
      #    create a new TimerWrapper for the bound _instance method_ at the
      #    time of the actual call.
      # 2. The actual call is then handled by __call__ for this new TimerWrapper
      #    instance that has the bound method as it's `func` attribute.
      # 3. self.func.__get__ ensures that we follow the descriptor protocol
      #    correctly which allows us to decorate e.g. staticmethod, classmethod.
      return self.__class__( self.func.__get__( obj, _type ) )

   def __call__( self, *args, **kwargs ):
      raise NotImplementedError

def timed( trace=print ):
   """
   Capture start, end, and timing information for wrapped function

   For example:
   class X:
      @timed(trace=print)
      def foo(self, a, b):
         time.sleep(0.3)

   start timed X.foo(1, 15)
   (X.foo(1, 15) runs for 0.3 seconds)
   end timed X.foo(1, 15) time 0.3s
   """

   class TimerWrapper( _DescriptorWrapper ):
      def __call__( self, *args, **kwargs ):
         callSig = _funcSignature( self.func, args, kwargs )
         startTime = time.time()
         trace( "start", callSig )
         ret = self.func( *args, **kwargs )
         trace( "end", callSig, f'time {time.time() - startTime:.1f}' )
         return ret

   return TimerWrapper

def _traced( trace=print, includeFuncName=False ):
   class TraceWrapper( _DescriptorWrapper ):
      def __call__( self, *args, **kwargs ):
         trace( _funcSignature( self.func, args, kwargs ) )
         if includeFuncName:
            funcName = _funcName( self.func )
            if 'funcName' in kwargs:
               raise TypeError( funcName + " should not be called with a `funcName` "
                     "argument directly." )
            kwargs[ 'funcName' ] = funcName
         return self.func( *args, **kwargs )

   return TraceWrapper

def traced( trace=print ):
   """
   Trace wrapped function / method execution with `trace`

   Can be applied to free functions, and to static, class, and instance methods.

   For example:
   @traced( trace=print )
   def foo( a, b=1 ):
      pass
   foo(1, b=2)

   --> "foo(1, b=2)"
   """
   return _traced( trace=trace, includeFuncName=False )

def tracedFuncName( trace=print ):
   """
   Trace wrapped function / method execution with `trace`

   Can be applied to free functions, and to static, class, and instance methods.

   Also, when the wrapped function is called, `tracedFuncName` will insert a kwarg
   funcName="<funcName>" for use by the function itself.

   For example:
   @tracedFuncName( trace=print )
   def slimShady( a, b=1, funcName=None ):
      print( "my name is", funcName )
   slimShady(1, b=2)

   --> "slimShady(1, b=2)"
   --> "my name is slimShady"
   """
   return _traced( trace=trace, includeFuncName=True )
