#!/usr/bin/env python3
# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tracing
import RcfTypeFuture as Rcf
from antlr4.tree.Tree import TerminalNodeImpl
from RcfInstantiatingCollHelpers import InstantiatingCollectionGV
from collections import namedtuple

t8 = Tracing.t8

Location = namedtuple( "Location", "line column length" )

class RcfDebugLocationVisitor:
   '''
   Visitor to walk the parse tree and determine the location of the RCF text it
   represents. This location is used to build the RcfDebugSymbols for a given
   function. As AET nodes are constructed their parse contexts are used to determine
   the RCF text that they were generated from.

   Constructor Args:
      excludeChildNodes (list of ParserContext objects):
         Objects which should not be included in the traversal, and thus not included
         in the final location. Traversal should terminate at these nodes in the
         parse tree.
         For example, when traversing the FuncDecl below we would not want to
         include all of it's children or else we would get the location of the
         entire function. Specifying the BlockDecl as part of excludeChildNodes
         would get us the location of 'function foo()' without the block.

                   +------------+
                   |  FuncDecl  |
                   +------------+
                          |
              +-----------+---------+-------+-------------+
              v           v         v       v             v
         +---------+   +-----+    +--+    +--+     +------------+
         |function |   | foo |    |( |    |) |     | BlockDecl  |
         +---------+   +-----+    +--+    +--+     +------------+
                                                          |
                                                  +-------+--------+
                                                  v       v        v
                                                +--+   +-----+   +--+
                                                |{ |   | ... |   |} |
                                                +--+   +-----+   +--+

   Attributes:
      currentLocationData (RcfDebugLocationVisitor.LocationData):
         The location in progress. As nodes in the parse tree are consumed this
         location is updated to include these nodes. This data is per line.
      locationContainerHead (Rcf.Debug.SymbolContainer)
         LocationData is per line. If the parse context extends across multiple lines
         then multiple Rcf.Debug.SymbolLocation objects are required. These are
         stored in a linked list, where this is the head of the list. These location
         are stored in reverse order.
   '''
   def __init__( self, excludeChildNodes, useNamedTuples=False ):
      self._excludeChildNodes = excludeChildNodes
      self._useNamedTuples = useNamedTuples
      self._currentLocationData = None
      self._locationContainerHead = None

   class LocationData:
      '''
      Storage type for the location of multiple nodes in a ParseTree. This location
      is per line. When a token is on a new line an Rcf.Debug.SymbolLocation is built
      for this location and a new LocationData object is created for the new line.

      '''
      def __init__( self, initialToken ):
         self._lineNumber = initialToken.line
         self._columnOffset = initialToken.column
         self._firstTokenStart = initialToken.start
         self._lastTokenStop = initialToken.stop

      def __repr__( self ):
         return "LocationData(line={}, columnOffset={}, length={})".format(
            self._lineNumber, self._columnOffset, self.length() )

      def lineNumber( self ):
         return self._lineNumber

      def length( self ):
         return self._lastTokenStop - self._firstTokenStart + 1

      def consumeToken( self, token ):
         self._lastTokenStop = token.stop

      def SymbolLocation( self ):
         return Rcf.Debug.SymbolLocation( self._lineNumber,
                                          self._columnOffset,
                                          self.length() )

      def NamedTuple( self ):
         return Location( self._lineNumber,
                          self._columnOffset,
                          self.length() )

   def addLocation( self ):
      # Create Rcf.Debug.SymbolLocation and insert it into the linked list
      assert self._currentLocationData
      t8( "Adding location:", self._currentLocationData )
      if self._useNamedTuples:
         if self._locationContainerHead is None:
            self._locationContainerHead = []
         self._locationContainerHead.append(
            self._currentLocationData.NamedTuple() )
      else:
         self._locationContainerHead = Rcf.Debug.SymbolContainer(
            self._currentLocationData.SymbolLocation(),
            self._locationContainerHead )
      self._currentLocationData = None

   def getLocations( self ):
      return self._locationContainerHead

   def visitChildren( self, context ):
      children = [ child for child in context.children
                   if child not in self._excludeChildNodes ]
      for child in children:
         self.visit( child )

   def visit( self, node, **kwargs ):
      if hasattr( node, 'children' ):
         self.visitChildren( node )
         return

      token = node.symbol if hasattr( node, 'symbol' ) else node
      t8( "Token:", token.text )
      currentLine = ( self._currentLocationData.lineNumber()
                      if self._currentLocationData else None )
      if currentLine and token.line != currentLine:
         self.addLocation()

      if self._currentLocationData is None:
         self._currentLocationData = RcfDebugLocationVisitor.LocationData( token )

      self._currentLocationData.consumeToken( token )

def getLocation( nodes, excludeChildNodes=None, useNamedTuples=False ):
   if nodes is None:
      return None

   t8( "Getting location for", type( nodes ) )

   if excludeChildNodes is None:
      excludeChildNodes = []
   assert isinstance( excludeChildNodes, list )

   if not isinstance( nodes, list ):
      nodes = [ nodes ]

   visitor = RcfDebugLocationVisitor( excludeChildNodes, useNamedTuples )
   for node in nodes:
      visitor.visit( node )
      visitor.addLocation()
   return visitor.getLocations()

class PointToken:
   '''
   Used to represent points or markers in the RCF text of important locations
   but which don't have an actual token to correspond to. For example
   precedence implies parenthesis are at given locations although the parenthesis
   do not exist in the RCF text.
   These PointToken conform to a normal tokens interface so they can be passed
   as normal tokens. However are zero length as they are not representing anything
   in the text, just a point in the text.
   '''
   def __init__( self, line, column ):
      self.text = ''
      self.line = line
      self.column = column
      # The start and stop markers are indexes, so to achieve a zero length
      # their sum must be -1
      self.start = 0
      self.stop = -1

def getEdgeTokenImpl( node, index ):
   if node is None:
      return None

   while hasattr( node, 'children' ):
      node = node.children[ index ]
   token = node.symbol if hasattr( node, 'symbol' ) else node
   return token

def getLeftmostPointToken( node ):
   if isinstance( node, list ):
      node = node[ 0 ]
   token = getEdgeTokenImpl( node, 0 )
   if not token:
      return None
   # Get the point just before this token of the 0th column
   return PointToken( token.line, token.column )

def getRightmostPointToken( node ):
   if isinstance( node, list ):
      node = node[ -1 ]
   token = getEdgeTokenImpl( node, -1 )
   if not token:
      return None
   tokenLength = token.stop - token.start
   lastColumnInToken = token.column + tokenLength
   # Use the column just after the token for the point.
   # RHS points MUST come after the represented text.
   return PointToken( token.line, lastColumnInToken + 1 )

def updateAetNodeKeysInDebugSymbols( aetNodeSymbolContainer,
                                     origAetKeys, newAetKeys,
                                     instantiatingCollectionCleaner ):
   '''
   Helper function that recreates DebugSymbols for one AET with the AetNodeKeys
   of another AET.

   When an AET is deduped, and existing AET is used instead of the one just created.
   The existing AET must be identical in all but two ways
   - function names may differ (not relevant here)
   - AetNodeKeys may differ
   However the AetNodeKeys are important for the DebugSymbols as they provide the
   mapping from an AET node to that nodes debug symbol, which for example contains
   the location of the RCF text that produced this node.
   Therefore the DebugSybmols must reference the keys used in the deduped AET.

   To do this a mapping is required of aetNodeKey values from origAet -> newAet.
   AetNodeKeys are assigned incrementally in the same order each time an AET with the
   same structure is created.
   This means we can produce a mapping of AetNodeKeys between two AETs by sorting
   all of the keys used for each and map them in ascending order.

   With this mapping, the debugSymbols can be recreated with all of the same values
   but replacing the original AetNodeKeys with the corresponding new AetNodeKeys.

   Arguments:
   - aetNodeSymbolContainer: Linked list of debug symbols to be recreated with new
                             AetNodeKeys.
   - origAetKeys: The keys used to construct the original AET referenced in the
                  debugSymbols.
   - newAetKeys: The keys in the AET that the debugSymbols should reference.
   - instantiatingCollectionCleaner: Cleaner responsible for reming instantiated AET
                                     nodes from the instantiating collections.

   '''

   # construct the mapping of origAetKeys -> newAetKeys
   assert len( origAetKeys ) == len( newAetKeys )
   aetNodeKeyMapping = {}
   for oldKey, newKey in zip( sorted( origAetKeys ), sorted( newAetKeys ) ):
      aetNodeKeyMapping[ oldKey ] = newKey

   # Pop the keys used for the debugSymbols just constructed. Some of these keys
   # will be cleaned up when nodes are replaced.
   debugSymbolKeys = InstantiatingCollectionGV.popDebugKeys()
   assert debugSymbolKeys
   debugSymbolKeys = set( debugSymbolKeys )
   symbolsToCleanup = []

   # Collect the SymbolContainer entries in a list.
   # The container nodes must be recreated so clean these up.
   curr = aetNodeSymbolContainer
   assert curr
   nodes = []
   while curr:
      nodes.append( curr.entry )
      symbolsToCleanup.append( curr )
      curr = curr.next

   # Recreate the symbols list in the same order with the same values except using
   # the new AetNodeKey values.
   # Cleanup the entries after are replaced.
   # Location symbol nodes do not need to be recreated as they do not contain
   # referenced to the AetNodeKey values, so they can be reused.
   head = None
   for node in reversed( nodes ):
      # pylint: disable-next=isinstance-second-argument-not-valid-type
      if isinstance( node, Rcf.Debug.InternalAetNodeSymbolType ):
         newAetSymbol = Rcf.Debug.InternalAetNodeSymbol(
            aetNodeKeyMapping[ node.aetNodeKey ] )
      else:
         newAetSymbol = Rcf.Debug.AetNodeSymbol(
            aetNodeKeyMapping[ node.aetNodeKey ], node.fragmentType, node.location,
            node.nodeSpecificData )
      symbolsToCleanup.append( node )
      head = Rcf.Debug.SymbolContainer( newAetSymbol, head )

   for symbol in symbolsToCleanup:
      debugSymbolKeys.remove( symbol.key )
   instantiatingCollectionCleaner.cleanup( symbolsToCleanup )

   # All of the keys for the DebugSymbol objects just created should have been
   # logged. Merge the previously logged keys (from which cleaned up nodes have
   # been removed) with the keys created during this functions evaluation.
   newDebugSymbolKeys = InstantiatingCollectionGV.popDebugKeys()
   assert newDebugSymbolKeys
   debugSymbolKeys.update( newDebugSymbolKeys )
   InstantiatingCollectionGV.pushDebugKeys( list( debugSymbolKeys ) )

   return head

def splitTokenContainingWhitespace( parser, tokenNodeWithWhitespace ):
   '''
   In some cases RCF syntax requires matching a string containing special
   characters in it. For example a community list with a hyphen 'COMM-LIST'.
   However the it is not desirable to generally match these special characters
   in strings so to get around matching these special characters in all strings
   these lexer rules are prefixed with a keyword before matching a string with a
   special character.

   However this means that the prefix keyword and the string form single token.
   For example 'community_list COMM-LIST' is one token. This is an issue because
   there can be arbitrary whitespace between the prefix keyword and the string.
   This can include newlines. However the expectation is that there is that a token
   only exists on a single line when DebugSymbols are constructed or compilation
   errors try to locate the text where an error occured.

   This function is provided to the antlr parser to be called on the parse
   contexts where these tokens containing whitespace are used. The parser will
   pass in the parser ( with the parser context containing the
   tokenNodeWithWhitespace ) and the tokenNodeWithWhitespace that is the token
   that contains whitespace.

   This function will then replace the tokenNodeWithWhitespace with multiple tokens,
   one for each part of the token excluding whitespace. The locations, etc of these
   tokens are updated to be as if the lexer provided the tokens this way.
   '''
   newTokens = []
   tokenToSplit = tokenNodeWithWhitespace.symbol
   line = tokenToSplit.line
   column = tokenToSplit.column
   start = tokenToSplit.start

   for perLineText in tokenToSplit.text.split( '\n' ):
      # a blank line does not need a token
      while perLineText.strip():
         # Make a copy of the token that can contain a piece of the existing token
         newToken = tokenToSplit.clone()
         newToken.text = perLineText.split()[ 0 ]
         newToken.line = line

         # Figure out where the next piece of text is
         whitespaceBefore = len( perLineText ) - len( perLineText.lstrip() )
         start += whitespaceBefore
         column += whitespaceBefore
         tokenLen = len( newToken.text )

         # Update the copy with the new token position
         newToken.column = column
         newToken.start = start
         newToken.stop = newToken.start + tokenLen - 1

         # Finish constructing the token
         newNode = TerminalNodeImpl( newToken )
         newNode.parentCtx = tokenNodeWithWhitespace.parentCtx
         newTokens.append( newNode )

         # Advance to the end of the token just created
         start += tokenLen
         column += tokenLen
         perLineText = perLineText[ whitespaceBefore + tokenLen : ]

      # Advance to the next line
      line += 1
      column = 0
      start += len( perLineText ) + 1

   # Replace the token that was just split up in the parse context
   # pylint: disable=protected-access
   nodeIndexInChildren = parser._ctx.children.index( tokenNodeWithWhitespace )
   parser._ctx.children = ( parser._ctx.children[ : nodeIndexInChildren ] +
                            newTokens +
                            parser._ctx.children[ nodeIndexInChildren + 1 : ] )
