# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import RcfAst
import RcfAstVisitor
from RcfHelperTypes import ResolutionRibTypes
import RcfMetadata
import RcfImmediateValueHelper

RcfTypeSystem = RcfMetadata.RcfTypeSystem
BT = RcfMetadata.RcfBuiltinTypes
ValueHelper = RcfImmediateValueHelper.RcfImmediateValueHelper

class FunctionScopeTypeBindingPhase( RcfAstVisitor.Visitor ):
   """ Assign type to attributes and expressions, report errors when needed.

   During this phase, we try to assign types to attributes and expressions.
   This enforces that the program is sound. This phase can emit errors to
   the diag (RcfTypingError).

   In this phase, we assume that all symbols are resolved.

   Attributes:
      diag (RcfDiag): Diagnostic class where to emit errors.
      self.currentFunction (RcfFunction): the current function.

   !Rules (don't change unless discussing with authors first)

      - This implements the pristine typing system, free of shortcomings.
          To implement exceptions due to implementation constraints (say
          AET is not ready yet for a particular operation on a type), use
          the @TypeBindingPhaseOverride.

      - Define the visit methods in the order in which the AST nodes are defined.

   @author: matthieu (rcf-dev)
   """
   def __init__( self, diags ):
      """ Constructor

      Args:
         diags (RcfDiag): the diagnostic report object.
      """
      super().__init__()
      self.diags = diags
      self.currentFunction = None

   ###########################################################################
   #                   V I S I T    M E T H O D                              #
   ###########################################################################

   def visitFunction( self, function, **kwargs ):
      self.currentFunction = function
      self.visit( function.block )( function.block )

   def visitBlock( self, block, **kwargs ):
      for stmt in block.stmts:
         self.visit( stmt )( stmt )

   def visitIfStmt( self, ifStmt, **kwargs ):
      self.visit( ifStmt.condition )( ifStmt.condition )
      self.visit( ifStmt.thenBlock )( ifStmt.thenBlock )
      if ifStmt.elseBlock:
         self.visit( ifStmt.elseBlock )( ifStmt.elseBlock )
      # A conditon's expr must either be Boolean or Trilean
      if ifStmt.condition.evalType not in ( BT.Boolean, BT.Trilean ):
         what = "if condition must evaluate to true, false, or unknown"
         self.diags.typingError( self.currentFunction, ifStmt.condition, what )

   def visitCall( self, call, **kwargs ):
      # At this point we cannot resolve the function call. We set the call.evalType
      # manually, because it's currently a fact in the Rcf programming languae
      # that function return trilean values.
      call.evalType = BT.Trilean
      # Only perform type binding on the arguments itself, but don't promote the type
      # Type promotion can only be done at the cross function type binding phase.
      for funcArg in call.funcArgs:
         self.visit( funcArg )( funcArg )

   def visitExternalRefOp( self, externalRefOp, **kwargs ):
      self.visit( externalRefOp.attribute )( externalRefOp.attribute )
      self.visit( externalRefOp.rhs )( externalRefOp.rhs )
      for additionalAttribute in externalRefOp.additionalAttributes:
         self.visit( additionalAttribute )( additionalAttribute )
      lhsType = externalRefOp.attribute.evalType
      rhsType = externalRefOp.rhs.evalType
      assert not externalRefOp.isExact or not externalRefOp.isMatchCovered
      if externalRefOp.isExact:
         op = "match_exact"
      elif externalRefOp.isMatchCovered:
         op = "match_covered"
      else:
         op = "match"
      evalType = RcfTypeSystem.conditionOpAllowed( lhsType, op, rhsType )
      externalRefOp.evalType = evalType
      if evalType is None:
         lName = externalRefOp.attribute.resolveType().displayName
         rName = externalRefOp.rhs.resolveType().displayName
         what = f"invalid operation '{op}' between {lName} and {rName}"
         self.diags.typingError( self.currentFunction, externalRefOp, what )
         # /!\ On error, Assume that the result of a match operation is a boolean,
         # so that we don't propagate the error all the way up in the expression:
         # e.g:
         #
         # prefix is 10.0.0.0/24 or med > 0 or med match prefix_list FOO
         #                                     ^~~~~~~~~~~~~~~~~~~~~~~~~
         #                                     assume eval type boolean
         #
         # This will help user focusing on the right error.
         externalRefOp.evalType = BT.Boolean

   def visitExternalRef( self, extRef, **kwargs ):
      # We infer the RCF eval type for the external reference based on the type of
      # the external reference. For example, a "prefix_list_v4" external reference
      # type maps to a "PrefixList" RCF eval type.
      if 'prefix_list' in extRef.etype:
         extRef.evalType = BT.PrefixList
      elif extRef.etype == 'as_path_list':
         extRef.evalType = BT.AsPathList
      elif extRef.etype == "community_list":
         extRef.evalType = BT.CommunityList
      elif extRef.etype == "ext_community_list":
         extRef.evalType = BT.ExtCommunityList
      elif extRef.etype == "large_community_list":
         extRef.evalType = BT.LargeCommunityList
      elif extRef.etype == 'roa_table':
         extRef.evalType = BT.RoaTable
      elif extRef.etype in ResolutionRibTypes:
         extRef.evalType = BT.ImmediateResolutionRib

   def visitAssign( self, assign, **kwargs ):
      self.visit( assign.attribute )( assign.attribute )
      lhsType = assign.attribute.evalType
      self.visit( assign.value )( assign.value, lhsType=lhsType )
      fromType = assign.value.evalType

      # Check for constness first, then check the rest.
      if assign.attribute.symbol.const:
         op = assign.op
         what = f"invalid operation '{op}': cannot modify a read only attribute"
         self.diags.typingError( self.currentFunction, assign, what )
         return

      # Check whether the modify operation is supported for this type
      allowed = RcfTypeSystem.modifyOpAllowed( lhsType=lhsType, op=assign.op,
                rhsType=fromType )

      # Set the promoteToType properly as expected by the AET gen module.
      if not allowed:
         for toType in RcfTypeSystem.allowableTypePromotions[ fromType ]:
            if RcfTypeSystem.modifyOpAllowed( lhsType, assign.op, toType ):
               assign.value.promoteToType = toType
               allowed = True
               break

      if not allowed:
         op = assign.op
         rName = assign.value.resolveType().displayName
         lName = assign.attribute.resolveType().displayName
         what = f"invalid operation '{op}' between {lName} and {rName}"
         self.diags.typingError( self.currentFunction, assign, what )
         return

   def visitReturn( self, returnOperation, **kwargs ):
      self.visit( returnOperation.expr )( returnOperation.expr )
      exprType = returnOperation.expr.evalType
      retType = self.currentFunction.symbol.retType
      if RcfTypeSystem.canPromote( exprType, retType ):
         returnOperation.expr.promoteToType = retType

      if exprType != retType and not returnOperation.expr.promoteToType:
         what = "return expression must evaluate to true, false, or unknown"
         self.diags.typingError( self.currentFunction, returnOperation, what )

   def visitExit( self, exitOperation, **kwargs ):
      self.visit( exitOperation.expr )( exitOperation.expr )
      exprType = exitOperation.expr.evalType
      retType = self.currentFunction.symbol.retType
      if RcfTypeSystem.canPromote( exprType, retType ):
         exitOperation.expr.promoteToType = retType

      if exprType != retType and not exitOperation.expr.promoteToType:
         what = "exit expression must evaluate to true, false, or unknown"
         self.diags.typingError( self.currentFunction, exitOperation, what )

   def resolveImmediateCollection( self, immediateCollection,
                                   contentTypeToCollectionType ):
      valueIter = iter( immediateCollection.values )
      value = next( valueIter )
      self.visit( value )( value )
      currentCollectionType = contentTypeToCollectionType[ value.evalType ]

      for value in valueIter:
         self.visit( value )( value )
         if currentCollectionType is not None:
            newCollectionType = contentTypeToCollectionType[ value.evalType ]
            if currentCollectionType == newCollectionType:
               continue
            if RcfTypeSystem.canPromote( newCollectionType, currentCollectionType ):
               continue
            if RcfTypeSystem.canPromote( currentCollectionType, newCollectionType ):
               currentCollectionType = newCollectionType
            else:
               # Failed to resolve the collection, no need to try to resolve types
               # for the rest of the collection. The ValueValidation phase determine
               # the compilation error
               currentCollectionType = None

      if currentCollectionType is not None:
         immediateCollection.evalType = currentCollectionType

   def communityValueTypeCheck( self, commVal ):
      def commValTypeCheck( section, sectionType ):
         if section.evalType == sectionType:
            return True
         for toType in RcfTypeSystem.allowableTypePromotions[ section.evalType ]:
            if toType == sectionType:
               section.promoteToType = toType
               return True
         return False

      def mismatchLengthError( commVal, allowedSections ):
         numToString = { 1: "one", 2: "two", 3: "three" }
         numString = numToString[ len( allowedSections ) ]
         what = ( f"'{commVal.extCommType}' extended community requires "
                  f"{numString} parts" )
         self.diags.typingError( self.currentFunction, commVal, what )

      sections = iter( commVal.sections )
      allowedSections = RcfTypeSystem.extCommunitySections[ commVal.extCommType ]
      for allowedSection in allowedSections:
         section = next( sections )
         if section.evalType == BT.NoneType:
            mismatchLengthError( commVal, allowedSections )
            return
         if not commValTypeCheck( section, allowedSection.rcfType ):
            what = f"invalid type '{ section.evalType.displayName }'"\
                   f", expected '{ allowedSection.rcfType.displayName }'"
            self.diags.typingError( self.currentFunction, section, what )
      for section in sections:
         if section.evalType != BT.NoneType:
            mismatchLengthError( commVal, allowedSections )

   def visitCommunityValue( self, commVal, **kwargs ):
      for section in commVal.sections:
         self.visit( section )( section )
      commVal.evalType = ValueHelper.communityValueTypeToEvalType( commVal.type )
      self.communityValueTypeCheck( commVal )

   def visitConstant( self, constant, **kwargs ):
      constant.evalType = ValueHelper.constantTypeToEvalType( constant.type )

   def visitRange( self, rangeNode, **kargs ):
      self.visit( rangeNode.lowerBound )( rangeNode.lowerBound )
      self.visit( rangeNode.upperBound )( rangeNode.upperBound )
      if BT.AsDot in ( rangeNode.lowerBound.resolveType(),
                       rangeNode.upperBound.resolveType() ):
         rangeNode.evalType = BT.AsNumberRange
      else:
         rangeNode.evalType = BT.IntRange

   def visitCollection( self, collection, **kwargs ):
      collection.evalType = ValueHelper.collectionTypeToEvalType( collection.type )
      if collection.evalType == BT.ImmediateSet:
         self.resolveImmediateCollection( collection, BT.contentTypeToSetType )
      elif collection.evalType == BT.ImmediateList:
         self.resolveImmediateCollection( collection, BT.contentTypeToListType )
      else:
         # as path immediate
         for value in collection.values:
            self.visit( value )( value )

   def visitAttribute( self, attribute, **kwargs ):
      attribute.evalType = attribute.symbol.rcfType

   def typeCheckOperandsWithOperator( self, lhs, operator, rhs ):
      opSym = operator.operator
      if opSym == 'is_not':
         # 'is_not' is special. The AST phase has enclosed this in a not so
         # a regular 'is' operator is used here.
         opSym = 'is'

      operator.evalType = RcfTypeSystem.conditionOpAllowed( lhs.evalType, opSym,
                                                            rhs.evalType )
      if operator.evalType is None:
         for toType in RcfTypeSystem.allowableTypePromotions[ rhs.evalType ]:
            operatorType = RcfTypeSystem.conditionOpAllowed( lhs.evalType, opSym,
                                                             toType )
            if operatorType is not None:
               operator.evalType = operatorType
               rhs.promoteToType = toType
               break

      if operator.evalType is None:
         op = operator.operator
         lName = lhs.resolveType().displayName
         rName = rhs.resolveType().displayName
         what = f"invalid operation '{op}' between {lName} and {rName}"
         self.diags.typingError( self.currentFunction, operator, what )
         # /!\ On error, Assume that the result of a binary operation is a boolean,
         # so that we don't propagate the error all the way up in the expression:
         # e.g:
         #
         # prefix is 10.0.0.0/24 or prefix is 20.0.0.0/24 or prefix is 2
         #                                                   ^~~~~~~~~~~
         #                                                 assume eval type boolean
         #
         # This will help user focusing on the right error.
         operator.evalType = BT.Boolean

   def visitBinOp( self, binOp, **kwargs ):
      lhs = binOp.lhs
      rhs = binOp.rhs
      self.visit( lhs )( lhs )
      self.visit( rhs )( rhs, lhsType=lhs.evalType )
      self.typeCheckOperandsWithOperator( lhs, binOp, rhs )

   def visitLogicalOp( self, logicalOp, **kwargs ):
      for expr in logicalOp.expressionList:
         self.visit( expr )( expr )

      # match every pair of expressions
      firstList = logicalOp.expressionList[ : -1 ]
      secondList = logicalOp.expressionList[ 1 : ]
      for lhs, rhs in zip( firstList, secondList ):
         self.typeCheckOperandsWithOperator( lhs, logicalOp, rhs )

   def visitNot( self, notExpr, **kwargs ):
      self.visit( notExpr.expr )( notExpr.expr )
      # Not's expression must either be Boolean or Trilean
      if notExpr.expr.evalType not in ( BT.Boolean, BT.Trilean ):
         what = "'not' must be applied to true, false, or unknown expression"
         self.diags.typingError( self.currentFunction, notExpr, what )
      notExpr.evalType = notExpr.expr.evalType

   def visitVariable( self, variable, **kwargs ):
      variable.evalType = variable.symbol.rcfType

class LinkerTypeBindingPhase:
   def __init__( self, diags ):
      self.diags = diags

   ###########################################################################
   #                   V I S I T    M E T H O D                              #
   ###########################################################################

   def __call__( self, function ):
      for functionCall in function.functionCalls:
         self.visitCall( functionCall, function )

   def visitCall( self, call, function ):
      # Validate types of funcArgs against symbols
      # Assume the type of functionSelf is correct as this is coming
      # from the metadata
      if call.functionSelf:
         # There is no type checking to be done on call.functionSelf today because
         # this can only be an attribute which does not require any value validation.
         # In the future if call.functionSelf can be something other than an
         # attribute, we will need to re-consider how to validate it here.
         assert isinstance( call.functionSelf, RcfAst.Attribute )

      # If we allow method calls on variables we will need to revisit this.
      for idx, ( funcArg, expArgRcfType ) in \
            enumerate( zip( call.funcArgs, call.symbol.funcParamTypes ) ):
         if RcfTypeSystem.canPromote( funcArg.evalType, expArgRcfType ):
            funcArg.promoteToType = expArgRcfType

         if ( expArgRcfType != funcArg.evalType and
              not RcfTypeSystem.canPromote( funcArg.evalType, expArgRcfType ) ):
            funcArgTypeName = funcArg.resolveType().displayName
            expArgTypeName = expArgRcfType.displayName
            fnName = call.funcName
            what = ( f"expected argument {idx+1} ({funcArgTypeName}) to be of "
                     f"{expArgTypeName} when calling function '{fnName}'" )
            # Pass the call node as the enclosing expression so that the entire
            # function call is rendered in the error message
            self.diags.typingError( function, funcArg, what,
                                    enclosingExpression=call )
