# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import json
import AirStreamLib
import CliSession
from CodeUnitMapping import (
   CodeUnitMapping,
   generateCodeUnitsForDomain,
)
import GnmiSetCliSession
import Tac
import Tracing
from RcfDiag import RcfDiag
from RcfLexer import (
   RcfLexer,
   InputStream,
)
from RcfLinter import (
   RcfLinter,
   RcfLintRequest,
)
from RcfOpenConfigFunctionValidator import (
   RcfOpenConfigFunctionValidationRequest,
   validateOpenConfigFunctions,
)
import RcfTypeFuture as Rcf
from Toggles.RcfLibToggleLib import toggleRcfPolicyDefinitionsEnabled

th = Tracing.Handle( "ConfigSessionRcf" )
t0 = th.trace0
t2 = th.trace2

EXTERNAL_RCF_PATH = "routing/rcf/openconfig/rcfDir"
EXTERNAL_POLICY_DEFINITIONS_PATH = "routing/policy/openconfig/policyDefinitions"
NATIVE_PATH = "routing/rcf/config"

def validateRcfConfig( cls, sessionName, funcName, cfg ):
   diag = RcfDiag( strict=False )

   lexer = RcfLexer( InputStream( cfg.code ) )
   lexer.removeErrorListeners()
   lexer.addErrorListener( diag )

   # Check that there exists exactly one RCF function definition in config/code
   numFuncDefs = 0
   funcNameInDef = ""
   for t in lexer.getAllTokens():
      if t.type == lexer.FUNCTION:
         numFuncDefs += 1
         funcNameInDef = t.text.split()[ 1 ]
      if numFuncDefs > 1:
         errorStr = ( "More than one RCF function definition found in "
                      f"policy-definition {funcName}" )
         raise AirStreamLib.ToNativeSyncherError( sessionName,
            cls.__name__, errorStr )

   if numFuncDefs == 0:
      errorStr = f"No RCF function definition found in policy-definition {funcName}"
      raise AirStreamLib.ToNativeSyncherError( sessionName,
         cls.__name__, errorStr )

   # Check that the name of the function in the definition is the same as the name of
   # the enclosing policy-definition container (the "name" argument passed into this
   # function)
   if funcNameInDef != funcName:
      errorStr = ( f"Expected name of RCF function in definition to be {funcName}, "
                   f"got {funcNameInDef}" )
      raise AirStreamLib.ToNativeSyncherError( sessionName,
         cls.__name__, errorStr )

class RcfPolicyDefinitionsToNativeHandler( GnmiSetCliSession.PreCommitHandler ):
   externalPathList = [ EXTERNAL_POLICY_DEFINITIONS_PATH ]
   nativePathList = [ NATIVE_PATH ]
   entityManager = None

   @classmethod
   def convertTacNameToOCName( cls, name ):
      assert name
      ocName = ''.join( '-' + c.lower() if c.isupper() else c for c in name )
      # Remove any trailing digit from the end of TACC attribute name.
      ocName = ocName.removesuffix( '0' )
      return ocName

   @classmethod
   def renameJsonKeys( cls, jsonValue ):
      '''
      Return a JSON value with its keys converted from camel-case format to dash-case
      format (lowercase words separated by '-')
      '''
      if isinstance( jsonValue, dict ):
         outDict = {}
         for key in jsonValue.keys():
            newKey = cls.convertTacNameToOCName( key )
            outDict[ newKey ] = cls.renameJsonKeys( jsonValue[ key ] )
         return outDict
      elif isinstance( jsonValue, list ):
         return [ cls.renameJsonKeys( val ) for val in jsonValue ]
      else:
         return jsonValue

   @classmethod
   def removeStateDicts( cls, jsonValue ):
      '''
      The JSON serializer walks over the external config/state TAC object, which
      contains some "state" objects. This method removes these state objects.
      '''
      if isinstance( jsonValue, dict ):
         return { key: cls.removeStateDicts( jsonValue[ key ] )
                  for key in jsonValue if key != "state" }
      elif isinstance( jsonValue, list ):
         return [ cls.removeStateDicts( val ) for val in jsonValue ]
      else:
         return jsonValue

   @classmethod
   def removeRedundantDicts( cls, jsonDict ):
      '''
      Return a JSON dict with some redundant "config" dicts removed
      '''
      if "config" in jsonDict:
         jsonDict.pop( "config" )
      statements = jsonDict.get( "statements" )
      if statements:
         statementColl = statements.get( "statement" )
         if statementColl:
            for st in statementColl:
               if "config" in st:
                  st.pop( "config" )
      return jsonDict

   @classmethod
   def processJson( cls, jsonStr ):
      jsonDict = json.loads( jsonStr )
      # We might receive an empty policy-definition, in which case we need to
      # stand up the "statements/statement" hierarchy ourselves (required by our
      # OpenConfig function validator)
      if "statements" not in jsonDict:
         jsonDict[ "statements" ] = {}
      if "statement" not in jsonDict[ "statements" ]:
         jsonDict[ "statements" ][ "statement" ] = []
      jsonDict = cls.renameJsonKeys( jsonDict )
      jsonDict = cls.removeStateDicts( jsonDict )
      jsonDict = cls.removeRedundantDicts( jsonDict )
      return json.dumps( jsonDict, indent=3 )

   @classmethod
   def configure( cls, entityManager ):
      cls.entityManager = entityManager
      GnmiSetCliSession.registerPreCommitHandler( cls )

   @classmethod
   def syncExternalToNative( cls, policyDefinitions, nativeConfig, sessionName ):
      helper = Tac.newInstance(
         "Rcf::OpenConfig::RcfPolicyDefinitionsToNativeHelper", policyDefinitions,
         nativeConfig )

      helper.syncPolicyDefinitionsToNative()

      # At this point, openConfigFunction is populated with the JSON strings.
      # Now we just do some post-processing on them.
      t0( f"Functions: {nativeConfig.openConfigFunction.keys()}" )
      for name in nativeConfig.openConfigFunction.keys():
         jsonStr = nativeConfig.openConfigFunction[ name ]
         nativeConfig.openConfigFunction[ name ] = cls.processJson( jsonStr )

      validationRequest = RcfOpenConfigFunctionValidationRequest(
         dict( nativeConfig.openConfigFunction ), nativeConfig.rcfCodeVersion )
      validationResult = validateOpenConfigFunctions( validationRequest )
      validationErrors = validationResult.diag.allErrors
      if validationErrors:
         errorStr = "\n".join( e.render( None ) for e in validationErrors )
         raise AirStreamLib.ToNativeSyncherError( sessionName,
            cls.__name__, errorStr )

   @classmethod
   def run( cls, sessionName ):
      policyDefinitions = AirStreamLib.getSessionEntity( cls.entityManager,
                                                         sessionName,
                                                         cls.externalPathList[ 0 ] )
      nativeConfig = AirStreamLib.getSessionEntity( cls.entityManager, sessionName,
                                                    cls.nativePathList[ 0 ] )
      cls.syncExternalToNative( policyDefinitions, nativeConfig, sessionName )

def Plugin( entMan ):
   CliSession.registerConfigGroup( entMan, "airstream-cmv",
         "routing/rcf/openconfig/rcfDir" )

   # -------------------------------------------------------------------------
   # pre-commit handler for AirStream OpenConfig -> EOS handling
   # -------------------------------------------------------------------------
   def toNativeRcfSyncher( cls, sessionName ):
      externalConfigDir = AirStreamLib.getSessionEntity( entMan, sessionName,
                                                         cls.externalPath )
      nativeConfig = AirStreamLib.getSessionEntity( entMan, sessionName,
                                                    cls.nativePath )
      rcfLinter = RcfLinter()

      extRcfCodeUnitText = {}
      for name, rcf in externalConfigDir.rcf.items():
         cfg = rcf.config
         validateRcfConfig( cls, sessionName, name, cfg )

         codeUnit = cfg.codeUnit
         # cfg.codeUnit can be None as it is defined as Tac::String::Optional
         if codeUnit is None:
            codeUnit = name

         if codeUnit in extRcfCodeUnitText:
            extRcfCodeUnitText[ codeUnit ] += f"\n\n{cfg.code}"
         else:
            extRcfCodeUnitText[ codeUnit ] = f"{cfg.code}"

      nativeRcfCodeUnitTextDict = dict( nativeConfig.rcfCodeUnitText.items() )

      if extRcfCodeUnitText == nativeRcfCodeUnitTextDict:
         t0( "No changes to RCF config" )
         return

      codeUnitsForLinter = generateCodeUnitsForDomain(
         Rcf.Metadata.FunctionDomain.OPEN_CONFIG, extRcfCodeUnitText )
      codeUnitMapping = CodeUnitMapping( codeUnitsForLinter )
      lintRequest = RcfLintRequest( codeUnitMapping, rcfCodeVersion=0,
                                    strictMode=False )
      rcfLintResult = rcfLinter.lint( lintRequest )

      if not rcfLintResult.success:
         errorStr = "\n".join( e.render( codeUnitMapping )
                                for e in rcfLintResult.diag.allErrors )
         errorStr += "\nCompilation failed"
         raise AirStreamLib.ToNativeSyncherError( sessionName,
            cls.__name__, errorStr )

      t0( f"Syncing from {externalConfigDir} to {nativeConfig}" )
      nativeRcfCodeUnitText = nativeConfig.rcfCodeUnitText
      rcfCodeUnitUrlInfo = nativeConfig.rcfCodeUnitUrlInfo

      def updateRcfCodeUnitUrlInfo( codeUnit, url ):
         updatedUrlInfo = Tac.Value( "Rcf::RcfCodeUnitUrlInfo", codeUnit, url )
         updatedUrlInfo.editSincePull = True
         rcfCodeUnitUrlInfo.addMember( updatedUrlInfo )

      # Remove stale RCF config
      for codeUnit in nativeRcfCodeUnitText:
         if codeUnit not in extRcfCodeUnitText:
            t2( f"Deleting code unit {codeUnit}" )
            del nativeRcfCodeUnitText[ codeUnit ]
            urlInfo = rcfCodeUnitUrlInfo.get( codeUnit )
            if urlInfo:
               updateRcfCodeUnitUrlInfo( codeUnit, urlInfo.lastPulledUrl )

      # Sync external config to native config
      for codeUnit, text in extRcfCodeUnitText.items():
         nativeText = nativeRcfCodeUnitText.get( codeUnit )
         if nativeText and text == nativeText:
            continue
         t2( f"Adding/updating code unit {codeUnit}" )
         # Newline added so that EOF is on a line on its own
         nativeRcfCodeUnitText[ codeUnit ] = text + "\n"
         urlInfo = rcfCodeUnitUrlInfo.get( codeUnit )
         if urlInfo:
            updateRcfCodeUnitUrlInfo( codeUnit, urlInfo.lastPulledUrl )

      nativeConfig.rcfCodeVersion = nativeConfig.rcfCodeVersionPending
      nativeConfig.enabled = bool( extRcfCodeUnitText )
      t0( "Sync complete" )

   class ToNativeRcfHandler( GnmiSetCliSession.PreCommitHandler ):
      externalPath = "routing/rcf/openconfig/rcfDir"
      nativePath = "routing/rcf/config"
      externalPathList = [ EXTERNAL_RCF_PATH ]
      nativePathList = [ NATIVE_PATH ]

      @classmethod
      def run( cls, sessionName ):
         toNativeRcfSyncher( cls, sessionName )

   GnmiSetCliSession.registerPreCommitHandler( ToNativeRcfHandler )
   if toggleRcfPolicyDefinitionsEnabled():
      RcfPolicyDefinitionsToNativeHandler.configure( entMan )
   AirStreamLib.registerCopyHandler( entMan, "OCRcfConfig",
                                     typeName="Rcf::OpenConfig::RcfDir" )
