#!/usr/bin/env python3
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import RcfAst
import RcfAstVisitor
import RcfMetadata
import RcfSymbol
import RcfSymbolTable

BT = RcfMetadata.RcfBuiltinTypes

class FunctionScopeResolutionPhase( RcfAstVisitor.Visitor ):
   """ Per function resolution phase of the semantic analysis.

   During this phase, we try to resolve the symbol that we find in their scope.
   This phase can emit errors (resolution errors and/or warnings).

      - We raise Rcf diags errors if symbols are not defined.
      - We gather meta data (any)

   Note that ResolutionPhase doesn't decide whether a given error is fatal or
   not, it's RcfDiag's job to do so.

      - Define the visit methods in the order in which the AST nodes are defined.
   """
   def __init__( self, diags ):
      """ Constructor

      Args:
         rcfExternalConfig : RcfHelperTypes.RcfExternalConfig
            This python object holds:
               aclConfig (Acl::AclListConfig)
               roaTableStatusDir (Rpki::RoaTableStatusDir)
               dynPfxListConfigDir (DynamicPrefixList::Config)
            These are used to validate
            external references during symbol table generation.
         diags (RcfDiag): the diagnostic report object.
      """
      super().__init__()
      self.diags = diags
      self.attributeScope = RcfSymbolTable.AttributeScope()
      self.currentFunction = None
      self.currentExpression = None

   def visitFunction( self, function, **kwargs ):
      self.currentFunction = function
      self.visit( function.block )( function.block )

   def visitBlock( self, block, **kwargs ):
      for stmt in block.stmts:
         self.visit( stmt )( stmt )

   def visitIfStmt( self, ifStmt, **kwargs ):
      self.visit( ifStmt.condition )( ifStmt.condition )
      self.visit( ifStmt.thenBlock )( ifStmt.thenBlock )
      if ifStmt.elseBlock:
         self.visit( ifStmt.elseBlock )( ifStmt.elseBlock )

   def visitSequentialExpr( self, sequentialExpr, **kwargs ):
      self.visit( sequentialExpr.expr )( sequentialExpr.expr )
      if sequentialExpr.nextExpr:
         self.visit( sequentialExpr.nextExpr )( sequentialExpr.nextExpr )

   def visitCall( self, call, **kwargs ):
      # At this point we can only resolve the symbols in the function arguments,
      # for example attributes. The function call itself will be resolved in the
      # cross function symbol resolution
      for funcArg in call.funcArgs:
         self.visit( funcArg )( funcArg )

   def visitExternalRefOp( self, externalRefOp, **kwargs ):
      self.currentExpression = externalRefOp
      self.visit( externalRefOp.attribute )( externalRefOp.attribute )
      self.visit( externalRefOp.rhs )( externalRefOp.rhs )
      for additionalAttribute in externalRefOp.additionalAttributes:
         self.visit( additionalAttribute )( additionalAttribute )
      self.currentExpression = None

   def visitAssign( self, assign, **kwargs ):
      self.currentExpression = assign
      self.visit( assign.attribute )( assign.attribute )
      self.visit( assign.value )( assign.value )
      self.currentExpression = None

   def visitReturn( self, returnOperation, **kwargs ):
      self.visit( returnOperation.expr )( returnOperation.expr )

   def visitExit( self, exitOperation, **kwargs ):
      self.visit( exitOperation.expr )( exitOperation.expr )

   def resolveImmediateCommSetValue( self, immediateSetValue ):
      # Immediate communities in the AST are either ints (<U32> or <U16>:<U16>) or
      # string (<WELL_KNOWN_COMMUNITY>). Integer type immediate communities do not
      # require any resolution. String type immediate communites need to be resolved
      # in order to get the integer value for the well known community.
      immediateSetValue.value = int( BT.ImmediateCommunity.valueDict[
         immediateSetValue.value ] )

   def resolveImmediateIsisLevelSetValue( self, immediateSetValue ):
      # Isis Level string values need to be resolved in order to get the integer
      # values corresponding to that enum.
      immediateSetValue.value = int( BT.ImmediateIsisLevel.valueDict[
         immediateSetValue.value ] )

   def resolveImmediateSet( self, immediateSet ):
      for value in immediateSet.values:
         if isinstance( value, RcfAst.Attribute ):
            self.visitAttribute( value )
         elif isinstance( value, RcfAst.Constant ):
            if value.type == 'ImmediateIsisLevel':
               self.resolveImmediateIsisLevelSetValue( value )
            elif value.type == 'ImmediateCommunity':
               self.resolveImmediateCommSetValue( value )
         elif isinstance( value, RcfAst.CommunityValue ):
            self.visitCommunityValue( value )
         elif isinstance( value, RcfAst.Range ):
            pass
         else:
            assert False, \
                   f"Invalid RcfAst type in an immediateSet : { type( value ) }"

   def resolveAsPath( self, asPathValue ):
      for node in asPathValue.values:
         self.visit( node )( node )

   def resolveImmediateList( self, immediateList ):
      for value in immediateList.values:
         if isinstance( value, RcfAst.ExternalRef ):
            self.visitExternalRef( value )
         else:
            assert False, \
                   f"Invalid RcfAst type in an immediateList : { type( value ) }"

   def visitCollection( self, collection, **kwargs ):
      if collection.type == RcfAst.Collection.Type.ImmediateSet:
         self.resolveImmediateSet( collection )
      elif collection.type == RcfAst.Collection.Type.AsPathImmediate:
         self.resolveAsPath( collection )
      elif collection.type == RcfAst.Collection.Type.ImmediateList:
         self.resolveImmediateList( collection )

      # Providing information about the operation and lhs can allow us
      # to optimise the RHS value in the AET phase
      self.setOperatorAndLhsSymbol( collection )

   def setOperatorAndLhsSymbol( self, astNode ):
      parent = self.currentExpression
      if parent is not None:
         if isinstance( parent, RcfAst.Assign ):
            symbol = parent.attribute.symbol
            astNode.parentOperator = parent.op
         elif isinstance( parent, RcfAst.BinOp ):
            symbol = parent.lhs.symbol
            astNode.parentOperator = parent.operator
         else:
            assert False, "Unexpected parent node type"
         # The lhs symbol may not have resolved, which would be a compilation error,
         # the rhs is still visited
         if symbol:
            astNode.parentLhsSymbol = symbol

   def visitCommunityValue( self, commVal, **kwargs ):
      for section in commVal.sections:
         self.visit( section )( section )

      # Providing information about the operation and lhs can allow us
      # to optimise the RHS value in the AET phase
      self.setOperatorAndLhsSymbol( commVal )

   def visitAttribute( self, attribute, **kwargs ):
      attributeSymbol = self.attributeScope.resolve( attribute.name )
      if attributeSymbol:
         attribute.symbol = attributeSymbol
      else:
         self.diags.resolutionError( self.currentFunction, attribute,
                                     self.currentExpression )

   def visitBinOp( self, binOp, **kwargs ):
      self.currentExpression = binOp
      self.visit( binOp.lhs )( binOp.lhs )
      self.visit( binOp.rhs )( binOp.rhs )
      self.currentExpression = None

   def visitLogicalOp( self, logicalOp, **kwargs ):
      for expr in logicalOp.expressionList:
         self.visit( expr )( expr )

   def visitNot( self, notExpr, **kwargs ):
      self.visit( notExpr.expr )( notExpr.expr )

   def visitVariable( self, variable, **kwargs ):
      # Today, variables can only be defined through function parameters whose scope
      # is the entire function scope. If we support local variables in the future,
      # this resolution logic will have to change.
      variableSymbol = self.currentFunction.symbol.resolve( variable.name )
      if variableSymbol:
         variable.symbol = variableSymbol
      else:
         self.diags.resolutionError( self.currentFunction, variable,
                                     self.currentExpression )

class LinkerResolutionPhase:
   """ Cross function resolution phase of the semantic analysis.

   During this phase, we try to resolve the following:
   - external symbols
   - function calls

   If these are undefined we raise an error.
   """
   def __init__( self, diags, rcfExternalConfig ):
      assert bool( diags.strict ) == bool( rcfExternalConfig ), \
             "rcfExternalConfig must be provided if and only if compiling strictly"
      self.diags = diags
      self.externalScope = None
      self.builtInExternalScope = RcfSymbolTable.BuiltInExternalScope()
      if rcfExternalConfig:
         self.externalScope = RcfSymbolTable.ExternalScope( rcfExternalConfig )
      self.attributeScope = RcfSymbolTable.AttributeScope()

   def __call__( self, function ):
      for functionCall in function.functionCalls:
         self.visitCall( functionCall, function )
      for externalRef in function.externalRefs:
         self.visitExternalRef( externalRef, function )

   def visitCall( self, call, function ):
      # At this point the function arguments already have their symbols resolved in
      # the per-function compilation phase.
      # We only need to make sure the function call resolves to an existing symbol.
      symbol = call.scope.resolve( call.funcName )
      if not symbol or not isinstance( symbol, RcfSymbol.Function ):
         self.diags.resolutionError( function, call, None )
         return

      parameterCount = len( symbol.funcParamTypes )
      argumentCount = len( call.funcArgs )
      if argumentCount != parameterCount:
         self.diags.mismatchedArgumentCountError( function, call, parameterCount,
                                                  argumentCount )

      # Assume self is an attribute for this function call/method.
      # When variables are allows this approach will need to be revisited
      if symbol.functionSelf:
         call.functionSelf = RcfAst.Attribute( None, symbol.functionSelf )
         call.functionSelf.scope = call.scope
         call.functionSelf.symbol = self.attributeScope.resolve(
                                       call.functionSelf.name )

      call.symbol = symbol

   def visitExternalRef( self, extRef, function ):
      # external references with ONLY built-in names are always checked)
      if extRef.etype in self.builtInExternalScope.dispatchOnType:
         if not self.builtInExternalScope.resolve( extRef ):
            self.diags.extRefResolutionError( function, extRef, None )
      # otherwise, its only checked if compiling in strict mode
      elif self.diags.strict and not self.externalScope.resolve( extRef ):
         self.diags.extRefResolutionError( function, extRef, None )
