# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from abc import abstractmethod, ABC as AbstractBaseClass
from CliParser import AlreadyHandledError
from CliPlugin import ConfigTagCommon
from CliPlugin.RcfCliSessionHelpers import isInteractive
from CliPlugin.RcfScratchpad import (
   getScratchpad,
   removeScratchpad,
)
from RcfLinter import RcfLintRequest
from CodeUnitMapping import (
   CodeUnitMapping,
   generateCodeUnitsForDomain,
)
from RcfCommandTagHelpers import getExcludedCodeUnits
import Tac
import Tracing
from Toggles import RcfLibToggleLib
from RcfOpenConfigFunctionTextGen import generateOpenConfigFunctionTexts
from RcfOpenConfigFunctionValidator import (
   RcfOpenConfigFunctionValidationRequest,
   validateOpenConfigFunctions,
)
import RcfTypeFuture as Rcf

traceHandle = Tracing.Handle( 'RcfCliHelpers' )
t0 = traceHandle.trace0

# ----------------------------------------------------------------------------------
#                                H E L P E R S
#----------------------------------------------------------------------------------

def addCompilationStatusToMode( mode, results, codeUnitMapping, addWarnings=True,
                                errorMsg=None ):
   """Add validation results to the CLI mode.

   Args:
      mode (CliMode)
      result (List[RcfCompilerStageResultBase]): results to add to the mode
      codeUnitMapping (CodeUnitMapping or None): Mapping from joined code lines to
         their original position in their respective code units.
   """
   for result in results:
      if addWarnings:
         for warning in result.diag.allWarnings:
            mode.addWarning( warning.render( codeUnitMapping ) )

      if not result.success:
         for error in result.diag.allErrors:
            mode.addError( error.render( codeUnitMapping ) )
      else:
         assert not result.diag.allErrors
   if all( result.success for result in results ):
      mode.addMessage( "Compilation successful" )
   else:
      msg = errorMsg or "Compilation failed"
      raise AlreadyHandledError( msg=msg, msgType=AlreadyHandledError.TYPE_ERROR )

def commitRcfUrlInfo( gv, unitName, url, edited ):
   t0( "Committing URL info for code unit ", unitName, " to RcfConfig" )
   tacRcfCodeUnitUrlInfo = Tac.Value( 'Rcf::RcfCodeUnitUrlInfo',
                                      unitName,
                                      url )
   tacRcfCodeUnitUrlInfo.editSincePull = edited
   gv.rcfConfig.rcfCodeUnitUrlInfo.addMember( tacRcfCodeUnitUrlInfo )

@Tac.memoize
def configTagInputAllocator():
   return Tac.newInstance( "ConfigTag::ConfigTagInputAllocator" )

def validateConfigTags( mode, scratchpad ):
   for configTag in scratchpad.rcfCodeUnitConfigTags.values():
      if configTag is not None:
         ConfigTagCommon.commandTagConfigCheck( mode, configTag )
         ConfigTagCommon.checkTagInRemovedOrDisassociatedState( mode, configTag,
                                                                featureCheck=True )

def setConfigTag( gv, unitName, configTag ):
   newTagId = gv.configTagConfig.configTagEntry[ configTag ].tagId

   previousTagId = gv.rcfConfig.codeUnitToConfigTag.get( unitName, None )
   if previousTagId != newTagId:
      # Delete previous association if it exists
      if previousTagId is not None:
         configTagToCodeUnit = gv.rcfConfig.configTagToCodeUnit[ previousTagId ]
         configTagToCodeUnit.codeUnit.remove( unitName )

      # Create the new association
      gv.rcfConfig.codeUnitToConfigTag[ unitName ] = newTagId
      configTagToCodeUnit = gv.rcfConfig.configTagToCodeUnit.newMember( newTagId )
      configTagToCodeUnit.codeUnit.add( unitName )
      configTagInputAllocator().newConfigTagInputEntry( gv.configTagInput, configTag,
                                                        "" )

def removeConfigTag( gv, unitName ):
   if unitName in gv.rcfConfig.codeUnitToConfigTag:
      previousTagId = gv.rcfConfig.codeUnitToConfigTag[ unitName ]
      configTagToCodeUnit = gv.rcfConfig.configTagToCodeUnit[ previousTagId ]
      configTagToCodeUnit.codeUnit.remove( unitName )
      del gv.rcfConfig.codeUnitToConfigTag[ unitName ]

def updateRcfConfig( gv, scratchpad ):

   # Commit URL info for any code units whose URL info may have changed
   for unitName, pendingUnitUrlInfo in scratchpad.rcfCodeUnitUrlInfos.items():
      if pendingUnitUrlInfo.lastPulledUrl:
         # This unit was pulled during this scratchpad session, replace old URL info
         commitRcfUrlInfo( gv,
                           unitName,
                           pendingUnitUrlInfo.lastPulledUrl,
                           pendingUnitUrlInfo.editSincePull )
      else:
         # This unit was only edited (not pulled) during this scratchpad session. If
         # URL info for this unit exists in Sysdb, update the 'edited' field. Else,
         # nothing to do ('edited' field means nothing without a URL).
         existingUnitUrlInfo = gv.rcfConfig.rcfCodeUnitUrlInfo.get( unitName )
         if existingUnitUrlInfo:
            commitRcfUrlInfo( gv,
                              unitName,
                              existingUnitUrlInfo.lastPulledUrl,
                              pendingUnitUrlInfo.editSincePull )

   # Commit RCF text for any code units whose RCF text may have changed
   for unitName, pendingUnitText in scratchpad.rcfCodeUnitTexts.items():
      if pendingUnitText is None:
         t0( "Deleting code unit ", unitName, " from RcfConfig" )
         # This code unit is being deleted
         del gv.rcfConfig.rcfCodeUnitText[ unitName ]
      else:
         t0( "Committing code unit ", unitName, " to RcfConfig" )
         gv.rcfConfig.rcfCodeUnitText[ unitName ] = pendingUnitText

   for functionName, pendingOpenConfigFunction in (
         scratchpad.openConfigFunctions.items() ):
      if pendingOpenConfigFunction is None:
         t0( "Deleting openconfig function ", functionName, " from RcfConfig" )
         # This function is being deleted
         del gv.rcfConfig.openConfigFunction[ functionName ]
      else:
         t0( "Committing openconfig function ", functionName, " to RcfConfig" )
         gv.rcfConfig.openConfigFunction[ functionName ] = (
               pendingOpenConfigFunction )

   # Commit any changed config tag associations
   for unitName, configTag in scratchpad.rcfCodeUnitConfigTags.items():
      if configTag:
         setConfigTag( gv, unitName, configTag )
      else:
         removeConfigTag( gv, unitName )

   # Delete any ConfigTagToCodeUnit entities that are no longer
   # associated with a code unit
   for configTagId in gv.rcfConfig.configTagToCodeUnit.keys():
      if len( gv.rcfConfig.configTagToCodeUnit[ configTagId ].codeUnit ) == 0:
         del gv.rcfConfig.configTagToCodeUnit[ configTagId ]

def commitRcfCode( gv, scratchpad ):
   gv.rcfConfig.rcfCodeVersionPending += 1

   if scratchpad:
      updateRcfConfig( gv, scratchpad )
   else:
      gv.rcfConfig.clearConfig()

   gv.rcfConfig.rcfCodeVersion = gv.rcfConfig.rcfCodeVersionPending
   gv.rcfConfig.enabled = bool( scratchpad )

def commitRcfHelper( mode, gv ):
   scratchpad = getScratchpad( mode )
   if scratchpad is None:
      mode.addWarning(
         'There are no pending routing control function changes to commit' )
      return

   if scratchpad.isReadyToCommit( mode, gv.rcfConfig, gv.rcfStatus ):
      validateConfigTags( mode, scratchpad )
      # Save RCF text and functionNames to Sysdb since linting succeeded or we're
      # in startup-config context are committing regardless of linting
      commitRcfCode( gv, scratchpad )
      removeScratchpad( mode )

def getCodeCmd( gv, unitName ):
   if unitName is gv.rcfConfig.unnamedCodeUnitName:
      codeCmd = "code"
   else:
      codeCmd = "code unit %s" % unitName
   return codeCmd

def getCodeConfigTagCmd( gv, unitName, configTag ):
   codeCmd = getCodeCmd( gv, unitName )
   codeConfigTagCmd = f"{codeCmd} command-tag {configTag}"
   return codeConfigTagCmd

def getCodeUnitStr( gv, unitName, capitalize=False, sentenceFormat=False ):
   if unitName is gv.rcfConfig.unnamedCodeUnitName:
      codeUnitStr = ""
      if sentenceFormat:
         codeUnitStr = "the "
      codeUnitStr += "unnamed code unit"
   else:
      codeUnitStr = "code unit %s" % unitName
   if capitalize:
      # Cannot use <str>.capitalize() because unitName might have capital letters
      codeUnitStr = codeUnitStr[ 0 ].upper() + codeUnitStr[ 1 : ]
   return codeUnitStr

def getUnitName( gv, args ):
   """Helper function for CLI commands to get the unit name from the parsed args"""
   if 'unit' in args and 'UNIT_NAME' in args:
      return args[ 'UNIT_NAME' ]
   else:
      # Unnamed code unit
      return gv.rcfConfig.unnamedCodeUnitName

def getEffectiveCodeUnitToConfigTag( mode, gv ):
   effectiveCodeUnitToConfigTag = dict( gv.rcfConfig.codeUnitToConfigTag )
   scratchpad = getScratchpad( mode )
   for unitName, configTag in scratchpad.rcfCodeUnitConfigTags.items():
      if configTag:
         tagEntry = gv.configTagConfig.configTagEntry.get( configTag )
         if tagEntry:
            effectiveCodeUnitToConfigTag[ unitName ] = tagEntry.tagId
      else:
         if unitName in effectiveCodeUnitToConfigTag:
            del effectiveCodeUnitToConfigTag[ unitName ]
   return effectiveCodeUnitToConfigTag

def getFunctionNames( mode, gv ):
   return list( gv.rcfStatus.functionNames )

class ScratchpadHelper( AbstractBaseClass ):
   """Helper class to factor retrieving names and text blobs across the
   scratchpad and rcf config. This used for code units and openconfig funcitons.

   Derived classes implement rcfCollectionName and scratchpadCollectionName.
   """
   def __init__( self, rcfConfig ):
      self.rcfConfig = rcfConfig

   def getEffectiveNames( self, mode ):
      """Helper API to get the currently effective (code unit, OC function) names
      across the scratchpad and Sysdb.

      NOTE: Names that are being deleted in the scratchpad *will* be included
      in the returned set of effective names.
      """
      effectiveNames = set( getattr( self.rcfConfig, self.rcfCollectionName() ) )
      scratchpad = getScratchpad( mode )
      if scratchpad:
         effectiveNames.update( getattr( scratchpad,
                                         self.scratchpadCollectionName() ) )
      return effectiveNames

   def getEffectiveText( self, mode, name ):
      """Helper API to get the currently effective (code unit, OC function) text
      across the scratchpad and Sysdb.

      Due to ConfigAgent being multithreaded, this API will return None if the
      provided name is not found in the scratchpad or in Sysdb.
      """
      scratchpad = getScratchpad( mode )
      if ( scratchpad and
           name in ( collection := getattr( scratchpad,
                                            self.scratchpadCollectionName() ) ) ):
         effectiveText = collection[ name ]
      else:
         effectiveText = getattr( self.rcfConfig,
                                  self.rcfCollectionName() ).get( name, None )
      return effectiveText

   def getEffectiveTexts( self, mode ):
      """Helper API to get the all the effective texts.

      For each entry (code unit, OC function):
      - If the entry is in the scratchpad, the effective text is the version in
        the scratchpad (including if the unit is deleted in the scratchpad).
      - Else, the effective text is the version in Sysdb (if present).

      Returns a dict mapping name (str) to effective text for the unit (str).
      """
      effectiveTexts = {}
      effectiveNames = self.getEffectiveNames( mode )
      for name in effectiveNames:
         effectiveText = self.getEffectiveText( mode, name )
         if effectiveText is not None:
            effectiveTexts[ name ] = effectiveText
      return effectiveTexts

   @abstractmethod
   def rcfCollectionName( self ):
      pass

   @abstractmethod
   def scratchpadCollectionName( self ):
      pass

class CodeUnitScratchpadHelper( ScratchpadHelper ):

   def rcfCollectionName( self ):
      return "rcfCodeUnitText"

   def scratchpadCollectionName( self ):
      return "rcfCodeUnitTexts"

class OpenConfigFunctionScratchpadHelper( ScratchpadHelper ):

   def rcfCollectionName( self ):
      return "openConfigFunction"

   def scratchpadCollectionName( self ):
      return "openConfigFunctions"

def lintRcfHelper( mode, gv ):
   scratchpad = getScratchpad( mode )
   if scratchpad is None:
      mode.addWarning(
         'There are no pending routing control function changes to compile' )
   else:
      codeUnitsForLinter = {}
      validationResults = []

      # Handle user-defined code units
      userDefinedCodeUnitTexts = (
            gv.codeUnitScratchpadHelper.getEffectiveTexts( mode ) )
      rcfCodeVersion = gv.rcfConfig.rcfCodeVersion

      # Filter command-tag disabled code units
      excludedCodeUnits = getExcludedCodeUnits( set( userDefinedCodeUnitTexts ),
                                       getEffectiveCodeUnitToConfigTag( mode, gv ),
                                       gv.configTagIdState,
                                       gv.configTagStatus )
      filteredUserDefinedCodeUnitTexts = {
         codeUnitName: codeUnitText
         for ( codeUnitName, codeUnitText ) in userDefinedCodeUnitTexts.items()
         if codeUnitName not in excludedCodeUnits
      }
      codeUnitsForLinter.update( generateCodeUnitsForDomain(
         Rcf.Metadata.FunctionDomain.USER_DEFINED,
         filteredUserDefinedCodeUnitTexts ) )

      # Generate OpenConfig functions
      openConfigInputForValidator = (
            gv.openConfigScratchpadHelper.getEffectiveTexts( mode ) )
      if RcfLibToggleLib.toggleRcfPolicyDefinitionsEnabled():
         openConfigValidationRequest = RcfOpenConfigFunctionValidationRequest(
            openConfigInputForValidator, rcfCodeVersion )
         ocValidationResult = validateOpenConfigFunctions(
               openConfigValidationRequest )
         scratchpad.openConfigValidationResult = ocValidationResult
         validationResults.append( ocValidationResult )
         if ocValidationResult.success:
            openConfigFunctionTexts = generateOpenConfigFunctionTexts(
                  ocValidationResult.openConfigFunctions,
                  gv.rcfExternalConfig.aclListConfig.ipv6PrefixList )
            # Add generated text to the linter input
            codeUnitsForLinter.update(
               generateCodeUnitsForDomain( Rcf.Metadata.FunctionDomain.OPEN_CONFIG,
                                           openConfigFunctionTexts ) )

      # Validate code units
      codeUnitMapping = CodeUnitMapping( codeUnitsForLinter )
      lintRequest = RcfLintRequest( codeUnitMapping,
                                    rcfCodeVersion,
                                    strictMode=isInteractive( mode ) )
      rcfExternalConfig = None
      if lintRequest.strictMode:
         rcfExternalConfig = gv.rcfExternalConfig
      rcfLintResult = gv.rcfLinter.lint( lintRequest, rcfExternalConfig )
      scratchpad.rcfLintResult = rcfLintResult
      validationResults.append( rcfLintResult )

      # Publish result
      addCompilationStatusToMode( mode, validationResults, codeUnitMapping )

# There are two cases when parsing the config: manually entering the command
# or reading the startup config. The latter comes with indentation of 3 spaces
# for every sub config mode, 6 spaces in our case. We need to preserve the
# original indentation of the RCF code. If all lines of the input start with at
# least 6 spaces, we prune the first 6 spaces of each line before proceeding.
# BUG416112: Move removeStartupConfigIndentation() to BasicCliUtils.py
def removeStartupConfigIndentation( multiLineRcfInput, indentLen ):
   # First, check if each line of the input starts with indentLen number of spaces.
   # If it doesn't, this multiLineRcfInput did *not* come from startup config and we
   # should not make any changes to it.
   indent = " " * indentLen
   lines = multiLineRcfInput.split( "\n" )
   for line in lines[ : -1 ]: # The last line will be an empty string after split()
      if line[ : indentLen ] != indent and line != '':
         return multiLineRcfInput
   # If we've gotten this far, we know all lines of the input start with inputLen
   # number of spaces. We prune the extraenous spaces, concatenate the lines back
   # into a single string, and return it back to the caller.
   modifiedInputLines = []
   for line in lines:
      if line == '':
         modifiedInputLines.append( line )
      else:
         modifiedInputLines.append( line[ indentLen : ] )
   return "\n".join( modifiedInputLines )
