# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from collections import defaultdict
from functools import total_ordering

@total_ordering
class CallGraphNode:
   '''
   This type represents a either side in the caller -> callee relationship.
   The name of the caller or callee are provided as the node name respectively.
   This type acts as a wrapper around the name string. When these nodes are
   compared or hashed only the name is taken into account.

   However in order to extract information about the cycle detected in
   the callgraph it is useful to store additional information in the callgraph.
   Thus additional attributes are available for
   - callerNode, representing the caller function
   - callSiteNode, representing the call site of the callee within the caller

   When a cycle in the callgraph is found nodes of this type are returned
   and the locations of the call cycle can be extracted when producing the
   compilation error.
   '''
   def __init__( self, name, callerNode=None, callSiteNode=None ):
      self.name = name
      self.callerNode = callerNode
      self.callSiteNode = callSiteNode

   def __hash__( self ):
      return hash( self.name )

   def __eq__( self, other ):
      if not hasattr( other, 'name' ):
         return NotImplemented
      return self.name == other.name

   def __lt__( self, other ):
      if not hasattr( other, 'name' ):
         return NotImplemented
      return self.name < other.name

class Callgraph:
   """ Directed graph of callers and callees

   Each vertex in this graph is a function name, and arcs (directed) represent
   the caller -> callee relationship.
   """
   def __init__( self ):
      self._calleesByCaller = defaultdict( list )

   def copy( self ):
      '''
      A regular deep copy would also recurse into the array. This means
      the ast node including the antler context would be deep copied as well.
      '''
      callgraph = Callgraph()
      for caller, callees in self._calleesByCaller.items():
         callgraph.add( caller, callees )
      return callgraph

   def add( self, caller, callee ):
      if isinstance( callee, list ):
         self._calleesByCaller[ caller ] += callee
      else:
         self._calleesByCaller[ caller ].append( callee )

   def delete( self, caller ):
      del self._calleesByCaller[ caller ]

   def callees( self, caller ):
      # here we are returning a dict of { caller, [callees] }
      return self._calleesByCaller.get( caller, [] )

   def callers( self ):
      return iter( self._calleesByCaller )

def findStronglyConnected( callgraph ):
   """ Tarjan's algorithm for finding strongly connected components. A strongly
   connected component of a graph is one where every vertex is reachable from every
   other vertex. The algorithm runs in linear time.

   Attributes:
      callgraph (default dict): dictionary of callers:[callees] generated from the
      symbol table
   """
   result = []
   index = {}
   # represents the smallest index of any node known to be
   # reachable from v through v's DFS subtree
   lowlink = {}
   stack = []
   indexCounter = [ 0 ] # numbers each node ( funcName ) uniquely

   def _stronglyConnected( funcName ):
      """ Sets the unique index for the func name to the smallest unused index
      """
      index[ funcName ] = indexCounter[ 0 ]
      lowlink[ funcName ] = indexCounter[ 0 ]
      indexCounter[ 0 ] += 1
      stack.append( funcName )
      successors = callgraph.callees( funcName )
      for successor in successors:
         if successor not in index:
            # Successor has not yet been visited; recurse on it
            _stronglyConnected( successor )
            lowlink[ funcName ] = min( lowlink[ funcName ], lowlink[ successor ] )
         elif successor in stack:
            # the successor is in the stack and therefore
            # in the current strongly connected component
            lowlink[ funcName ] = min( lowlink[ funcName ], index[ successor ] )

      # if funcName is a rootNode, pop the stack and generate scc
      if lowlink[ funcName ] == index[ funcName ]:
         connected_component = []
         while True:
            successor = stack.pop()
            connected_component.append( successor )
            if successor == funcName:
               break
         # store the result
         result.append( connected_component )
   for func in callgraph.callers():
      if func not in index:
         _stronglyConnected( func )
   return result # returns a set of lists that represent srongly connected components

def removeFunc( callgraph, target ):
   """ Remove target function from the callgraph
   """
   if callgraph.callees( target ): # check that
      callgraph.delete( target )
      for nbrFunc in callgraph.callers():
         callees = callgraph.callees( nbrFunc )
         if target in callees:
            callees.remove( target )

def subgraph( callgraph, vertices ):
   """ Get the subgraph of the function callgraph induced
    by the set of function names (vert)
   """
   sub = Callgraph()
   for v in vertices:
      for callee in callgraph.callees( v ):
         if callee in vertices:
            sub.add( v, callee )
   return sub

# disable msg to accomidate example graph
# pylint: disable-msg=W1401
def findAllCycles( callgraph ):
   """ Yield each elementary cycle within the callgraph. An elementary
   cycle is defined as a cycle in a graph who's verticies ( and by extention,
   its edges ) are used at most once in the cycle. With exception is the last vertex
   which signifies a cycle. By contrast a simple cycle is defined as a cycle where
   no edge appears nore than once but verticies may be repreated.

   For example, take the graph:

   A -- B
   |  / | \
   | /  |  C
   F -- D /
    \\  /
     E

   One elementary cycle for this would be A > B > C > D > E > F > A
   A simple but not elementary cycle would be: A > B > C > D > E > F > B > D > F > A
   """
   def _unblock( thisFuncName, blocked, B ):
      toUnblock = { thisFuncName }
      while toUnblock:
         currFunc = toUnblock.pop()
         if currFunc in blocked:
            blocked.remove( currFunc )
            toUnblock.update( B[ currFunc ] )
            B[ currFunc ].clear()

   callgraph = callgraph.copy()
   components = findStronglyConnected( callgraph )
   cyclesFound = []
   # pylint: disable-msg=R1702
   while components:
      component = components.pop() # currently investigated strongly connected comp
      startFunc = component.pop()
      path = [ startFunc ]
      blocked = set()
      closed = set()
      blocked.add( startFunc )
      B = defaultdict( set )
      stack = [ ( startFunc, list( callgraph.callees( startFunc ) ) ) ]
      while stack: # walk spanning tree to find cycles
         visitingFunc, calledFuncs = stack[ -1 ] # using last-in
         if calledFuncs:
            nextFunc = calledFuncs.pop()
            if nextFunc == startFunc: # found a cycle
               # append a copy of the path instead of reference
               cyclesFound.append( path[ 1 : ] + [ nextFunc ] )
               closed.update( path )
            elif nextFunc not in blocked:
               path.append( nextFunc )
               stack.append( ( nextFunc, list( callgraph.callees( nextFunc ) ) ) )
               closed.discard( nextFunc )
               blocked.add( nextFunc )
               continue
         if not calledFuncs: # no callees
            if visitingFunc in closed:
               _unblock( visitingFunc, blocked, B )
            else:
               for nbrFunc in callgraph.callees( visitingFunc ):
                  if visitingFunc not in B[ nbrFunc ]:
                     B[ nbrFunc ].add( visitingFunc )
            stack.pop()
            path.pop()
      removeFunc( callgraph, startFunc )
      # by this time, component has start func removed, component may be empty
      H = subgraph( callgraph, component )
      components.extend( findStronglyConnected( H ) )
   return cyclesFound
