#!/usr/bin/env python3
# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Cell
import CliGlobal
from collections import Counter, defaultdict

import CliPlugin.PolicyEvalModels
from CliPlugin.PolicyEvalModels import RcfDebugStateOutOfSync
from CliPlugin import RcfCli
from CliPlugin.RcfDebugModels import (
   RcfDebugFragmentLocation,
   RcfDebugFunctionInvocationReference,
   RcfDebugResult,
   RcfDebugFragment,
   RcfDebugFunctionInvocation,
   RcfDebugFunction,
   RcfDebugEvaluation,
   RcfFunctionText,
   RcfCodeUnitText,
)
from RcfLibUtils import convertDebugSymbolsToPythonRepr
import RcfTypeFuture as Rcf

gv = CliGlobal.CliGlobal( dict(
      debugSymbols=None,
   )
)

class RcfDebugEvaluationModelHelper:
   '''
   Helper class to create an RcfDebug CAPI model for a given RcfDebugEvaluationLog.

   args:
     evaluationLog (dict):
       Of the format:
       {
         entryPoint: <top level function for this evaluation>
         entries: [
           {
             aetNodeKey: <key of the AET node that produced the entry>
             result: <result returned by the the AET nodes eval method>
             ...
           }, ...
         ]
       }
     codeUnitsModel (dict):
       Multiple evaluations can be present in the same CAPI model. While
       this class builds one evaluation model at a time, the codeUnitsModel
       is shared across evaluations. If this is the first or only evaluation
       then codeUnitsModel will be None, otherwise the existing one will be
       passed in and this helper should update it.

   This function glues together three sources to produce the debug CAPI model:
   - RcfDebugEvaluationLog
   - RcfText
   - RcfDebugSymbols

   The RcfText and RcfDebugSymbols are fetched from the CliPlugins global variables.
   The versions of the AETs in the RcfDebugEvaluationLog, the RcfDebugSymbols, and
   the RCF Text need to be verified to ensure they are in sync.
   '''

   def __init__( self, evaluationLog, codeUnitsModel ):
      self.evaluationLog = evaluationLog
      self.codeUnitsModel = codeUnitsModel or {}

      self.entryPoint = self.evaluationLog[ 'entryPoint' ]
      self.overallResult = self.getOverallResult( evaluationLog[ 'entries' ] )
      self.overallTermination = self.getOverallTermination(
         evaluationLog[ 'entries' ] )
      self.rcfConfig = RcfCli.gv.rcfConfig
      self.debugSymbols = convertDebugSymbolsToPythonRepr(
         gv.debugSymbols.functionSymbol )

      # Mapping of function to code unit
      self.functionToCodeUnitName = {}
      # Each invocation of a function gets a new index. This counter allocates
      # indexes for invocations as they are encountered in the log.
      self.functionInvocationIndex = Counter()
      # As function invocations are encountered they are added to this backlog.
      self.functionInvocationsToProcess = []

   def getOverallResult( self, evaluationLog ):
      '''
      Retrieve  the result of the last entry in the log, this will have been logged
      by the Function::eval.

      args:
        evaluationLog (RcfDebugEvaluationLog):
          The evaluation log to whose result is to be retrieved.
      '''
      return str( evaluationLog[ -1 ][ 'result' ] )

   def getOverallTermination( self, evaluationLog ):
      '''
      Return the 'termination' flag of the last entry of the evaluation log

      args:
        evaluationLog (RcfDebugEvaluationLog):
          The evaluation log to whose termination is to be retrieved.
      '''
      return str( evaluationLog[ -1 ][ 'termination' ] )

   def checkVersions( self, v1, v2, msg ):
      '''
      Raise an exception if two version are not equal. Used to ensure the versions of
      the AETs in the RcfDebugEvaluationLog, the RcfDebugSymbols, and the RCF Text
      are in sync.

      args:
        v1 (int):
          The first version number.
        v2 (int):
          The second version number.
        msg (str):
          The exception message which the versions can be formatted into.
      '''
      if v1 != v2:
         raise RcfDebugStateOutOfSync( msg.format( v1, v2 ) )

   def iterFunctionInvocationsToProcess( self ):
      '''
      Generator returning the next function invocation to process. Starts with
      the entry point function and then yields called functions in the order they
      were encountered.
      Note: self.functionInvocationsToProcess will be updated with called function
      as evaluation are processed.
      '''
      entryPointLog = self.evaluationLog[ 'entries' ]
      callStackDepth = 0
      yield self.entryPoint, callStackDepth, entryPointLog

      for fnName, callStackDepth, log in self.functionInvocationsToProcess:
         yield fnName, callStackDepth, log

   def addCodeUnitModel( self, fnName ):
      '''
      Update the codeUnitsModel with the RCF text of functions as they are
      encountered.

      Ensure the rcfCodeVersion matches the debugSymbols.

      args:
        fnName (str):
          The name of the function whose RCF text definition should be added to the
          codeUnitsModel.
      '''
      functionSymbol = self.debugSymbols[ fnName ]
      symbolRcfTextVersion = functionSymbol[ 'rcfCodeVersion' ]
      functionDefinition = functionSymbol[ 'functionDefinition' ]

      codeUnitName = functionDefinition[ 'codeUnitName' ]
      if fnName not in self.functionToCodeUnitName:
         self.functionToCodeUnitName[ fnName ] = codeUnitName

      if codeUnitName not in self.codeUnitsModel:
         self.codeUnitsModel[ codeUnitName ] = RcfCodeUnitText()

      if fnName not in self.codeUnitsModel[ codeUnitName ].functions:
         # Care needs to be taken when accessing the RCF text version. Race
         # conditions are possible here since in the Rcf::Config can be modified by
         # a different CLI session at the same time as this function access it.
         # To catch these races when they occur the version is accessed as follows:
         # - Write                       - Read
         # 1) rcfCodeVersionPending      1) rcfCodeVersion
         # 2) rcfCodeUnitText[]          2) rcfCodeUnitText[]
         # 3) rcfCodeVersion             3) rcfCodeVersionPending
         # Accessing with this order to verify the versions ensures that the
         # contents of rcfCodeUnitText[] matches rcfCodeVersion, or
         # else rcfCodeVersionPending will not match.

         # Ensure the current RCF text version matches the Debug Symbols.
         configRcfCodeVersion = self.rcfConfig.rcfCodeVersion
         self.checkVersions( configRcfCodeVersion, symbolRcfTextVersion,
                             'RCF text version mismatch: configRcfCodeVersion {}, '
                             'symbolRcfTextVersion {}' )

         source = self.rcfConfig.rcfCodeUnitText[ codeUnitName ]
         # Tabs are expanded before compilation in the RCF agent so that
         # column offsets in Debug Symbols are consistent with how the text will
         # be rendered.
         source = source.expandtabs( 4 )

         # Validate the pending version to ensure rcfCodeUnitText[] was not being
         # updated.
         self.checkVersions( self.rcfConfig.rcfCodeVersionPending,
                             configRcfCodeVersion,
                             'RCF text version update pending: '
                             'rcfConfig.rcfCodeVersionPending {}, '
                             'configRcfCodeVersion {}' )

         source = source.split( '\n' )
         lines = {}
         defnStartLine = functionDefinition[ 'definitionStartLine' ]
         defnEndLine = functionDefinition[ 'definitionEndLine' ]
         for lineNumber in range( defnStartLine, defnEndLine + 1 ):
            lines[ lineNumber ] = source[ lineNumber - 1 ]

         self.codeUnitsModel[ codeUnitName ].functions[ fnName ] = RcfFunctionText(
            lines=lines, _rcfCodeVersion=configRcfCodeVersion )
      else:
         # Ensure the model being reused from another evaluation used the same
         # RCF text.
         functionTextModel = self.codeUnitsModel[ codeUnitName ].functions[ fnName ]
         self.checkVersions( functionTextModel.rcfCodeVersion(),
                             symbolRcfTextVersion,
                             'RcfFunctionText version mismatch: '
                             'configRcfCodeVersion {}, '
                             'symbolRcfTextVersion {}' )

   def buildFragmentKwargs( self, fragmentType, location, continuation=False,
                            result=None, invocation=None ):
      return {
         'fragmentType': fragmentType,
         'location': location,
         'continuation': continuation,
         'result': result,
         'invocation': invocation
      }

   def iterDelimiters( self, symbol ):
      '''
      A node can have the one of the following delimiters associated with it:
      - openStatementExpression
      - closeStatementExpression
      or
      - openExplicitBracket
      - closeExplicitBracket
      or
      - openImplicitBracket
      - closeImplicitBracket

      This methods acquires the information from the symbol and yields the
      delimiters for the node if present.

      args:
        symbol (Rcf.Debug.AetNodeSymbol):
          The debug symbol for the node.
      '''
      statementExpressionData = symbol.get( 'statementExpressionData' )
      parensData = symbol.get( 'parensData' )
      # These are mutually exclusive, one is for statements the other for expressions
      assert not ( statementExpressionData and parensData )
      if statementExpressionData:
         openingLocations = statementExpressionData.get( 'openStmtPoint' )
         closingLocations = statementExpressionData.get( 'closeExprPoint' )
         fragmentFormat = '{}' + 'StatementExpression'
      elif parensData:
         openingLocations = parensData.get( 'openingParens' )
         closingLocations = parensData.get( 'closingParens' )
         fragmentFormat = '{}' + '{}Bracket'.format(
            'Explicit' if parensData.get( 'explicitParens' ) else 'Implicit' )
      else:
         return

      for openingLocation, closingLocation in zip( openingLocations or [],
                                                   closingLocations or [] ):
         yield self.buildFragmentKwargs(
            fragmentFormat.format( 'open' ),
            RcfDebugFragmentLocation( **openingLocation ) )
         yield self.buildFragmentKwargs(
            fragmentFormat.format( 'close' ),
            RcfDebugFragmentLocation( **closingLocation ) )

   def iterBlockFragmentData( self, blockData, termination ):
      '''
      Generator that produces the fragment data for a given Block.

      args:
        blockData (Rcf.Debug.BlockData):
          The DebugSymbol for the block that was evaluated.
        termination (TerminationFlag):
          The results termination flag for the block.
      '''
      blockLabelData = blockData.get( 'blockLabel' )
      if blockLabelData:
         assert len( blockLabelData ) == 1
         labelLocation = RcfDebugFragmentLocation( **blockLabelData[ 0 ] )
         yield self.buildFragmentKwargs( 'definition', labelLocation )

      openBraceData = blockData[ 'openingBrace' ]
      assert len( openBraceData ) == 1
      openingBraceLocation = RcfDebugFragmentLocation( **openBraceData[ 0 ] )
      yield self.buildFragmentKwargs( 'openBlock', openingBraceLocation )

      # Only emit the closing brace if the evaluated reached the end of the
      # block without returning.
      if termination == 'noTermination':
         closeBraceData = blockData[ 'closingBrace' ]
         assert len( closeBraceData ) == 1
         closingBraceLocation = RcfDebugFragmentLocation( **closeBraceData[ 0 ] )
         yield self.buildFragmentKwargs( 'closeBlock', closingBraceLocation )

   def iterBranchFragmentData( self, ifData, result ):
      '''
      Generator that produces the branch data for a given IfExpression.

      args:
        ifData (Rcf.Debug.IfData):
          The DebugSymbol for the branch that was evaluated.
        result (RcfDebugResult):
          The result of the branch condition.
      '''
      ifLabelData = ifData.get( 'ifLabel' )
      if ifLabelData:
         assert len( ifLabelData ) == 1
         yield self.buildFragmentKwargs(
            'definition', RcfDebugFragmentLocation( **ifLabelData[ 0 ] ) )

      ifPart = ifData[ 'ifPart' ]
      assert len( ifPart ) == 1
      yield self.buildFragmentKwargs( 'ifStatement',
                                      RcfDebugFragmentLocation( **ifPart[ 0 ] ),
                                      result=result )

      # The second part of a branch is the 'else'.
      # It is only considered evaluated if the condition result is NOT true.
      if result and result.value == 'true':
         return

      elseLabelData = ifData.get( 'elseLabel' )
      if elseLabelData:
         assert len( elseLabelData ) == 1
         yield self.buildFragmentKwargs(
            'definition', RcfDebugFragmentLocation( **elseLabelData[ 0 ] ) )

      elsePart = ifData.get( 'elsePart' )
      if elsePart:
         assert len( elsePart ) == 1
         yield self.buildFragmentKwargs(
            'elseStatement', RcfDebugFragmentLocation( **elsePart[ 0 ] ) )

   def handleInvocation( self, functionCallLog, functionCallData, callStackDepth ):
      '''
      Helper to manage state around processing a function invocation.

      args:
        functionCallLog (RcfDebugEvaluationLog):
          The evaluation log of the invocation.
        functionCallData (Rcf.Debug.FunctionCallData):
          The DebugSymbol for the function call.
        callStackDepth (int):
          The callStackDepth of the function call.
      '''
      calleeName = functionCallData[ 'calleeName' ]
      invocationIndex = self.functionInvocationIndex[ calleeName ]
      self.functionInvocationIndex[ calleeName ] += 1
      self.functionInvocationsToProcess.append(
         ( calleeName,
           callStackDepth,
           functionCallLog[ 'entries' ]
         )
      )
      return RcfDebugFunctionInvocationReference(
         functionName=calleeName, invocationIndex=invocationIndex )

   def iterEntries( self, log, functionSymbol, callStackDepth ):
      '''
      Generator to walk over the entries in a RcfDebugEvaluationLog and produce
      the resulting fragment data for each entry.

      args:
        log (RcfDebugEvaluationLog):
          The evaluation log to walk over.
        functionSymbol (Rcf.Debug.FunctionSymbol):
          The DebugSymbol for the function being evaluated.
        callStackDepth (int):
          The callStackDepth of the function being evaluated.
      '''
      aetNodeSymbols = functionSymbol[ 'aetNodeSymbols' ]

      for entry in log:
         aetNodeKey = entry[ 'aetNodeKey' ]
         symbol = aetNodeSymbols[ aetNodeKey ]
         if symbol.get( 'internal' ):
            continue
         fragmentType = symbol[ 'fragmentType' ]

         result = None
         if entry[ 'termination' ] != 'exitToPoa':
            result = RcfDebugResult( value=str( entry[ 'result' ] ) )

         functionCallLog = entry.get( 'functionCallLog' )
         assert bool( functionCallLog ) == ( fragmentType == 'functionCall' )

         # Handle parenthesis, statementExpression delimiters
         yield from self.iterDelimiters( symbol )

         # Handle blocks
         if fragmentType == 'openBlock':
            yield from self.iterBlockFragmentData( symbol[ 'blockData' ],
                                                   entry[ 'termination' ] )
            # blocks don't require further processing
            continue

         # Handle if-else
         if fragmentType == 'ifStatement':
            yield from self.iterBranchFragmentData( symbol[ 'ifData' ], result )
            # branches don't require further processing
            continue

         # The remaining fragments may have multiple locations
         for i, locationSymbol in enumerate( symbol[ 'location' ] ):
            continuation = bool( i )
            location = RcfDebugFragmentLocation( **locationSymbol )
            invocationReference = None

            if fragmentType in [ 'assignment', 'definition' ]:
               result = None
            elif fragmentType == 'functionCall' and not continuation:
               invocationReference = self.handleInvocation(
                  functionCallLog, symbol[ 'functionCallData' ], callStackDepth + 1 )

            yield self.buildFragmentKwargs( fragmentType, location,
                                            result=result,
                                            continuation=continuation,
                                            invocation=invocationReference )

   def buildModel( self ):
      '''
      The main function called to build the CAPI model. Walks over the function
      invocations and constructs the CAPI models for each fragment, invocation,
      and function. Then combines these into the overall evaluation model.
      '''
      functionEvalModels = {}
      functionInvocations = defaultdict( list )
      for fnName, callStackDepth, log in self.iterFunctionInvocationsToProcess():
         functionSymbol = self.debugSymbols.get( fnName )
         if functionSymbol is None:
            raise RcfDebugStateOutOfSync( f'Missing DebugSymbols for "{fnName}"' )

         if ( functionSymbol[ 'functionDomain' ]
              != Rcf.Metadata.FunctionDomain.USER_DEFINED ):
            # Builtin, OpenConfig, or POA wrapper function with no debug support.
            if fnName not in functionEvalModels:
               functionEvalModels[ fnName ] = RcfDebugFunction(
                  invocations=[], functionDomain=functionSymbol[ 'functionDomain' ] )
            continue

         # Ensure the AET version matches before processing the invocation
         evalAetVersion = log[ -1 ][ 'aetNodeKey' ]
         symbolsAetVersion = functionSymbol[ 'aetVersion' ]
         self.checkVersions( evalAetVersion, symbolsAetVersion,
                             'AET version mismatch: fnName: ' + fnName +
                             ', evalAetVersion {}, symbolAetVersion {}' )

         self.addCodeUnitModel( fnName )

         # Sort the fragments by their location. Some fragments are points and can
         # overlap with others. These should always come first and be in a certain
         # order if at the same location.
         fragmentSortPriority = defaultdict( int )
         fragmentSortPriority[ 'openStatementExpression' ] = -4
         fragmentSortPriority[ 'openImplicitBracket' ] = -3
         fragmentSortPriority[ 'closeImplicitBracket' ] = -2
         fragmentSortPriority[ 'closeStatementExpression' ] = -1

         # pylint: disable-next=dangerous-default-value
         def fragSort( fragment, fragmentPriority=fragmentSortPriority ):
            return ( fragment.location.line, fragment.location.column,
                     fragmentPriority[ fragment.fragmentType ] )
         fragments = [ RcfDebugFragment( **fragKwargs )
                       for fragKwargs in self.iterEntries( log, functionSymbol,
                                                           callStackDepth ) ]
         fragments = sorted( fragments, key=fragSort )

         result = RcfDebugResult( value=self.getOverallResult( log ) )
         invocationModel = RcfDebugFunctionInvocation( fragments=fragments,
                                                       callStackDepth=callStackDepth,
                                                       result=result )
         functionInvocations[ fnName ].append( invocationModel )

      # Construct the Function model once all of the invocations have been
      # processed.
      for fnName in functionInvocations:
         functionEvalModels[ fnName ] = RcfDebugFunction(
            invocations=functionInvocations[ fnName ],
            codeUnitName=self.functionToCodeUnitName[ fnName ],
            functionDomain=Rcf.Metadata.FunctionDomain.USER_DEFINED )

      return (
         RcfDebugEvaluation(
            entryPoint=self.entryPoint,
            result=RcfDebugResult( value=self.overallResult,
                                   termination=self.overallTermination ),
            functionEvaluations=functionEvalModels ),
         self.codeUnitsModel
      )

def createRcfDebugEvaluationModel( evaluationLog, codeUnitsModel ):
   helper = RcfDebugEvaluationModelHelper( evaluationLog, codeUnitsModel )
   return helper.buildModel()

# Making this a function so it is more obvious why tests need to import this file
def installEvaluationModelHook():
   CliPlugin.PolicyEvalModels.rcfDebugModelHook = createRcfDebugEvaluationModel

def Plugin( entMan ):
   mg = entMan.mountGroup()

   gv.debugSymbols = mg.mount( Cell.path( 'routing/rcf/debugSymbols' ),
                               'Rcf::Debug::Symbols', 'rS' )

   mg.close( callback=None, blocking=False )

installEvaluationModelHook()
