# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import RcfAst
from RcfVisitor import RcfVisitor

class LabelGenVisitor( RcfVisitor ):
   """ This visitor walks ANTLR parse tree, looking for labels.

   !Rules (don't change unless discussing with authors first)

      - Define the visit method following same the order in which
        the parser rules are defined in the grammar file.

   Args:
      diag: diag object to store errors/warnings.

   Attributes:
      self.codeUnitKey (CodeUnitKey): the current code unit key
      self.currentFnName (str): the current function name.
      self.currentFnCtx (RcfFunction): the current function.

   """
   def __init__( self, diag, codeUnitKey ):
      self.diag = diag
      self.functionDiag = diag.emptyFunctionDiag()
      self.codeUnitKey = codeUnitKey
      self.currentFnName = None
      self.currentFnCtx = None

      # [ funcName ][ labelName ] = labelLoc
      self.llocPerFunction = {}
      self.functionDefinitions = {}

   def __call__( self, parseTree ):
      """ Visitor instances are functors.

      Primes the visitor to walk over the parse tree.

      Args:
         parseTree: the parse tree from Antlr.

      Returns:
         The AST for this parse tree.
      """
      parseTree.accept( self )
      self.diag.update( self.functionDiag )
      return self.llocPerFunction

   # Visit a parse tree produced by RcfParser#funcDecl.
   def visitFuncDecl( self, ctx ):
      # 'function <name>' will be two tokens, extract <name>
      assert len( ctx.funcDef().children ) == 2
      name = ctx.funcDef().children[ 1 ].getText()
      self.currentFnName = name
      functionNode = RcfAst.Function( ctx, self.currentFnName, self.codeUnitKey,
                                      None )

      if self.currentFnName not in self.llocPerFunction:
         self.llocPerFunction[ self.currentFnName ] = {}
         assert self.currentFnName not in self.functionDefinitions
         self.functionDefinitions[ self.currentFnName ] = functionNode
      else: # function redefinition error
         self.diag.functionNameDefinitionError(
            functionNode,
            existingSymbolNode=self.functionDefinitions[ self.currentFnName ] )
      return self.visitChildren( ctx )

   # Visit a parse tree produced by RcfParser#labeledBlock.
   def visitLabeledBlock( self, ctx ):
      # pylint: disable-next=cyclic-import,import-outside-toplevel
      from RcfLabeling import LabelLoc
      lname = ctx.LABEL().getPayload().text
      ltype = LabelLoc.LabelType.block
      lloc = LabelLoc( ctx=ctx, labelType=ltype, name=lname )
      # We have already visited funcDecl, an entry for this function must exist
      assert self.currentFnName in self.llocPerFunction
      if lname not in self.llocPerFunction[ self.currentFnName ]:
         self.llocPerFunction[ self.currentFnName ][ lname ] = lloc
      else: # label redefinition error
         # TODO: remove Ast.Label
         # TODO: check for re-definition using this visitor instead of symbol gen
         # For now, this will do
         originalLloc = self.llocPerFunction[ self.currentFnName ][ lname ]
         originalLabel = RcfAst.Label( originalLloc.ctx.LABEL().getPayload() )
         redefiningLabel = RcfAst.Label( lloc.ctx.LABEL().getPayload() )
         function = RcfAst.Function( None, self.currentFnName, self.codeUnitKey,
                                     None )
         self.functionDiag.labelDefinitionError(
            func=function,
            redefiningLabel=redefiningLabel,
            originalLabel=originalLabel )
      return self.visitChildren( ctx )
