# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import RcfAst
import RcfAstVisitor
import RcfImmediateValueHelper

ValueHelper = RcfImmediateValueHelper.RcfImmediateValueValidationHelper
CollectionHelper = RcfImmediateValueHelper.RcfCollectionValidationHelper

class ValueValidationCommon( RcfAstVisitor.Visitor ):
   def __init__( self, diags ):
      """ Common implementation for per function and cross function value validation
      In these two compiler phases different variables have to be validated in the
      same way.

      Args:
         diags (RcfDiag): the diagnostic report object.
      """
      super().__init__()
      self.diags = diags
      self.currentFunction = None
      self.currentExpression = None

   ###########################################################################
   #                   V I S I T    M E T H O D                              #
   ###########################################################################

   def visitConstant( self, constant, **kwargs ):
      valueHelper = ValueHelper( self.currentFunction, self.diags )
      valueHelper.validate( constant )

   def visitCommunityValue( self, commVal, **kwargs ):
      valueHelper = ValueHelper( self.currentFunction, self.diags )
      valueHelper.validate( commVal )

   def visitRange( self, rangeNode, **kwargs ):
      valueHelper = ValueHelper( self.currentFunction, self.diags )
      valueHelper.validate( rangeNode.lowerBound )
      valueHelper.validate( rangeNode.upperBound )

   def visitCollection( self, collection, **kwargs ):
      collectionHelper = CollectionHelper(
         self.currentFunction, self.currentExpression, self.diags )
      collectionHelper.validate( collection )

class FunctionScopeValueValidation( ValueValidationCommon ):
   """ Validate the type of immediate values, report errors when needed.

   During this phase, we validate immediate values. This phase can emit errors to
   the diag (immediateValueError, immediateCollectionError).

   In this phase, we assume that all types are resolved.
   """
   ###########################################################################
   #                   V I S I T    M E T H O D                              #
   ###########################################################################

   def visitFunction( self, function, **kwargs ):
      self.currentFunction = function
      self.visit( function.block )( function.block )

   def visitBlock( self, block, **kwargs ):
      for stmt in block.stmts:
         self.visit( stmt )( stmt )

   def visitIfStmt( self, ifStmt, **kwargs ):
      self.visit( ifStmt.condition )( ifStmt.condition )
      self.visit( ifStmt.thenBlock )( ifStmt.thenBlock )
      if ifStmt.elseBlock:
         self.visit( ifStmt.elseBlock )( ifStmt.elseBlock )

   def visitSequentialExpr( self, sequentialExpr, **kwargs ):
      self.visit( sequentialExpr.expr )( sequentialExpr.expr )
      if sequentialExpr.nextExpr:
         self.visit( sequentialExpr.nextExpr )( sequentialExpr.nextExpr )

   def visitExternalRefOp( self, externalRefOp, **kwargs ):
      self.currentExpression = externalRefOp
      self.visit( externalRefOp.attribute )( externalRefOp.attribute )
      self.visit( externalRefOp.rhs )( externalRefOp.rhs )
      for additionalAttribute in externalRefOp.additionalAttributes:
         self.visit( additionalAttribute )( additionalAttribute )
      self.currentExpression = None

   def visitAssign( self, assign, **kwargs ):
      self.currentExpression = assign
      self.visit( assign.attribute )( assign.attribute )
      self.visit( assign.value )( assign.value )
      self.currentExpression = None

   def visitReturn( self, returnOperation, **kwargs ):
      self.visit( returnOperation.expr )( returnOperation.expr )

   def visitExit( self, exitOperation, **kwargs ):
      self.visit( exitOperation.expr )( exitOperation.expr )

   def visitBinOp( self, binOp, **kwargs ):
      self.currentExpression = binOp
      self.visit( binOp.lhs )( binOp.lhs )
      self.visit( binOp.rhs )( binOp.rhs )
      self.currentExpression = None

   def visitLogicalOp( self, logicalOp, **kwargs ):
      for expr in logicalOp.expressionList:
         self.visit( expr )( expr )

   def visitNot( self, notExpr, **kwargs ):
      self.visit( notExpr.expr )( notExpr.expr )

class LinkerValueValidation( ValueValidationCommon ):
   """ Validate the type of immediate values in function arguments,
   report errors when needed.

   During this phase, we validate immediate values. This phase can emit errors to
   the diag (immediateValueError, immediateCollectionError).

   In this phase, we assume that all types are resolved.
   """
   def visitFunction( self, function, **kwargs ):
      self.currentFunction = function
      for functionCall in function.functionCalls:
         self.visitCall( functionCall )

   def visitCall( self, call, **kwargs ):
      if call.functionSelf:
         # There is no value validation to be done on call.functionSelf today because
         # this can only be an attribute which does not require any value validation.
         # In the future if call.functionSelf can be something other than an
         # attribute, we will need to re-consider how to validate it here.
         assert isinstance( call.functionSelf, RcfAst.Attribute )
      for funcArg in call.funcArgs:
         self.visit( funcArg )( funcArg )
