# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from collections import (
   defaultdict,
   deque,
)
import itertools

import Tac
import QuickTrace
from RcfLibUtils import collectionNameFromTypename
from Task import (
   Task,
   TaskRunResult,
)

qv = QuickTrace.Var
qt8 = QuickTrace.trace8

class KeyCounter:
   '''
   Just a simple running counter. This will be used as the 'key' constructor arg when
   instantiating entity nodes in collections.

   Thin wrapper to allow tests to reset the counter
   '''
   def __init__( self ):
      self.counter = itertools.count( start=int( 1 ) )

   def get( self ):
      return next( self.counter )

   def reset( self, key=1 ):
      self.counter = itertools.count( start=int( key ) )

class InstantiatingCollectionGV:
   '''
   Global objects required to use instantiating collections to construct RCF objects
   such as AET nodes or DebugSymbols.
   These objects are in this classes namespace to make it a little more explicit that
   we're interacting with global variables.
   These objects are class members since they are singletons and this prevents us
   having multiple instances of each.
   '''
   # AET
   allAetEntities = None
   allAetEntitiesIdGen = KeyCounter()

   # DebugSymbols
   allDebugSymbolEntities = None
   allDebugSymbolEntitiesIdGen = KeyCounter()

   # For testing to track keys used per function in a compilation
   allAetEntitiesKeysUsed = []
   # keyed by aet
   allAetEntitiesKeysUsedPerFn = defaultdict( list )
   allDebugSymbolEntitiesKeysUsed = []
   # keyed by Rcf::Debug::FunctionSymbol
   allDebugSymbolEntitiesKeysUsedPerFn = defaultdict( list )

   @classmethod
   def setup( cls, allAetEntities, allDebugSymbolEntities ):
      cls.teardown()
      cls.allAetEntities = allAetEntities
      cls.allDebugSymbolEntities = allDebugSymbolEntities

   @classmethod
   def teardown( cls ):
      cls.allAetEntities = None
      cls.allDebugSymbolEntities = None
      cls.allAetEntitiesKeysUsed = []
      cls.allDebugSymbolEntitiesKeysUsed = []
      cls.allAetEntitiesKeysUsedPerFn = defaultdict( list )
      cls.allDebugSymbolEntitiesKeysUsedPerFn = defaultdict( list )

   @classmethod
   def popDebugKeys( cls ):
      keys = cls.allDebugSymbolEntitiesKeysUsed
      cls.allDebugSymbolEntitiesKeysUsed = []
      return keys

   @classmethod
   def pushDebugKeys( cls, keys ):
      cls.allDebugSymbolEntitiesKeysUsed = keys

   @classmethod
   def resetNodeInstantiatorForTest( cls ):
      cls.allAetEntitiesIdGen.reset()
      cls.allDebugSymbolEntitiesIdGen.reset()

   @classmethod
   def logAetKeysUsed( cls, aet ):
      assert aet not in cls.allAetEntitiesKeysUsedPerFn
      keys = cls.allAetEntitiesKeysUsed
      assert keys
      cls.allAetEntitiesKeysUsedPerFn[ aet ] = keys
      cls.allAetEntitiesKeysUsed = []
      return keys

   @classmethod
   def logDebugSymbolKeysUsed( cls, debugSymbols ):
      assert debugSymbols not in cls.allDebugSymbolEntitiesKeysUsedPerFn
      keys = cls.allDebugSymbolEntitiesKeysUsed
      assert keys
      cls.allDebugSymbolEntitiesKeysUsedPerFn[ debugSymbols ] = keys
      cls.allDebugSymbolEntitiesKeysUsed = []
      return keys

   @staticmethod
   def getAllKeysInInstantiatingColls( collObject ):
      collNames = sorted( attrName for attrName in collObject.attributes
                          if attrName.endswith( 'Coll' ) )
      keys = []
      for collName in collNames:
         coll = getattr( collObject, collName )
         keys += list( coll )
      return keys

   @classmethod
   def getAllAetKeysInInstantiatingColls( cls ):
      return cls.getAllKeysInInstantiatingColls( cls.allAetEntities )

   @classmethod
   def getAllDebugKeysInInstantiatingColls( cls ):
      return cls.getAllKeysInInstantiatingColls( cls.allDebugSymbolEntities )

   @staticmethod
   def getLoggedKeysForFunctions( functions, keysPerFn ):
      allKeys = []
      for fn in functions:
         allKeys += keysPerFn[ fn ]
      return allKeys

   @classmethod
   def getAllLoggedAets( cls ):
      return list( cls.allAetEntitiesKeysUsedPerFn )

   @classmethod
   def getAllLoggedAetKeys( cls ):
      return cls.getLoggedKeysForFunctions( cls.getAllLoggedAets(),
                                            cls.allAetEntitiesKeysUsedPerFn )

   @classmethod
   def getAllLoggedDebugSymbols( cls ):
      return list( cls.allDebugSymbolEntitiesKeysUsedPerFn )

   @classmethod
   def getAllLoggedDebugSymbolKeys( cls ):
      return cls.getLoggedKeysForFunctions( cls.getAllLoggedDebugSymbols(),
                                            cls.allDebugSymbolEntitiesKeysUsedPerFn )

   @classmethod
   def getLoggedAetKeysForAet( cls, aet ):
      return cls.getLoggedKeysForFunctions(
         [ aet ], cls.allAetEntitiesKeysUsedPerFn )

   @classmethod
   def instantiator( cls, collObjectName ):
      '''
      Returns a constructor that will instantiate types with a given object of
      collections.
      '''
      def typeInstantiator( typename ):
         '''
         Returns a constructor that will instantiate a given type.
         '''
         collName = collectionNameFromTypename( typename )
         keyGen = getattr( cls, collObjectName + 'IdGen' )

         def instantiate( *args ):
            '''
            Returns an object created with a given instantiating collection.
            '''
            collObj = getattr( cls, collObjectName )
            coll = getattr( collObj, collName )
            key = keyGen.get()
            getattr( cls, collObjectName + 'KeysUsed' ).append( key )
            qt8( 'Alloc key ', qv( key ), ' for collName ', qv( collName ) )
            return coll.newMember( key, *args )

         return staticmethod( instantiate )
      return typeInstantiator

   @classmethod
   def cleaner( cls, collObjectName, baseType, ):
      '''
      Returns a destructor that will remove a given node from it's instantiating
      collection.
      '''
      def nodeCleaner( node ):
         if node is None:
            return
         assert isinstance( node, baseType )

         collName = collectionNameFromTypename( node.tacType.fullTypeName )
         collObj = getattr( cls, collObjectName )
         coll = getattr( collObj, collName )
         qt8( 'Free key ', qv( node.key ), ' for collName ', qv( collName ) )

         assert node.key in coll
         del coll[ node.key ]
      return nodeCleaner

class CleanupHelperBase:
   """
   This base class helps cleanup nodes from the instantiating collections.
   It uses python/TACC reflection using node.attributes to find all nodes in the
   tree.

   Contructor Args:
   - typesMap: provided to specialize the cleanup of a given type.
   - baseTypeForAllNodes: the baseType of the nodes in the tree that need to be
                          cleaned up for instantiating collections
   - cleanupInstantiatingCollFunc: function to cleanup a given node from its
                                   instantiating collections

   Public functions:
   * cleanup( node )

   The rest are internal functions.
   """
   def __init__( self, typeMap, baseTypeForAllNodes, cleanupInstantiatingCollFunc ):
      self.typeMap = typeMap
      self.baseTypeForAllNodes = baseTypeForAllNodes
      self.cleanupInstantiatingCollFunc = cleanupInstantiatingCollFunc

   def cleanup( self, node ):
      if node is None:
         return

      func = self.typeMap.get( type( node ) )
      if func:
         func( node )
      else:
         for attrName in node.attributes:
            if not node.tacType.attr( attrName ).hasDataMember:
               # Skip aliases and functions
               continue
            attr = getattr( node, attrName )
            if isinstance( attr, self.baseTypeForAllNodes ):
               self.cleanup( attr )
         self.cleanupInstantiatingCollFunc( node )

   def _cleanupLinkedList( self, head ):
      # The number of statements in a list could be big. We don't want each
      # entry in a list to be cleaned up recursively from the previous entry
      # in the linked list. Instead, we collect all the entries into a list and
      # just cleanup the 'entry' and 'next' in each element in the list.
      nodes = []
      cur = head
      while cur is not None:
         nodes.append( cur )
         cur = cur.next

      for node in nodes:
         self.cleanup( node.entry )
         self.cleanupInstantiatingCollFunc( node.next )
      self.cleanupInstantiatingCollFunc( head )

class AetCleanupHelper( CleanupHelperBase ):
   """
   This class helps cleanup the AET nodes from the instantiating collections.
   It has helper functions for the core RCF language constructs.

   Public functions:
   * cleanup( node )

   The rest are internal functions.
   """
   def __init__( self ):
      # pylint: disable-next=cyclic-import,import-outside-toplevel
      import RcfTypeFuture as Rcf
      typeMap = {
         Rcf.Eval.FunctionCallExpressionType: self._cleanupFunctionCallExpression,
         Rcf.Eval.ExpressionListType: self._cleanupLinkedList,
         Rcf.Eval.LogicalOperatorListType: self._cleanupLinkedList,
         Rcf.Eval.IntLinkedListType: self._cleanupLinkedList,
         Rcf.Eval.CommunityLinkedListType: self._cleanupLinkedList,
      }
      super().__init__(
         typeMap, Rcf.Eval.AetNodeBase,
         InstantiatingCollectionGV.cleaner( 'allAetEntities', Rcf.Eval.AetNodeBase )
      )

   def _cleanupFunctionCallExpression( self, node ):
      # We don't want to recursively clean up AETs so ignore the calleePtr attribute
      # and just cleanup this node and its arguments
      self._cleanupLinkedList( node.argListHead )
      self.cleanupInstantiatingCollFunc( node )

class DebugSymbolsCleanupHelper( CleanupHelperBase ):
   """
   This class helps cleanup the DebugSymbol nodes from the instantiating collections.

   Public functions:
   * cleanup( node )

   """
   def __init__( self ):
      # pylint: disable-next=cyclic-import,import-outside-toplevel
      import RcfTypeFuture as Rcf
      typeMap = {
         Rcf.Debug.SymbolContainerType: self._cleanupLinkedList,
      }
      super().__init__(
         typeMap,
         Rcf.Debug.SymbolBase,
         InstantiatingCollectionGV.cleaner( 'allDebugSymbolEntities',
                                            Rcf.Debug.SymbolBase )
      )

class InstantiatingCollectionCleaner:

   class CleanupOperation:
      AET = 1
      FUNCTION_SYMBOL = 2
      SINGLE_SYMBOL = 3

   def __init__( self ):
      # Called by RcfAgentConfigReactor at startup
      CleanupOperation = self.CleanupOperation
      self.tasks = {}
      for taskName, op, func in [
         ( 'AETs', CleanupOperation.AET, self.cleanupAetBacklog ),
         ( 'FunctionSymbols', CleanupOperation.FUNCTION_SYMBOL,
           self.cleanupFunctionSymbolBacklog ),
         ( 'SingleSymbols', CleanupOperation.SINGLE_SYMBOL,
           self.cleanupSingleSymbolBacklog, ),
      ]:
         task = Task( "InstantiatingColl-cleanup-" + taskName, func )
         task.schedule()
         self.tasks[ op ] = task
      self.backlog = {
         # trees deferred until later to recurse and cleanup
         CleanupOperation.AET: deque(),
         CleanupOperation.FUNCTION_SYMBOL: deque(),
         # individual nodes deferred until later to cleanup
         CleanupOperation.SINGLE_SYMBOL: deque(),
      }
      self.aetCleanupHelper = AetCleanupHelper()
      self.debugSymbolsCleanupHelper = DebugSymbolsCleanupHelper()
      # Set if the AETs constructed manually and wont be present in instantiating
      # collections
      self.dontCleanupForTest = False

   def teardown( self ):
      # The tasks hold a reference to this object so we have a circular reference
      # here.
      # The python garbage collector would take care of this, but a backlog of
      # now orphaned entities hanging around until then could cause orphan entities
      # to be attrlogged which we don't want, so break the cycle manually.
      for task in self.tasks.values():
         # Don't let the task fire again
         task.suspend()
      self.tasks.clear()
      self.backlog.clear()

   def alwaysYieldForTest( self ):
      for task in self.tasks.values():
         task.taskForceYieldInTest = True

   def getCleanupOp( self, nodes ):
      '''
      Get the cleanup operation type for all of the nodes.
      Ensure they are all of the same type.
      '''
      def getOp( node ):
         # pylint: disable-next=cyclic-import,import-outside-toplevel
         import RcfTypeFuture as Rcf
         # pylint: disable-next=isinstance-second-argument-not-valid-type
         if isinstance( node, Rcf.Eval.FunctionType ):
            return self.CleanupOperation.AET
         # pylint: disable-next=isinstance-second-argument-not-valid-type
         elif isinstance( node, Rcf.Debug.FunctionSymbolType ):
            return self.CleanupOperation.FUNCTION_SYMBOL
         else:
            # pylint: disable-next=isinstance-second-argument-not-valid-type
            assert isinstance( node, Rcf.Debug.SymbolBase )
            return self.CleanupOperation.SINGLE_SYMBOL

      op = getOp( next( iter( nodes ) ) )
      return op

   def deferCleanup( self, nodes ):
      if self.dontCleanupForTest:
         return
      op = self.getCleanupOp( nodes )
      self.backlog[ op ].extend( nodes )
      self.tasks[ op ].schedule()

   def cleanupAetBacklog( self, shouldYield ):
      gv = InstantiatingCollectionGV
      workDone = 0

      deferredAets = self.backlog[ self.CleanupOperation.AET ]
      while deferredAets:
         aet = deferredAets.popleft()
         del gv.allAetEntitiesKeysUsedPerFn[ aet ]
         self.aetCleanupHelper.cleanup( aet )
         workDone += 1
         if shouldYield():
            break

      return TaskRunResult( shouldReschedule=bool( deferredAets ),
                            workDone=workDone )

   def cleanupFunctionSymbolBacklog( self, shouldYield ):

      if self.backlog[ self.CleanupOperation.SINGLE_SYMBOL ]:
         # These individual symbols were replaced in a FunctionSymbol tree under
         # construction and so hold references to parts of the FunctionSymbol
         # tree they were intended for.
         # See RcfDebugLocationLib.py:updateAetNodeKeysInDebugSymbols
         # We MUST wait for these individual symbols to be cleaned up before
         # cleaning up any FunctionSymbols trees or we risk these individual symbols
         # holding references to parts of the tree we just cleaned up.
         # When we cleanup these trees the nodes become orphans.
         # Holding references to these orphans will risk attrlogging them and
         # if we do that the attrlog connection will close and the agent will
         # restart.
         # Yield the task until the individual symbols are all cleaned up.
         return TaskRunResult( shouldReschedule=True, workDone=0 )

      gv = InstantiatingCollectionGV
      workDone = 0

      deferredFuncionSymbols = self.backlog[ self.CleanupOperation.FUNCTION_SYMBOL ]
      while deferredFuncionSymbols:
         functionSymbol = deferredFuncionSymbols.popleft()
         del gv.allDebugSymbolEntitiesKeysUsedPerFn[ functionSymbol ]
         self.debugSymbolsCleanupHelper.cleanup( functionSymbol )
         workDone += 1
         if shouldYield():
            break

      return TaskRunResult( shouldReschedule=bool( deferredFuncionSymbols ),
                            workDone=workDone )

   def cleanupSingleSymbolBacklog( self, shouldYield ):
      workDone = 0

      deferredDebugSymbols = self.backlog[ self.CleanupOperation.SINGLE_SYMBOL ]
      while deferredDebugSymbols:
         debugSymbol = deferredDebugSymbols.popleft()
         # If an individual symbol is being cleaned up then it MUST have been created
         # and dropped within a functions compilation and so is not part of
         # gv.allDebugSymbolEntitiesKeysUsedPerFn
         # Only this symbol should be cleaned up, NOT its children.
         self.debugSymbolsCleanupHelper.cleanupInstantiatingCollFunc( debugSymbol )
         workDone += 1
         if shouldYield():
            break

      return TaskRunResult( shouldReschedule=bool( deferredDebugSymbols ),
                            workDone=workDone )

   def cleanup( self, nodes ):
      '''
      Triggers the cleanup of the provided nodes by a deferred task.
      '''
      if not nodes:
         return
      if not isinstance( nodes, ( list, set ) ):
         nodes = [ nodes ]
      self.deferCleanup( nodes )
