#!/usr/bin/env python3
# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from CallbackRegistry import CallbackRegistry
import json
from RcfCompilerStageCommon import RcfCompilerStageResultBase
import RcfOpenConfigCommon
from RcfOpenConfigDiag import RcfOpenConfigDiag
import re
import ipaddress
import yaml

class RcfOpenConfigFunctionValidationRequest:
   """An OpenConfig function validation request.

   Args:
      openConfigFunctions (dict of string): Raw OpenConfig function text to be
         deserialized from JSON and validated.
      rcfCodeVersion (int): Code version number from Sysdb
   """
   def __init__( self, openConfigFunctions, rcfCodeVersion ):
      self.openConfigFunctions = openConfigFunctions
      self.rcfCodeVersion = rcfCodeVersion

class RcfOpenConfigFunctionValidationResult( RcfCompilerStageResultBase ):
   """Output from the OpenConfig function validator.

   Args:
      request (RcfOpenConfigFunctionValidationRequest): The request to which this is
         a response.
      success (bool): Whether or not validation was successful.
      openConfigFunctions (dict of dicts): Deserialized openconfig functions. Each
         validation returns the function data, or None if there's an error.
      errorList (list of diag messages or None): All validation errors.
   """
   def __init__( self, request, success, openConfigFunctions, diag ):
      super().__init__( success=success, diag=diag )
      self.request = request
      self.openConfigFunctions = openConfigFunctions

def validateOpenConfigFunctions( request ):
   """Validate the OpenConfig functions in a request.

   Args:
      request (RcfOpenConfigFunctionValidationRequest): The request containing the
         openconfig functions to validate.

   Returns an RcfOpenConfigFunctionValidationResult instance.
   """
   diag = RcfOpenConfigDiag()
   validator = RcfOpenConfigFunctionValidator( diag )
   openConfigFunctions = {}
   for functionName, text in request.openConfigFunctions.items():
      openConfigFunctions[ functionName ] = validator.validate( text, functionName )
   # No warnings are expected for OpenConfig function validation. They aren't put
   # into the result and aren't factored into `success`. The request does not
   # distinguish between strict and non-strict, either.
   assert not diag.allWarnings
   success = not diag.hasErrors()
   result = RcfOpenConfigFunctionValidationResult( request, success,
                                                   openConfigFunctions, diag )
   return result

class RcfOpenConfigFunctionValidator( CallbackRegistry ):
   """Validate json text against the OpenConfig policy-definitions model.

   The API is validate( text, functionName ) which returns the deserialized json as
   a dictionary if validation succeeds, else None.

   Three types of validation are performed:
   - JSON decoding. If the builtin json module fails to decode the text, an error
     is produced.
   - Schema validation by comparing against RcfOpenConfigMetadata.yaml:
     - Values are of the expected type. Dictionaries ("containers" in OpenConfig
       terminoloy) and lists are present where expected. There are not values where
       collections are expected, and vice versa.
     - There are no unexpected keys in containers.
   - Validation performed by functions registered with registerCallback:
      - Missing required keys
      - ...and anything else.

   Schema validation ensures the json is an inclusive subset of the paths covered
   by RcfOpenConfigMetadata.yaml. Registered validators are otherwise responsible for
   validation. Leaving all value validation to registered validators provides the
   most flexibility to enforce constraints, which is important because some
   constraints would be difficult to describe with a schema. Here are some examples:
   - Internally keyed lists with duplicate keys
   - Leaves that are generally optional, but at least one must be present
   - Leaves that are conditionally required/forbidden. For example,
     set-community/inline|reference must agree with the setting of
     set-community/config/method.

   Validators are registered with registerCallback. See CallbackRegistry for details.
   RcfOpenConfigFunctionValidator.validate maintains a validator key which
   represents a node in the json/OpenConfig model. If one or more validators are
   registered with the current node/validator key, they are executed.

   The validator key differs from the OpenConfig path in two ways:
   - Slashes are replaced with periods (.) to make them visually distinct.
   - List keys (statement[ FOO ]) are omitted in favor of empty brackets
     (statement[]). Without doing this, two elements in a list would not produce
     the same validator key, and some type or regex/globbing would be needed to
     match a validator.

   Validator functions are executed after schema validation, therefore they
   can assume they receive a value of the expected type.

   Some examples:
   - "" handles validation for the root node
   - ".statements" handles validation for the contents of statements
   - ".statements.statement" handles validation contents of statements/statement,
     which is guaranteed to be a list.
   - ".statements.statement[]" handles validation for each element of statement

   The expected signature of a validator function:
   Args:
      self (RcfOpenConfigFunctionValidator)
      data (any): The data at the path this validator was registered with.
      functionName (str): The name of the function under validation
      openConfigPath (str): The OpenConfig path
   Returns: True if validation succeeds, False otherwise.
   """
   # Mapping from python types deserializable from JSON to readable strings
   typeMap = {
         str: "string",
         int: "integer",
         float: "float",
         bool: "boolean",
         type( None ): "null",
         dict: "container",
         list: "list",
   }

   def __init__( self, diag ):
      with open( "/usr/share/Rcf/RcfOpenConfigMetadata.yaml" ) as f:
         metadata = yaml.safe_load( f )
         self.schema = metadata[ 'Schema' ]
      self.diag = diag

   def validate( self, text, functionName ):
      """Validate text.

      Args:
         text (str): Incoming json text
         functionName (str): The name of the function. This is needed to produce
            error messages.
      Returns: the deserialized json as a dictionary if validation suceeds, otherwise
         None.
      """
      def findDuplicateKeys( keysAndValues ):
         seen = set()
         for key, _ in keysAndValues:
            if key in seen:
               self.diag.openConfigFunctionValidationDuplicateContainerKey(
                  functionName, openConfigPath, key )
            seen.add( key )
         return dict( keysAndValues )

      openConfigPath = ""
      validatorKey = ""

      try:
         data = json.loads( text, object_pairs_hook=findDuplicateKeys )
      except json.JSONDecodeError:
         self.diag.openConfigFunctionDecodeError( functionName, openConfigPath )
         return None

      self._validateRecursive( data, self.schema, functionName, openConfigPath,
                               validatorKey )

      if self.diag.hasErrors():
         return None
      return data

   def getExpectedTypesForSchema( self, schema ):
      """Return the expected JSON data type according to the schema.

      Args:
         schema: a value in the OpenConfig metadata. This is either a list,
         dictionary, or a string. If it's a string, it's one or more comma-seperated
         tokens denoting the expected type of a leaf (see values in self.typeMap).
      """
      if ( schemaType := type( schema ) ) in [ list, dict ]:
         expectedTypes = [ self.typeMap[ schemaType ] ]
      else:
         # Schema is a value, a string denoting the expected type.
         assert schemaType == str
         expectedTypes = []
         for dataType in schema.split( "," ):
            expectedType = dataType.strip()
            assert expectedType in self.typeMap.values()
            expectedTypes.append( expectedType )
      return expectedTypes

   def _validateRecursive( self, data, schema, functionName, openConfigPath,
                           validatorKey ):
      # Validation is performed with a pre-ordered depth-first traversal.
      # 1. Check for schema errors in this node
      # 2. Run registered validators
      # 3. Recurse inside lists/containers
      actualType = self.typeMap[ type( data ) ]
      expectedTypes = self.getExpectedTypesForSchema( schema )

      if actualType not in expectedTypes:
         self.diag.openConfigFunctionValidationUnexpectedTypeError(
               functionName, openConfigPath, expectedTypes, actualType, data )
         return

      for validator in self.getCallbacks( validatorKey ):
         validator( self, data, functionName, openConfigPath )

      if actualType == "container":
         for key, value in data.items():
            if key not in schema:
               self.diag.openConfigFunctionValidationUnexpectedPathError(
                     functionName, openConfigPath, key )
               continue
            self._validateRecursive( value, schema[ key ], functionName,
                                     openConfigPath + f"/{key}",
                                     validatorKey + f".{key}" )

      if actualType == "list":
         # The schema should have one entry containing the schema to validate
         # each entry in this list.
         schemaEntry, = schema
         # If the list element is a python dictionary ( i.e. a container ) we want
         # to recursiverly check all the leafs/container inside it. If list is a
         # leaf-list we want to skip this.
         if isinstance( schemaEntry, dict ):
            for value in data:
               # Get the name/key. If it's missing or not a string, report it as "?".
               # The name is assumed to exist in /name.
               key = value.get( 'name', '?' )
               key = key if isinstance( key, str ) else '?'
               self._validateRecursive( value, schemaEntry, functionName,
                                        openConfigPath + f"[{key}]",
                                        validatorKey + "[]" )

# Begin validation functions

# Non-condition/action validation

def requirePath( validatorKey, node ):
   # "node" is required in the path corresponding to the validator key
   @RcfOpenConfigFunctionValidator.registerCallback( validatorKey )
   def validateBasicRequiredPath( self, data, functionName, openConfigPath ):
      if node not in data:
         self.diag.openConfigFunctionValidationMissingPathError(
               functionName, openConfigPath, node )

# Validation Helper

U8_MAX = 0xFF
U16_MAX = 0xFFFF
U32_MAX = 0xFFFFFFFF

def validateInteger( data, maxVal, minVal=0 ):
   return minVal <= data <= maxVal

requirePath( "", "name" )
requirePath( "", "statements" )
requirePath( ".statements", "statement" )
requirePath( ".statements.statement[]", "name" )

functionNameRe = re.compile( r"[_a-zA-Z][_a-zA-Z0-9]*" )

@RcfOpenConfigFunctionValidator.registerCallback( ".name" )
def validateFunctionName( self, data, functionName, openConfigPath ):
   # The functionName as passed by the CLI and as defined in the JSON must match.
   if data != functionName:
      self.diag.openConfigFunctionValidationFunctionNameMismatchError(
            functionName, openConfigPath, data )
   # The functionName must also be a valid RCF function name.
   if not functionNameRe.fullmatch( data ):
      self.diag.openConfigFunctionValidationInvalidFunctionNameError(
            functionName, openConfigPath, data )

@RcfOpenConfigFunctionValidator.registerCallback( ".statements.statement" )
def validateStatementNames( self, data, functionName, openConfigPath ):
   # Two statements cannot have the same name.
   seen = set()
   for statement in data:
      # validateBasicRequiredPath handles requiring the "name" node, so silently
      # ignore that here.
      statementName = statement.get( "name", None )
      if statementName is not None:
         if statementName in seen:
            self.diag.openConfigFunctionValidationDuplicateStatementError(
               functionName, openConfigPath, statementName )
         seen.add( statementName )

# Condition validation

# requirePath is not transitive so this has no effect unless the match-prefix-set
# node exists.
requirePath( ".statements.statement[].conditions.match-prefix-set", "config" )
# Require the prefix-set path underneath config. We don't want to handle the
# scenario where match-set-options exists (when that path is supported) but
# prefix-set does not (an obvious misconfig).
requirePath( ".statements.statement[].conditions.match-prefix-set.config",
             "prefix-set" )

# Equivalent to EXTERNAL_REF in RcfCommonLexerRules.g4
externalRefRe = re.compile( r"[a-zA-Z0-9:\\\[\]=+\-_][a-zA-Z0-9:\\\[\]=+\-_.]*" )

@RcfOpenConfigFunctionValidator.registerCallback(
      ".statements.statement[].conditions.match-prefix-set.config.prefix-set" )
def validateMatchPrefixSet( self, data, functionName, openConfigPath ):
   if not externalRefRe.fullmatch( data ):
      self.diag.openConfigFunctionValidationInvalidLeafValueError(
            functionName, openConfigPath, self.typeMap[ type( data ) ], data )

@RcfOpenConfigFunctionValidator.registerCallback(
      ".statements.statement[].conditions.bgp-conditions.as-path-length."
      "config.value" )
def validateAsPathLenghtValue( self, data, functionName, openConfigPath ):
   # deviation only supports value range of 0-4000
   if not validateInteger( data, 4000 ):
      self.diag.openConfigFunctionValidationInvalidLeafValueError(
         functionName, openConfigPath, self.typeMap[ type( data ) ], data )

# Require config container under as-path-length as, an empty container is a
# misconfiguration.

requirePath( ".statements.statement[].conditions.bgp-conditions.as-path-length",
             "config" )

# Both operator and value are requied.
requirePath( ".statements.statement[].conditions.bgp-conditions.as-path-length."
             "config", "operator" )
requirePath( ".statements.statement[].conditions.bgp-conditions.as-path-length."
             "config", "value" )

@RcfOpenConfigFunctionValidator.registerCallback(
      ".statements.statement[].conditions.bgp-conditions.as-path-length."
      "config.operator" )
def validateAsPathLenghtOperator( self, data, functionName, openConfigPath ):
   if data not in [ 'ATTRIBUTE_EQ', 'ATTRIBUTE_GE', 'ATTRIBUTE_LE' ]:
      self.diag.openConfigFunctionValidationInvalidLeafValueError(
         functionName, openConfigPath, self.typeMap[ type( data ) ], data )

# Bgp-Conditions validation

@RcfOpenConfigFunctionValidator.registerCallback(
      ".statements.statement[].conditions.bgp-conditions.config.local-pref-eq" )
def validateLocalPrefEq( self, data, functionName, openConfigPath ):
   if not validateInteger( data=data, maxVal=U32_MAX ):
      self.diag.openConfigFunctionValidationInvalidLeafValueError(
         functionName, openConfigPath, self.typeMap[ type( data ) ], data )

@RcfOpenConfigFunctionValidator.registerCallback(
      ".statements.statement[].conditions.bgp-conditions.config.route-type" )
def validateRouteType( self, data, functionName, openConfigPath ):
   if data not in [ 'INTERNAL', 'EXTERNAL' ]:
      self.diag.openConfigFunctionValidationInvalidLeafValueError(
         functionName, openConfigPath, self.typeMap[ type( data ) ], data )

# Action validation

@RcfOpenConfigFunctionValidator.registerCallback(
      ".statements.statement[].actions.bgp-actions.config.set-med" )
def validateSetMed( self, data, functionName, openConfigPath ):
   # arista-bgp-set-med-type is a union of uint32, string, and an enum.
   # The former is parsed as an integer, the latter two are converted into
   # python strings. The string either matches the regex defined in
   # arista-bgp-set-med-type or is equivalent to one of the enum values (only IGP
   # for now).
   valid = True

   value = None
   # arista-bgp-set-med-type integer
   if isinstance( data, int ):
      value = data
   # arista-bgp-set-med-type string, enum
   elif isinstance( data, str ):
      metricType, value = RcfOpenConfigCommon.parseSetMed( data )
      if not ( metricType or data == "IGP" ):
         valid = False
   if value and ( value > U32_MAX or value < 0 ):
      valid = False

   if not valid:
      self.diag.openConfigFunctionValidationInvalidLeafValueError(
            functionName, openConfigPath, self.typeMap[ type( data ) ], data )

requirePath( ".statements.statement[].actions.bgp-actions.set-as-path-prepend.",
             "config" )
requirePath( ".statements.statement[].actions.bgp-actions.set-as-path-prepend."
            "config", "repeat-n" )

@RcfOpenConfigFunctionValidator.registerCallback(
   ".statements.statement[].actions.bgp-actions.config.set-local-pref" )
@RcfOpenConfigFunctionValidator.registerCallback(
   ".statements.statement[].actions.bgp-actions.set-as-path-prepend."
   "config.asn" )
def validateU32( self, data, functionName, openConfigPath ):
   if not validateInteger( data, U32_MAX ):
      self.diag.openConfigFunctionValidationInvalidLeafValueError(
         functionName, openConfigPath, self.typeMap[ type( data ) ], data )

@RcfOpenConfigFunctionValidator.registerCallback(
   ".statements.statement[].actions.bgp-actions.set-as-path-prepend."
   "config.repeat-n" )
def validateRepeatN( self, data, functionName, openConfigPath ):
   if not validateInteger( data, U8_MAX, minVal=1 ):
      self.diag.openConfigFunctionValidationInvalidLeafValueError(
         functionName, openConfigPath, self.typeMap[ type( data ) ], data )

def validateIpAddr( addr ):
   try:
      _ = ipaddress.ip_address( addr )
      return True
   except ValueError:
      return False

@RcfOpenConfigFunctionValidator.registerCallback(
     ".statements.statement[].actions.bgp-actions.config.set-next-hop" )
def validateSetNextHop( self, data, functionName, openConfigPath ):
   valid = True
   if data not in [ "PEER_ADDRESS", "SELF" ]:
      valid = validateIpAddr( data )
   if not valid:
      self.diag.openConfigFunctionValidationInvalidLeafValueError(
            functionName, openConfigPath, self.typeMap[ type( data ) ], data )

@RcfOpenConfigFunctionValidator.registerCallback(
      ".statements.statement[].actions.bgp-actions.config.set-route-origin" )
def validateSetRouteOrigin( self, data, functionName, openConfigPath ):
   if data not in [ 'IGP', 'EGP', 'INCOMPLETE' ]:
      self.diag.openConfigFunctionValidationInvalidLeafValueError(
         functionName, openConfigPath, self.typeMap[ type( data ) ], data )

@RcfOpenConfigFunctionValidator.registerCallback(
      ".statements.statement[].actions.config.policy-result" )
def validatePolicyResult( self, data, functionName, openConfigPath ):
   if data not in [ "ACCEPT_ROUTE", "REJECT_ROUTE", "NEXT_STATEMENT" ]:
      self.diag.openConfigFunctionValidationInvalidLeafValueError(
            functionName, openConfigPath, self.typeMap[ type( data ) ], data )

@RcfOpenConfigFunctionValidator.registerCallback(
   ".statements.statement[].actions.isis-actions.config.set-level" )
def validateSetLevel( self, data, functionName, openConfigPath ):
   valid = False
   if isinstance( data, int ):
      if validateInteger( data, 2, minVal=1 ):
         valid = True
   elif isinstance( data, str ):
      if data in [ 'LEVEL_1', 'LEVEL_2', 'LEVEL_1_2' ]:
         valid = True
   if not valid:
      self.diag.openConfigFunctionValidationInvalidLeafValueError(
         functionName, openConfigPath, self.typeMap[ type( data ) ], data )

@RcfOpenConfigFunctionValidator.registerCallback(
   ".statements.statement[].actions.isis-actions.config.set-metric-style-type" )
def validateSetMetricStyleType( self, data, functionName, openConfigPath ):
   if isinstance( data, str ) and data != "WIDE_METRIC":
      self.diag.openConfigFunctionValidationInvalidLeafValueError(
         functionName, openConfigPath, self.typeMap[ type( data ) ], data )

@RcfOpenConfigFunctionValidator.registerCallback(
   ".statements.statement[].actions.isis-actions.config.set-metric" )
def validateSetMetric( self, data, functionName, openConfigPath ):
   if not validateInteger( data, 16777215, minVal=1 ):
      self.diag.openConfigFunctionValidationInvalidLeafValueError(
         functionName, openConfigPath, self.typeMap[ type( data ) ], data )

requirePath( ".statements.statement[].actions.bgp-actions.set-community",
             "config" )
requirePath( ".statements.statement[].actions.bgp-actions.set-community.config",
             "method" )
requirePath( ".statements.statement[].actions.bgp-actions.set-community.config",
             "options" )

@RcfOpenConfigFunctionValidator.registerCallback(
      ".statements.statement[].actions.bgp-actions.set-community" )
def validateSetCommunity( self, data, functionName, openConfigPath ):
   if 'config' not in data or 'method' not in data[ 'config' ]:
      # We don't need to add a diag messsage here, as requirePath() fn calls for
      # config and method path should take care of missing paths.
      return
   expectedMethod = data[ 'config' ][ 'method' ].lower()
   validMethods = [ "reference", "inline" ]
   if expectedMethod not in validMethods:
      # Invalid method specified
      self.diag.openConfigFunctionValidationInvalidLeafValueError(
            functionName,
            openConfigPath + "/config/method",
            self.typeMap[ type( expectedMethod ) ],
            expectedMethod.upper() )
   else:
      if expectedMethod not in data:
         self.diag.openConfigFunctionValidationMissingPathError(
               functionName, openConfigPath, expectedMethod )
      validMethods.remove( expectedMethod )
   for method in validMethods:
      if method in data:
         self.diag.openConfigFunctionValidationUnexpectedPathError(
            functionName, openConfigPath, method )

@RcfOpenConfigFunctionValidator.registerCallback(
      ".statements.statement[].actions.bgp-actions.set-community.config."
      "options" )
def validateSetCommuniyOptions( self, data, functionName, openConfigPath ):
   if data not in [ "ADD", "REMOVE", "REPLACE" ]:
      self.diag.openConfigFunctionValidationInvalidLeafValueError(
            functionName, openConfigPath, self.typeMap[ type( data ) ], data )

requirePath( ".statements.statement[].actions.bgp-actions.set-community.inline",
             "config" )
requirePath( ".statements.statement[].actions.bgp-actions.set-community.inline."
             "config", "arista-communities" )

@RcfOpenConfigFunctionValidator.registerCallback(
      ".statements.statement[].actions.bgp-actions.set-community.inline.config."
      "arista-communities" )
def validateAristaCommunities( self, data, functionName, openConfigPath ):
   for community in data:
      valid = False
      if isinstance( community, int ):
         if validateInteger( community, U32_MAX ):
            valid = True
      elif isinstance( community, str ):
         if ':' in community:
            comm = community.split( ':' )
            if len( comm ) == 2:
               asn, uniqueIdef = comm
               if ( asn.isnumeric() and uniqueIdef.isnumeric() and
                    validateInteger( int( asn ), U16_MAX ) and
                    validateInteger( int( uniqueIdef ), U16_MAX ) ):
                  valid = True
         else:
            wellKnownCommunity = [ 'INTERNET', 'GSHUT', 'NO_EXPORT', 'NO_ADVERTISE',
                                   'LOCAL_AS' ]
            if community in wellKnownCommunity or community == "NONE":
               valid = True
      if not valid:
         self.diag.openConfigFunctionValidationInvalidLeafValueError(
               functionName, openConfigPath, self.typeMap[ type( community ) ],
               community )

requirePath( ".statements.statement[].actions.bgp-actions.set-community.reference",
             "config" )
requirePath( ".statements.statement[].actions.bgp-actions.set-community.reference."
             "config", "community-set-refs" )

@RcfOpenConfigFunctionValidator.registerCallback(
      ".statements.statement[].actions.bgp-actions.set-community.reference.config."
      "community-set-refs" )
def vaildateCommunitySetRefs( self, data, functionName, openConfigPath ):
   for communitySet in data:
      if not externalRefRe.fullmatch( communitySet ):
         self.diag.openConfigFunctionValidationInvalidLeafValueError(
               functionName, openConfigPath, self.typeMap[ type( communitySet ) ],
               data )
