# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

'''
Registers SM translating RouteMap OpenConfig configuration to EOS config

Additionally, registers a pre-commit handler performing some post-translation work.
Most importantly, it raises any errors detected during translation and stored into
an `ToNativeSmResult` instance.
'''

from Toggles.OpenConfigRoutingPolicyToggleLib import (
   toggleAirStreamPolicyDefinitionsEnabled,
)

import AirStreamLib
import CliSession
import GnmiSetCliSession
import Tac
import Tracing

th = Tracing.Handle( 'RoutingPolicyOpenConfig' )
t0 = th.trace0
t3 = th.trace3

TACC_NAMESPACE = 'Routing::Policy::OpenConfig'

EXTERNAL_MOUNT_PATH = 'routing/policy/openconfig/policyDefinitions'
EXTERNAL_MOUNT_TYPE = f'{TACC_NAMESPACE}::PolicyDefinitions'
NATIVE_PATH = 'routing/routemap/config'
ACL_LIST_PATH = 'routing/acl/config'

COPY_HANDLER_DIR_NAME = 'PolicyDefinitionsCopyHandlerDir'
COPY_HANDLER_DIR_TYPE = f'{TACC_NAMESPACE}::{COPY_HANDLER_DIR_NAME}'

SM_TYPE_NAME = f'{TACC_NAMESPACE}::ToNativeSm'
SM_RESULT_TYPE_NAME = f'{TACC_NAMESPACE}::ToNativeSmResult'

# Maximum size of the user message we are allowed to return, in bytes
ERROR_MESSAGE_TRUNCATION_SIZE = 0x100
ERROR_MESSAGE_TRUNCATION_TEXT = '(error list truncated)'

def _textSize( text ):
   '''Returns size of the text after encoding'''
   return len( text.encode() )

assert _textSize( ERROR_MESSAGE_TRUNCATION_TEXT ) < ERROR_MESSAGE_TRUNCATION_SIZE

class RoutingPolicyOpenConfigToNativeSmWrapper(
      GnmiSetCliSession.GnmiSetCliSessionSm ):
   '''Wrapper installing to-native SM on OpenConfig session paths'''

   entityPaths = [ EXTERNAL_MOUNT_PATH, NATIVE_PATH ]  # override from parent

   gnmiSetSessionState = None  # set by `configure`
   entityManager = None  # set by `configure`
   smResult = None  # set by `configure`, stores SM results
   isRegistered = False  # set by `configure` upon registration

   @classmethod
   def configure( cls, entityManager, smResult ):
      '''Configures and registers the class'''
      cls.gnmiSetSessionState = Tac.singleton( 'AirStream::GnmiSetSessionState' )
      cls.entityManager = entityManager
      cls.smResult = smResult
      if not cls.isRegistered:
         # the following API does not like to be called repeatedly
         GnmiSetCliSession.registerPreCommitSm( cls )
         cls.isRegistered = True

   def __init__( self ):
      super().__init__()
      self.toNativeSm = None

   def externalEntity( self ):
      '''Returns external entity we are going to react to'''
      return self.sessionEntities[ EXTERNAL_MOUNT_PATH ]

   def nativeEntity( self ):
      '''Returns native entity we are going to translate changes to'''
      sessionName = self.gnmiSetSessionState.sessionName
      return AirStreamLib.getSessionEntity( self.entityManager, sessionName,
                                            NATIVE_PATH )

   # override from parent
   def run( self ):
      # Store SM in the instance to protect it from being garbage collected.
      self.toNativeSm = Tac.newInstance( SM_TYPE_NAME, self.externalEntity(),
                                         self.nativeEntity(), self.smResult )
      t0( 'To-native SM created' )

class RoutingPolicyOpenConfigPreCommitHandler( GnmiSetCliSession.PreCommitHandler ):
   '''Pre-commit handler

   Raises any error detected by SM as `AirStreamLib.ToNativeSyncherError`.'''

   externalPathList = [ EXTERNAL_MOUNT_PATH ]  # override from parent
   nativePathList = [ NATIVE_PATH, ACL_LIST_PATH ]  # override from parent

   entityManager = None  # set by `configure`
   smResult = None  # set by `configure`, stores SM results

   @classmethod
   def configure( cls, entityManager, smResult ):
      '''Configures and registers the class'''
      cls.entityManager = entityManager
      cls.smResult = smResult
      GnmiSetCliSession.registerPreCommitHandler( cls )

   @classmethod
   def fixEmptyRouteMaps( cls, sessionName ):
      '''Adds auto entries to route maps that ended up being empty.

      This method might work on route maps that no longer exist in route map config
      due to being deleted, for example if a policy definition was changed to RCF
      policy type. This is fine though, as these orphan route maps will be cleaned up
      when the to-native SM result is cleaned up.
      '''
      for routeMap in cls.smResult.emptiedRouteMap.values():
         if routeMap.mapEntry:
            # This route map is not really empty (something was added later).
            continue
         t3( f'fixEmptyRouteMaps: {routeMap.name=}' )
         mapEntry = routeMap.mapEntry.newMember( 10 )
         mapEntry.statementName = '10'
         mapEntry.statementNameFromCli = True
         mapEntry.permit = 'permitMatch'
         mapEntry.hasContinue = True

   @classmethod
   def fixPrefixListMatches( cls, sessionName ):
      '''Chooses between matchPrefixList and matchIpv6PrefixList for updated entries

      Rules for choosing:
      * If neither prefix list exists, deletes the match;
      * If only IPv4 prefix list exists, uses `matchPrefixList`;
      * If only IPv6 prefix list exists, uses `ipv6MatchPrefixList`;
      * If both exist, adds an message to `cls.smResult.errorMessage` at most once.
      '''
      aclListConfig = AirStreamLib.getSessionEntity( cls.entityManager, sessionName,
                                                     ACL_LIST_PATH )
      ipVersionsLists = { 4: aclListConfig.prefixList,
                          6: aclListConfig.ipv6PrefixList }
      errorMessageAdded = False
      for mapEntry in cls.smResult.updatedMapEntry.values():
         assert 'matchIpv6PrefixList' not in mapEntry.matchRule, (
               'SM always deletes this match rule' )
         rule = mapEntry.matchRule.get( 'matchPrefixList' )
         if rule is None:
            continue
         listName = rule.strValue
         availableIpVersions = { ipVersion for ipVersion, prefixLists in
               ipVersionsLists.items() if listName in prefixLists }
         t3( f'fixPrefixListMatches: {listName=} {availableIpVersions=}' )
         if not availableIpVersions:
            del mapEntry.matchRule[ 'matchPrefixList' ]
         elif availableIpVersions == { 4 }:
            # Already have the right version
            continue
         elif availableIpVersions == { 6 }:
            del mapEntry.matchRule[ 'matchPrefixList' ]
            rule = Tac.newInstance( type( rule ).__name__, 'matchIpv6PrefixList' )
            rule.strValue = listName
            mapEntry.matchRule.addMember( rule )
         elif len( availableIpVersions ) > 1:
            # The error is not specific, so we don't want to add it multiple times.
            if not errorMessageAdded:
               cls.smResult.errorMessage.push(
                     'MIXED prefix-sets are unsupported in route-maps in EOS' )
               errorMessageAdded = True
         else:
            assert False, f'Unexpected {availableIpVersions=}'

   @classmethod
   def raiseErrorsFromSm( cls, sessionName ):
      '''Raises any error detected by SM as `AirStreamLib.ToNativeSyncherError`'''
      errorMessages = cls.smResult.errorMessage
      if not errorMessages:
         return

      reportedMessages = []
      separator = '\n'
      separatorSize = _textSize( separator )
      remainingBytes = ( ERROR_MESSAGE_TRUNCATION_SIZE -
                         _textSize( ERROR_MESSAGE_TRUNCATION_TEXT ) )
      for errorMessage in errorMessages.values():
         errorMessageSize = _textSize( errorMessage ) + separatorSize
         if remainingBytes < errorMessageSize:
            reportedMessages.append( ERROR_MESSAGE_TRUNCATION_TEXT )
            break
         reportedMessages.append( errorMessage )
         remainingBytes -= errorMessageSize

      userMessage = separator.join( reportedMessages )
      raise AirStreamLib.ToNativeSyncherError( sessionName=sessionName,
            syncherName=cls.__name__, userMsg=userMessage )

   # override from parent
   @classmethod
   def run( cls, sessionName ):
      t0( 'Pre-commit handler' )
      try:
         cls.fixEmptyRouteMaps( sessionName )
         cls.fixPrefixListMatches( sessionName )
         cls.raiseErrorsFromSm( sessionName )  # MUST be last to report all errors
         assert not cls.smResult.errorMessage, 'Should have raised an error'
      finally:
         cls.smResult.clear()

def registerCopyHandler( entityManager ):
   '''Registers copy handler and its owning entity'''
   handlerDir = CliSession.registerCopyHandlerDir( entityManager,
                                                   COPY_HANDLER_DIR_NAME,
                                                   COPY_HANDLER_DIR_TYPE )
   handlerDir.copyHandler = ()
   CliSession.registerCustomCopyHandler( entityManager,
                                         EXTERNAL_MOUNT_PATH,
                                         EXTERNAL_MOUNT_TYPE,
                                         handlerDir.copyHandler )
   t0( 'Copy handler registered' )

def Plugin( entityManager ):
   if not toggleAirStreamPolicyDefinitionsEnabled():
      return

   CliSession.registerConfigGroup( entityManager, 'airstream-cmv',
                                   EXTERNAL_MOUNT_PATH )

   registerCopyHandler( entityManager )

   smResult = Tac.newInstance( SM_RESULT_TYPE_NAME )
   RoutingPolicyOpenConfigToNativeSmWrapper.configure( entityManager, smResult )
   RoutingPolicyOpenConfigPreCommitHandler.configure( entityManager, smResult )

   t0( 'RoutingPolicyOpenConfigToNative plugin installed' )
