#!/usr/bin/env python3
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import collections
import contextlib
import sys
import threading
import traceback

class _ThreadLocalStack:
   '''Like name implies, this is a thread-local stack'''

   def __init__( self ):
      self.__threadLocal = threading.local()

   @property
   def __values( self ):
      '''Returns actual thread-local storage, creating one if needed'''
      result = getattr( self.__threadLocal, 'values', None )
      if result is None:
         result = self.__threadLocal.values = []
      return result

   @contextlib.contextmanager
   def hold( self, value ):
      '''Context manager pushing value to stack and popping it back on exit'''
      self.__values.append( value )
      try:
         yield value
      finally:
         poppedValue = self.__values.pop()
         assert poppedValue is value

   def __bool__( self ):
      '''Returns true value iff stack is non-empty'''
      return bool( self.__values )

   def last( self ):
      '''Returns current value on 'top' of the stack'''
      return self.__values[ -1 ]

# Singleton used internally to denote omitted parameters
_OMIT = object()

# Line of '='s used by unittest to format some messages
_SEPARATOR1 = '=' * 70

# Line of '-'s used by unittest to format some messages
_SEPARATOR2 = '-' * 70

# Information about a subtest currently in stack
_SubTest = collections.namedtuple( '_SubTest', (
      'description',
      'parameters',
   ) )

# Python exception triple in one object
_ExceptionInfo = collections.namedtuple( '_ExceptionInfo', (
      'type',
      'value',
      'traceback',
   ) )

def _currentExceptionInfo():
   '''Returns current exception triple as _ExceptionInfo object, or None if none'''
   result = sys.exc_info()
   if result == ( None, None, None ):
      result = None
   else:
      result = _ExceptionInfo( *result )
   return result

# Information about a failure accumulated by subtest group
_Failure = collections.namedtuple( '_Failure', (
      'subTest',
      'exceptionInfo',
   ) )

# Information about a subtest group currently in stack
_Group = collections.namedtuple( '_Group', (
      'failureByException',
      'failfast',
   ) )

def _formatFailure( groupDescription, failure ):
   '''
   Generates failure description in the same form as unittest

   That is:
   ```
   ======================================================================
   FAIL: runTest (unittest.case.TestCase) [test] (success=False)
   ----------------------------------------------------------------------
   ```
   but with unused parts omitted.
   '''
   # Does it by first generating template in the form:
   # ( '{separator1}\nFAIL: {groupDescription} [{subTestDescription}]'
   # ' ({0[0]}={0[1]!r}, {1[0]}={1[1]!r}, {2[0]}={2[1]!r})\n{separator2}\n' )
   # with unused parts omitted, and then calling its format method.
   parts = []
   parts.append( '{separator1}\nFAIL:' )
   if groupDescription is not _OMIT:
      parts.append( ' {groupDescription}' )
   if failure.subTest:
      subTestDescription = failure.subTest.description
      if subTestDescription is not _OMIT:
         parts.append( ' [{subTestDescription}]' )
      subTestParameters = tuple( failure.subTest.parameters.items() )
      if subTestParameters:
         parts.append( ' (' )
         for parameterIndex in range( len( subTestParameters ) ):
            parts.append( f'{{{parameterIndex}[0]}}={{{parameterIndex}[1]!r}}' )
            parts.append( ', ' )
         parts[ -1 ] = ')'  # replaces last comma
      if subTestDescription is _OMIT and not subTestParameters:
         parts.append( ' (<subtest>)' )
   else:
      subTestDescription = None
      subTestParameters = ()
      if groupDescription is _OMIT:
         parts.append( ' <subtestGroup>' )
   parts.append( '\n{separator2}\n' )
   return ''.join( parts ).format(
         *subTestParameters,
         separator1=_SEPARATOR1,
         groupDescription=groupDescription,
         subTestDescription=subTestDescription,
         separator2=_SEPARATOR2,
      )

BaseExceptionGroup = getattr( __builtins__, 'BaseExceptionGroup', None )
if BaseExceptionGroup is None:
   # Not on 3.11 yet
   class BaseExceptionGroup( BaseException ):
      '''Exception representing a group of exceptions.'''

      def __init__( self, msg, excs ):
         self.message = msg
         self.exceptions = excs

# Nested subtest groups belonging to current thread
_GROUP_STACK = _ThreadLocalStack()

@contextlib.contextmanager
def subTestGroup( description=_OMIT, *, stream=None, failfast=False ):
   '''Returns context manager representing group of subtests

   In unittest library subtest results are collected and reported by the test.
   Since we don't have a test object around,
   this context manager provides similar functionality.

   :param description: description, analogous to test name in unittest
   :param stream: stream to write output to, defaults to stderr
   :param failfast: controls nested subtests behaviour
      when False (default), all subtests are attempted and errors reported together
      when True, first failed subtest terminates the group by propagating exception
   '''
   stream = stream or sys.stderr
   failureByException = {}
   group = _Group( failureByException=failureByException, failfast=failfast )
   with _GROUP_STACK.hold( group ):
      try:
         yield
      except:  # pylint: disable=bare-except
         exceptionInfo = _currentExceptionInfo()
         failure = _Failure( subTest=None, exceptionInfo=exceptionInfo )
         group.failureByException.setdefault( exceptionInfo.value, failure )

      if not failureByException:
         return

      try:
         if len( failureByException ) == 1:  # pylint: disable=no-else-raise
            [ [ exception, failure ] ] = failureByException.items()
            stream.write( _formatFailure( description, failure ) )
            # Reraise the only exception we have,
            # no need to print traceback in this case.
            raise exception
         else:
            for exception, failure in failureByException.items():
               stream.write( _formatFailure( description, failure ) )
               traceback.print_exception( *failure.exceptionInfo, file=stream )
            stream.write( f'{_SEPARATOR2}\n' )
            # Raise exception group from None to prevent extra tracebacks.
            raise BaseExceptionGroup(
                  '<subTests>' if description is _OMIT else f'{description}',
                  tuple( failureByException ) ) from None
      finally:
         stream.write( 'FAILED\n' )
         stream.flush()


# Nested subtests belonging to current thread
_SUBTEST_STACK = _ThreadLocalStack()

@contextlib.contextmanager
def subTest( msg=_OMIT, **parameters ):
   '''Returns context manager representing group of subtests

   Arguments has same meaning as those of unittest.TestCase.subTest
   '''
   if not _GROUP_STACK:
      # Need new subtest group, since we are not currently in one.
      with subTestGroup( failfast=True ):
         with subTest( msg=msg, **parameters ):
            yield
      return
   group = _GROUP_STACK.last()

   mergedParameters = ( _SUBTEST_STACK.last().parameters.new_child
         if _SUBTEST_STACK else collections.ChainMap )( parameters )
   subTestObject = _SubTest( description=msg, parameters=mergedParameters )
   with _SUBTEST_STACK.hold( subTestObject ):
      try:
         yield
      except:  # pylint: disable=bare-except
         exceptionInfo = _currentExceptionInfo()
         failure = _Failure( subTest=subTestObject, exceptionInfo=exceptionInfo )
         group.failureByException.setdefault( exceptionInfo.value, failure )
         if group.failfast:
            raise
