#!/usr/bin/env python3
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from collections import defaultdict
from contextlib import contextmanager

import RcfAst
import RcfAstVisitor
import RcfMetadata
import RcfSymbol
import RcfSymbolTable
import RcfTypeFuture as Rcf

BT = RcfMetadata.RcfBuiltinTypes
RcfBuiltinSymbols = RcfMetadata.RcfBuiltinSymbols
RcfTypeSystem = RcfMetadata.RcfTypeSystem

class FunctionScopeDefinitionPhase( RcfAstVisitor.Visitor ):
   """ Definiton phase of the semantic analysis.

   During this phase, we visit the AST, and most importantly:

      - we define built-in symbols (med, prefix etc...)
      - we define each function we encounter
      - we annotate the AST nodes with the scope it belongs to.
      - we gather meta data (any)

   !Rules (don't change unless discussing with authors first)
      - Define the visit methods in the order in which the AST nodes are defined.

      @author: matthieu (rcf-dev)

   Attributes:
      currentScope (Scope): the current scope changes as we change scope.
      diags (RcfDiag): the diagnostic report object.
   """
   def __init__( self, diag ):
      """ Constructor

      Args:
         diags (RcfDiag): the diagnostic report object.
      """
      super().__init__()
      self.diag = diag
      self.currentScope = None
      self.currentFunctionName = None
      self.currentFunctionScope = None

   @contextmanager
   def scope( self, newScope ):
      """ Creates a current scope, and restore the previous scope
      when exiting the context.

      Args:
         newScope ( AbstractScope ): the new scope.
      """
      currentScope = self.currentScope
      self.currentScope = newScope
      yield newScope
      self.currentScope = currentScope

   #---------------------------------------------------------------------------
   #                       visit methods override.
   #---------------------------------------------------------------------------
   def visitFunction( self, function, **kwargs ):
      # Builtin functions provide route map features directly because they do not
      # contain expressions with attributes to provide them.
      builtinFunctionData = RcfMetadata.metadataProcessor.allBuiltinFunctions.get(
         function.name, {} )
      rmFeatures = builtinFunctionData.get( 'routeMapFeatures', None )
      functionSelf = builtinFunctionData.get( 'functionSelf', None )

      funcSymbolScope = RcfSymbol.Function(
         name=function.name, rcfType=BT.Function, retType=BT.Trilean, node=function,
         enclosingScope=self.currentScope, routeMapFeatures=rmFeatures,
         functionSelf=functionSelf, allowedVariableTypenames=BT.variableTypenames )

      # The functions themselves are not defined at this point, because this
      # is only usefull for cross function compilation.
      # In other words, we don't have access to the global scope at this point.

      self.currentFunctionName = function.name
      self.currentFunctionScope = funcSymbolScope
      function.symbol = funcSymbolScope

      # Handle function parameters
      maxCountFunctionParams = (
            Rcf.Eval.FunctionArgumentListType.maximumFunctionArgumentCount() )
      if len( function.funcParams ) > maxCountFunctionParams:
         self.diag.functionMaxParametersDefinitionError( function,
                                                         maxCountFunctionParams )
      with self.scope( funcSymbolScope ):
         for funcParam in function.funcParams:
            self.visitFuncParam( funcParam )
         self.visit( function.block )( function.block )

   def defineLabel( self, label ):
      symbol = self.currentFunctionScope.resolve( label.name )
      if symbol:
         # redefinition of the same label in the same function
         # label names all start with a @, so it's assumed that
         # if similar name is already defined, it's a label as well
         assert isinstance( symbol.node, RcfAst.Label )
         self.diag.labelDefinitionError( self.currentFunctionScope.node,
            label, symbol.node )
      self.currentFunctionScope.define(
         RcfSymbol.Label( label.name, node=label ) )

   def defineFuncParam( self, funcParam, rcfType ):
      symbol = self.currentFunctionScope.resolve( funcParam.name )
      if symbol:
         # Redefinition of the same function parameter or variable in the same
         # function. Parameter/variable names all start with a '$', so it's assumed
         # that if similar name is already defined, it's a variable as well.
         # Until we allow definition of variables outside of function parameters,
         # we can assume that the node defining the variable is a FunctionParam.
         #
         # BUG997321, redefinition error will not be generated if only one variable
         # with the name has a legal type
         assert isinstance( symbol.node, RcfAst.FunctionParam )
         self.diag.variableNameDefinitionError( self.currentFunctionScope.node,
                                                funcParam,
                                                symbol.node )
      else:
         # Get a unique argVector index to be used when accessing the value passed to
         # this parameter using a variable node. The argVector index is unique per
         # function and per type of variable. Note that the type of the variable is
         # not necessarily the same as the type of the function parameter, e.g. the
         # value passed to a function parameter defined as "as_number_type" is
         # accessed in the function body using an "IntVariable" node.
         argVectorIndex = self.currentFunctionScope.getVariableArgVectorIndex(
               funcParam.name,
               rcfType.variableAetTypename )
         functionParamSymbol = RcfSymbol.FunctionParam(
               name=funcParam.name,
               rcfType=rcfType,
               node=funcParam,
               argVectorIndex=argVectorIndex )
         self.currentFunctionScope.define( functionParamSymbol )
         funcParam.symbol = functionParamSymbol

   def addParamTypeToFunction( self, funcParamRcfType ):
      # Add the function parameter RCF type to the list of expected parameter types
      #  for this function symbol.
      self.currentFunctionScope.funcParamTypes.append( funcParamRcfType )

   def visitFuncParam( self, funcParam, **kwargs ):
      if funcParam.name is None:
         # This function parameter is for a builtin function. The typename string is
         # from the metadata so we can assume it is an RCF type name.
         funcParamRcfType = getattr( BT, funcParam.typeStr )
         self.addParamTypeToFunction( funcParamRcfType )
         # There is no need to define a function parameter in this function's scope
         # for builtin function parameters because builtin functions do not refer
         # to their parameter values by name (their relative positions on the stack
         # are hardcoded in the implementation).
      else:
         # This function parameter is for a user defined function. The typename
         # string is whatever the user provided, not an RCF type name.
         funcParamRcfType = RcfTypeSystem.funcParamTypeMapping.get(
               funcParam.typeStr, None )
         self.addParamTypeToFunction( funcParamRcfType )
         if funcParamRcfType is None:
            self.diag.variableTypeDefinitionError( self.currentFunctionScope.node,
                                                   funcParam )
         else:
            # Define the function parameter in this function's scope
            self.defineFuncParam( funcParam, funcParamRcfType )

   def visitBlock( self, block, **kwargs ):
      if block.label is not None:
         self.defineLabel( block.label )
      with self.scope( RcfSymbolTable.BlockScope( self.currentScope ) ):
         for stmt in block.stmts:
            self.visit( stmt )( stmt )

   def visitIfStmt( self, ifStmt, **kwargs ):
      if ifStmt.ifLbl is not None:
         self.defineLabel( ifStmt.ifLbl )
      if ifStmt.elseLbl is not None:
         self.defineLabel( ifStmt.elseLbl )
      self.visit( ifStmt.condition )( ifStmt.condition )
      self.visit( ifStmt.thenBlock )( ifStmt.thenBlock )
      if ifStmt.elseBlock:
         self.visit( ifStmt.elseBlock )( ifStmt.elseBlock )

   def visitCall( self, call, **kwargs ):
      call.scope = self.currentScope
      for funcArg in call.funcArgs:
         self.visit( funcArg )( funcArg )
      # call.functionSelf has not been resolved yet

   def visitSequentialExpr( self, sequentialExpr, **kwargs ):
      self.visit( sequentialExpr.expr )( sequentialExpr.expr )
      if sequentialExpr.nextExpr:
         self.visit( sequentialExpr.nextExpr )( sequentialExpr.nextExpr )

   def visitExternalRefOp( self, externalRefOp, **kwargs ):
      self.visit( externalRefOp.attribute )( externalRefOp.attribute )
      self.visit( externalRefOp.rhs )( externalRefOp.rhs )
      for additionalAttribute in externalRefOp.additionalAttributes:
         self.visit( additionalAttribute )( additionalAttribute )

   def visitExternalRef( self, extRef, **kwargs ):
      extRef.scope = self.currentScope

   def visitAssign( self, assign, **kwargs ):
      self.visit( assign.attribute )( assign.attribute )
      self.visit( assign.value )( assign.value )

   def visitReturn( self, returnOperation, **kwargs ):
      self.visit( returnOperation.expr )( returnOperation.expr )

   def visitExit( self, exitOperation, **kwargs ):
      self.visit( exitOperation.expr )( exitOperation.expr )

   def visitCollection( self, collection, **kwargs ):
      for value in collection.values:
         self.visit( value )( value )

   def visitCommunityValue( self, commVal, **kwargs ):
      for section in commVal.sections:
         self.visit( section )( section )

   def visitAttribute( self, attribute, **kwargs ):
      attribute.scope = self.currentScope

   def visitBinOp( self, binOp, **kwargs ):
      self.visit( binOp.lhs )( binOp.lhs )
      self.visit( binOp.rhs )( binOp.rhs )

   def visitLogicalOp( self, logicalOp, **kwargs ):
      for expr in logicalOp.expressionList:
         self.visit( expr )( expr )

   def visitNot( self, notExpr, **kwargs ):
      self.visit( notExpr.expr )( notExpr.expr )

   def visitVariable( self, variable, **kwargs ):
      variable.scope = self.currentScope

class LinkerDefinitionPhase:
   def __init__( self, diag ):
      self.diag = diag
      self.globalScope = RcfSymbolTable.GlobalScope()
      self.duplicateFunctions = defaultdict( list )

   # ---------------------------------------------------------------------------
   #                       visit methods override.
   # ---------------------------------------------------------------------------

   @staticmethod
   def getFunctionArgTuple( function ):
      assert function.symbol, f"{function.name} symbol not defined"
      return tuple( param.name for param in function.symbol.funcParamTypes )

   def undefineFunctionWithConflictingSignatures( self ):
      '''
      For every duplicate function, verify the function parameters are the same.
      If they differ, make sure we undefine the function from the global scope
      This will prevent functions linking non deterministically against the first
      resolved function definition.
      '''
      for functionName, functions in self.duplicateFunctions.items():
         definedSymbol = self.globalScope.resolve( functionName )
         definedFunction = definedSymbol.node
         definedFunction.valid = False
         firstFunctionParams = self.getFunctionArgTuple( definedFunction )
         for otherFunction in functions:
            assert definedFunction.name == functionName
            assert otherFunction.name == functionName
            otherFunction.valid = False
            if self.getFunctionArgTuple( otherFunction ) != firstFunctionParams:
               self.globalScope.undefine( functionName )

   def __call__( self, root ):
      # This must also include all invalid functions. This is done to make sure
      # function resolution does not fail against invalid functions
      for function in root.validAndInvalidFunctions():
         self.visitFunction( function )

      self.undefineFunctionWithConflictingSignatures()

   def visitFunction( self, function ):
      if function.codeUnitKey.isBuiltin():
         self.globalScope.define( function.symbol )
         return

      # In the cross-function definition phase, only thing to check is for
      # redefinition of a function
      function.symbol.enclosingScope = self.globalScope

      # We can only define the function if all function arguments are correct,
      # see BUG997321. If one of the function argument is incorrect we will
      # not have any resolution against this function.
      validFunctionArgTypes = all( arg.typeStr in RcfTypeSystem.funcParamTypeMapping
                                   for arg in function.funcParams )

      if not validFunctionArgTypes:
         assert not function.valid, ( "This function should have been marked as "
                                      "invalid during FunctionScopeDefinitionPhase" )
         return

      if RcfMetadata.RcfKeywords.isKeyword( function.name ):
         # The function name clashes with a language keyword
         self.diag.functionNameDefinitionError( function )
         function.valid = False
         return

      symbol = self.globalScope.resolve( function.name )
      if symbol:
         self.duplicateFunctions[ function.name ].append( function )
         self.diag.functionNameDefinitionError( function,
                                                existingSymbolNode=symbol.node )
      else:
         # Actually define the function symbol in the scope now that we know it is
         # unique
         self.globalScope.define( function.symbol )
