#!/usr/bin/env python3
# Copyright (c) 2022 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from enum import Enum

from Toggles import RcfLibToggleLib

import RcfAst
import RcfAstAccessibleAttributeGen
import RcfAstListener
import RcfAstWalker
from RcfImmediateValueHelper import AetExtCommHelper, RcfImmediateValueHelper
from RcfMetadata import RcfBuiltinTypes as BT
from RcfTypeFuture import Arnet

class CollectionWithRangeTranform( RcfAstListener.Listener ):
   """ Listener that goes through all collections and applies an AST transform.
   This transform moves all Range ast nodes from the `values` attribute to the
   `ranges` attribute.
   """
   def listenCollection( self, collection, action, **kwargs ):
      if action is RcfAstWalker.Action.EXIT:
         for item in collection.values:
            if isinstance( item, RcfAst.Range ):
               collection.ranges.append( item )
         for rangeNode in collection.ranges:
            collection.values.remove( rangeNode )

class FunctionCallCollector( RcfAstListener.Listener ):
   """ Walker that finds all function calls and stores them on the
   Function AST node.

   Attributes:
      diag (RcfDiag): RCF diagnostics object.
      nodeStack (list): Maintains a stack of the depth of the nodes
      currentFunction (RcfAst.Function): Current function for the node
   """
   def __init__( self, diag ):
      """Constructor

      Args:
         diag (RcfDiag): RCF diagnostics object.
      """
      super().__init__()
      self.diag = diag
      self.currentFunction = None

   def listenFunction( self, function, action, **kwargs ):
      if action is RcfAstWalker.Action.ENTRY:
         self.currentFunction = function
      if action is RcfAstWalker.Action.EXIT:
         self.currentFunction = None

   def listenCall( self, call, action, **kwargs ):
      if action is RcfAstWalker.Action.ENTRY:
         self.currentFunction.functionCalls.append( call )

   def listenExternalRef( self, extRef, action, **kwargs ):
      if action is RcfAstWalker.Action.ENTRY:
         self.currentFunction.externalRefs.append( extRef )

class CodeAnalysisBase( RcfAstListener.Listener ):
   """ Base CodeAnalysisBase class to listen the AstWalker

   Attributes:
      diag (RcfDiag): RCF diagnostics object.
      nodeStack (list): Maintains a stack of the depth of the nodes
      currentFunction (RcfAst.Function): Current function for the node
   """
   def __init__( self, diag ):
      """Constructor

      Args:
         diag (RcfDiag): RCF diagnostics object.
      """
      super().__init__()
      self.diag = diag
      self.nodeStack = []
      self.currentFunction = None

   def checkPreviousNodeType( self, nodeType ):
      """Check the previous node with the given nodeType

      Args:
         nodeType (RcfAstNode class): RcfAst node type
      """
      assert self.nodeStack, "no previous node in nodeStack"
      return isinstance( self.nodeStack[ -1 ], nodeType )

   def listen( self, node, action, **kwargs ):
      """
      listen handles the nodeStack and currentFunction
      - nodeStack maintains a stack of the depth of the nodes
      - Entry pushes the node after calling the listen
      - Exit pops the node before calling the listen
      - Previous node will always be on top, when listen is called
      """
      if action is RcfAstWalker.Action.EXIT:
         assert self.nodeStack, "empty nodeStack"
         assert node is self.nodeStack[ -1 ], "unexpected node"
         self.nodeStack.pop()
      elif action is not RcfAstWalker.Action.ENTRY:
         assert False, "unexpected action"

      super().listen( node, action, **kwargs )

      if action is RcfAstWalker.Action.ENTRY:
         self.nodeStack.append( node )

      if isinstance( node, RcfAst.Function ):
         self.currentFunction = node

class DirectiveErrorChecker( CodeAnalysisBase ):
   """ Directive Error Checker

   Check if DIRECTIVE statement is used and abort compilation.
   """

   def listenDirective( self, directive, action, **kwargs ):
      if action is not RcfAstWalker.Action.ENTRY:
         return
      directiveValue = directive.value
      what = f"'{directiveValue}' found"
      self.diag.directiveError( self.currentFunction, directive, what )

class NoEffectWarningChecker( CodeAnalysisBase ):
   """
   No Effect Warning/Error Checker

   Check if any statements have no effect

   Attributes:
      what (str): Warning/Error message
      callFound (bool): If function call is found
   """
   def __init__( self, diag ):
      super().__init__( diag )
      self.what = 'statement has no effect'
      self.callFound = False

   def listenCall( self, call, action, **kwargs ):
      if action is RcfAstWalker.Action.ENTRY:
         self.callFound = True

   def handleOp( self, context, action, **kwargs ):
      """
      Handle for LogicalOp, Not, BinOp and ExternalRefOp

      No effect condition:
      - When the node is a statement and the statement has no function calls

      Args:
         context (RcfParser.ParserRuleContext): Context of the Ast Node
         action (RcfAstWalker.Action): Action of the Ast Walker
      """
      if not self.checkPreviousNodeType( RcfAst.Block ):
         return

      if action is RcfAstWalker.Action.EXIT and not self.callFound:
         self.diag.noEffectWarning( self.currentFunction, context, self.what )

      self.callFound = False

   def listenLogicalOp( self, logicalOp, action, **kwargs ):
      self.handleOp( logicalOp.context, action, **kwargs )

   def listenNot( self, notExpr, action, **kwargs ):
      self.handleOp( notExpr.context, action, **kwargs )

   def listenBinOp( self, binOp, action, **kwargs ):
      self.handleOp( binOp.context, action, **kwargs )

   def listenExternalRefOp( self, externalRefOp, action, **kwargs ):
      self.handleOp( externalRefOp.context, action, **kwargs )

   def listenConstant( self, constant, action, **kwargs ):
      if action is not RcfAstWalker.Action.EXIT:
         return

      if self.checkPreviousNodeType( RcfAst.Block ):
         self.diag.immediateValueError( self.currentFunction, constant, self.what )

class EndOfBlock( Enum ):
   """ End of block enum

   - NOT_FOUND: Return/Exit not found yet in block
   - FOUND: Return/Exit found in block
   - WARNED: Warning already been made
   """
   NOT_FOUND = 0
   FOUND = 1
   WARNED = 2

class UnreachableCodeWarningChecker( CodeAnalysisBase ):
   """Unreachable Code Warning Checker

   Check if any code are after return/exit statements
   - within the same block

   Attributes:
      what (str): Warning message
      endOfBlockStack: Maintains a stack of EndOfFound Enums
   """
   def __init__( self, diag ):
      super().__init__( diag )
      self.what = 'unreachable code'
      self.endOfBlockStack = []

   def listenBlock( self, block, action, **kwargs ):
      """
      listenBlock handles the endOfBlockStack
      - Considers label blocks as same block in stack because they always run
      - Entry pushes the node and Exit pops the node from the stack
      """
      if self.checkPreviousNodeType( RcfAst.Block ):
         return

      if action is RcfAstWalker.Action.ENTRY:
         self.endOfBlockStack.append( EndOfBlock.NOT_FOUND )
      elif action is RcfAstWalker.Action.EXIT:
         self.endOfBlockStack.pop()
      else:
         assert False, "unexpected action"

   def handleStmt( self, context, action, **kwargs ):
      """ Handle for IfStmt, Call, ExternalRefOp, Assign, BinOP, LogicalOp, Not,
      Return, and Exit

      Unreachable code condition:
      - node is a statement and the previous statement is a return/exit

      Args:
         context (RcfParser.ParserRuleContext): Context of the Ast Node
         action (RcfAstWalker.Action): Action of the Ast Walker
      """
      if not self.checkPreviousNodeType( RcfAst.Block ):
         return

      if action is not RcfAstWalker.Action.ENTRY:
         return

      if self.endOfBlockStack[ -1 ] is EndOfBlock.FOUND:
         self.diag.unreachableCodeWarning( self.currentFunction, context, self.what )
         self.endOfBlockStack[ -1 ] = EndOfBlock.WARNED

   def listenIfStmt( self, ifStmt, action, **kwargs ):
      self.handleStmt( ifStmt.context.ifPart, action, **kwargs )

   def listenCall( self, call, action, **kwargs ):
      self.handleStmt( call.context, action, **kwargs )

   def listenExternalRefOp( self, externalRefOp, action, **kwargs ):
      self.handleStmt( externalRefOp.context, action, **kwargs )

   def listenAssign( self, assign, action, **kwargs ):
      self.handleStmt( assign.context, action, **kwargs )

   def listenBinOp( self, binOp, action, **kwargs ):
      self.handleStmt( binOp.context, action, **kwargs )

   def listenLogicalOp( self, logicalOp, action, **kwargs ):
      self.handleStmt( logicalOp.context, action, **kwargs )

   def listenNot( self, notExpr, action, **kwargs ):
      self.handleStmt( notExpr.context, action, **kwargs )

   def handleEndOfBlock( self, action, **kwargs ):
      """ Handle for Return, and Exit

      Set the top of the endOfBlockStack to found, if not found

      Args:
         action (RcfAstWalker.Action): Action of the Ast Walker
      """
      if action is not RcfAstWalker.Action.ENTRY:
         return

      if self.endOfBlockStack[ -1 ] is EndOfBlock.NOT_FOUND:
         self.endOfBlockStack[ -1 ] = EndOfBlock.FOUND

   def listenExit( self, exitOperation, action, **kwargs ):
      self.handleStmt( exitOperation.context.exitPart, action, **kwargs )
      self.handleEndOfBlock( action, **kwargs )

   def listenReturn( self, returnOperation, action, **kwargs ):
      self.handleStmt( returnOperation.context.returnPart, action, **kwargs )
      self.handleEndOfBlock( action, **kwargs )

class RouterIdAndIPv6ComparisonWarningChecker( CodeAnalysisBase ):
   """RouterIdAndIPv6ComparisonWarningChecker

   Check if the router_id attribute is compared with an IPv6 address or prefix list.
   """
   def listenBinOp( self, binOp, action, **kwargs ):
      if action is not RcfAstWalker.Action.ENTRY:
         return
      if not binOp.lhs.name.endswith( 'router_id' ):
         return
      if not isinstance( binOp.rhs, RcfAst.Constant ):
         return
      if binOp.rhs.resolveType() != BT.IpAddress:
         return

      ipaddr = Arnet.IpGenAddr( binOp.rhs.value )
      if ipaddr.af == Arnet.AddressFamily.ipv6:
         what = f"'{binOp.lhs.name}' should only be compared with an IPv4 address"
         self.diag.routerIdAndIPv6ComparisonWarning(
            self.currentFunction, binOp.context, what )

   def listenExternalRefOp( self, externalRefOp, action, **kwargs ):
      # Check only on AstWalker entry
      if action is not RcfAstWalker.Action.ENTRY:
         return

      # Skip check for attributes other than router_id
      # e.g. source_session.remote.ip_address
      if not externalRefOp.attribute.name.endswith( 'router_id' ):
         return

      # Skip check for RHS other than external reference (e.g. variable)
      if not isinstance( externalRefOp.rhs, RcfAst.ExternalRef ):
         return

      # Throw warning only if router_id is compared with IP address other than V4
      if externalRefOp.rhs.isIpV4:
         return

      what = ( f"'{externalRefOp.attribute.name}' should only be compared with"
               f" an IPv4 prefix list" )
      self.diag.routerIdAndIPv6ComparisonWarning(
         self.currentFunction, externalRefOp.context, what )

class LinkBandwidthValueAssignmentChecker( CodeAnalysisBase ):
   """ LinkBandwidthValueAssignmentChecker

   Check if more than one LBW value is present in an immediate ext_community set
   if the op is = or add
   Check if ASN is specified for any LBW values present in an immediate
   ext_community set if the op is = or add
   """
   def listenCollection( self, collection, action, **kwargs ):
      if action is not RcfAstWalker.Action.ENTRY:
         return

      if ( collection.evalType == BT.ImmediateExtCommunitySet and
           AetExtCommHelper.opIntroducesNewElement( collection ) and
           collection.parentLhsSymbol.rcfType == BT.ExtCommunity ):
         lbwValues = []
         lbwAsSpecified = []
         for v in collection.values:
            if ( v.evalType == BT.ImmediateExtCommunityValue and
                 v.sections[ 0 ].value == 'LINK-BANDWIDTH-AS' ):
               lbwValues.append( v )
               if v.context.section2 is not None:
                  lbwAsSpecified.append( v.sections[ 1 ] )

         if len( lbwValues ) > 1:
            self.diag.multipleLinkBandwidthWarning( self.currentFunction,
                                                    collection, lbwValues )

         if lbwAsSpecified:
            self.diag.adminAsSpecifiedWarning( self.currentFunction,
                                               collection, lbwAsSpecified )

         singleLbwExtComm = None
         for lbwValue in lbwValues:
            # First dynamic LBW is used
            if isinstance( lbwValue.sections[ 2 ], RcfAst.Attribute ):
               singleLbwExtComm = lbwValue
               break
            # Otherwise, the smallest static LBW value seen is used
            if ( singleLbwExtComm is None or
                 RcfImmediateValueHelper.bandwidthStrToNum(
                    lbwValue.sections[ 2 ].value, 'bps' ) <
                 RcfImmediateValueHelper.bandwidthStrToNum(
                    singleLbwExtComm.sections[ 2 ].value, 'bps' ) ):
               singleLbwExtComm = lbwValue

         # Remove all unused LBW values
         for lbwValue in lbwValues:
            if lbwValue != singleLbwExtComm:
               collection.values.remove( lbwValue )

class FunctionScopeCodeAnalysisPhase:
   """ Code Analysis phase

   During this phase, we visit the AST and check
      - no effect statements
      - unreachable code
      - which attributes are being used

   Attributes:
      astWalker (RcfAstWalker.Walker): RCF Abstract Syntax Tree Walker
   """
   def __init__( self, diag ):
      """ Constructor
         - Creates walker
         - Registers each listener

      Args:
         diag (RcfDiag): RCF diagnostics object.
      """
      self.astWalker = RcfAstWalker.Walker()
      self.astWalker.registerListener( LinkBandwidthValueAssignmentChecker( diag ) )
      self.astWalker.registerListener( NoEffectWarningChecker( diag ) )
      self.astWalker.registerListener( UnreachableCodeWarningChecker( diag ) )
      self.astWalker.registerListener(
         RouterIdAndIPv6ComparisonWarningChecker( diag ) )
      self.astWalker.registerListener( DirectiveErrorChecker( diag ) )
      featureAcessAttrs = RcfLibToggleLib.toggleRcfAccessAttrsEnabled()
      if not diag.hasErrors() and featureAcessAttrs:
         self.astWalker.registerListener(
            RcfAstAccessibleAttributeGen.AccessibleAttributesGenerator() )
      self.astWalker.registerListener( FunctionCallCollector( diag ) )
      self.astWalker.registerListener( CollectionWithRangeTranform() )

   def __call__( self, node, **kwargs ):
      """ Starts the walker
      """
      self.astWalker( node )
