# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import jsonschema
import json

from TypeFuture import TacLazyType

AegisDirection = TacLazyType( 'Aegis::Direction' )
TcamCapacity = TacLazyType( 'Aegis::TcamCapacity' )

AF_DIR_TCAM_FEATURE_MAP = {
   'in': {
      'ipv4': 'traffic-policy port ipv4',
      'ipv6': 'traffic-policy port ipv6'
   },
   'out': {
      'ipv4': 'traffic-policy port ipv4 egress',
      'ipv6': 'traffic-policy port ipv6 egress'
   }
}

class UnsupportedHardwareDefinitionError( Exception ):
   def __init__( self, key, value, supportedValues=None ):
      if supportedValues:
         supportedValuesStr = (
            ' (valid values: %s)' % ', '.join(
            "'%s'" % val for val in supportedValues ) )
      else:
         supportedValuesStr = ''

      self.message = ( "Invalid value '%s' for key '%s'%s." %
         ( value, key, supportedValuesStr ) )

      super().__init__(
         self.message )

class TcamKeySizeError( Exception ):
   def __init__( self, attribute, value, limit ):
      self.message = ( "Total number of "
         "bits %d exceeds the limit %d for %s." %
            ( value, limit, attribute ) )
      super().__init__( self.message )

class HardwareDefinitionJsonValidator:
   def __init__( self ):
      self.hardwareCapacityJson = readJson( '/etc/HardwareCapacity.json' )
      self.tcamProfileParamJson = readJson( '/etc/TcamProfile.json' )

      self.schema = {
         'type': 'object',
         'properties': {},
         'required': []
      }
      self.populateChipType()
      self.populateChipRevision()
      self.populateBigKaps()
      self.populateMdbProfile()
      self.populateTcamProfile()

   def populateChipType( self ):
      attribute = 'chipType'
      self.schema[ 'properties' ][ attribute ] = {
         'type': 'string',
         'enum': list( self.hardwareCapacityJson[ 'supportedChipTypes' ] ),
      }
      self.schema[ 'required' ].append( attribute )

   def populateChipRevision( self ):
      attribute = 'chipRevision'
      self.schema[ 'properties' ][ attribute ] = {
         'type': 'string',
         'anyOf': []
      }
      for chipType in self.hardwareCapacityJson[ 'supportedChipTypes' ]:
         chipRevisionPerChipType = {
            'properties': {
               'chipType': { 'enum': [ chipType ] }
            },
            'enum': list( self.hardwareCapacityJson[ chipType ][ attribute ] )
         }
         self.schema[ 'properties' ][ attribute ][ 'anyOf' ].append(
            chipRevisionPerChipType )
      self.schema[ 'required' ].append( attribute )

   def populateBigKaps( self ):
      attribute = 'bigKaps'
      self.schema[ 'properties' ][ attribute ] = {
         'type': 'boolean',
         'anyOf': []
      }
      for chipType in self.hardwareCapacityJson[ 'supportedChipTypes' ]:
         bigKapsPerChipType = {
            'properties': {
               'chipType': { 'enum': [ chipType ] }
            },
            'enum': list( self.hardwareCapacityJson[ chipType ][ attribute ] )
         }
         self.schema[ 'properties' ][ attribute ][ 'anyOf' ].append(
            bigKapsPerChipType )
      self.schema[ 'required' ].append( attribute )

   def populateMdbProfile( self ):
      attribute = 'mdbProfile'
      self.schema[ 'properties' ][ attribute ] = {
         'type': 'string',
         'anyOf': []
      }
      for chipType in self.hardwareCapacityJson[ 'supportedChipTypes' ]:
         supportedMdbProfiles = self.hardwareCapacityJson[ chipType ][ attribute ]
         chipTypeMdbProfiles = {
            'properties': {
               'chipType': { 'enum': [ chipType ] },
            },
            'enum': list( supportedMdbProfiles ),
         }
         self.schema[ 'properties' ][ attribute ][ 'anyOf' ].append(
            chipTypeMdbProfiles )
      self.schema[ 'required' ].append( attribute )

   def populateTcamProfile( self ):
      attribute = 'tcamProfile'
      self.schema[ 'properties' ][ attribute ] = {
         'type': 'object',
         'properties': {},
         'anyOf': [],
      }
      for chipType in self.hardwareCapacityJson[ 'supportedChipTypes' ]:
         featuresPerChipType = self.tcamProfileParamJson[ chipType ]
         tcamProfileFeaturesSchema = {
            'type': 'object',
            'properties': {},
            'propertyNames': {
               "pattern": '|'.join( list( featuresPerChipType ) )
            }
         }
         self.schema[ 'properties' ][ attribute ][ 'anyOf' ].append(
            tcamProfileFeaturesSchema )
         for feature, featureSupportedVals in featuresPerChipType.items():
            tcamProfileFeaturesSchema[ 'properties' ][ feature ] = {
               'type': 'object',
               'properties': {
                  'key-field': {
                     'type': 'array',
                     'items': {
                        'type': 'string',
                        'enum': list( featureSupportedVals[ 'key-field' ] )
                     }
                  },
                  'action': {
                     'type': 'array',
                     'items': {
                        'type': 'string',
                        'enum': list( featureSupportedVals[ 'action' ] )
                     }
                  },
                  'portQualifierSize': { 'type': 'number' }
               },
               'required': [ 'key-field', 'action' ]
            }
      self.schema[ 'required' ].append( attribute )

   def validate( self, data, subCfg ):
      self.validateJson( data )
      self.populateSubCfg( data, subCfg )
      self.validateSubCfg( data, subCfg )

   def validateJson( self, data ):
      try:
         jsonschema.validate( data, self.schema )
      except jsonschema.ValidationError as e:
         # Missing key
         if next( iter( e.relative_schema_path ) ) == 'required':
            missingKeyword = str( e ).split( ' ', maxsplit=1 )[ 0 ]
            raise KeyError( missingKeyword ) from e

         keySubDirs = []
         for path in e.absolute_path:
            if isinstance( path, int ):
               continue
            keySubDirs.append( path )
         associatedKey = '/'.join( keySubDirs )

         # Patterns are expected to be used to validate
         # keys in the following situation:
         # {
         #    outerKey : {
         #       innerKey1 : {},
         #       innerKey2 : {},
         #   }
         # }
         # The inner keys are expected to satisfy some pattern.
         if e.validator == 'pattern':
            if associatedKey == 'tcamProfile':
               # If we have more associated keys in the future, we
               # should encode in some object _how_ to derive the
               # supported values from the pattern.
               unsupportedValue = e.instance
               supportedValues = e.validator_value.split( '|' )
               raise UnsupportedHardwareDefinitionError( associatedKey,
                                                         unsupportedValue,
                                                         supportedValues ) from e
            assert False, f'Unsupported key {associatedKey}'

         # Value with the wrong type
         if e.validator == 'type':
            unsupportedValue = str( e.instance )
            if len( unsupportedValue ) > 20:
               unsupportedValue = str( e.instance )[ : 20 ] + '...'
            supportedTypes = [ 'value of type: ' + e.validator_value ]
            raise UnsupportedHardwareDefinitionError( associatedKey,
                                                      unsupportedValue,
                                                      supportedTypes ) from e
         # Value has correct type but is an unsupported value
         if e.validator == 'enum':
            unsupportedValue = e.instance
            supportedValues = e.validator_value
            raise UnsupportedHardwareDefinitionError( associatedKey,
                                                      unsupportedValue,
                                                      supportedValues ) from e

   def populateSubCfg( self, cfgJson, subCfg ):
      chipType = cfgJson[ 'chipType' ]
      chipRevision = cfgJson[ 'chipRevision' ]
      isBigKaps = cfgJson[ 'bigKaps' ]
      tcamProfileCfg = cfgJson[ 'tcamProfile' ]
      mdbProfile = cfgJson[ 'mdbProfile' ]

      supportedChipType = self.hardwareCapacityJson[ chipType ]

      subCfg.chipType = chipType
      subCfg.chipRevision = chipRevision
      subCfg.bigKaps = isBigKaps
      subCfg.mdbProfile = mdbProfile

      kapsCapInfo = supportedChipType[ 'kapsCapacity' ][
         chipRevision ][ mdbProfile ]
      for prefixType, kapsCap in kapsCapInfo.items():
         subCfg.kapsCapacity[ prefixType ] = kapsCap

      tcamCapacityDict = supportedChipType[ 'tcamCapacity' ][ chipRevision ]
      subCfg.tcamCapacity = TcamCapacity(
         tcamCapacityDict[ 'numLargeBanks' ],
         tcamCapacityDict[ 'numEntriesPerBank' ] )
      subCfg.tcamProfile = ( '', )
      tcamProfileFeaturesJson = self.tcamProfileParamJson[ chipType ]

      keySizeLimits = sorted( supportedChipType[ 'tcamKeySize' ] )
      actionSizeLimits = sorted( supportedChipType[ 'actionSize' ] )

      # Start TCAM profile capability handling.
      for feature, featureContents in tcamProfileCfg.items():
         supportedFeatureJson = tcamProfileFeaturesJson[ feature ]
         keyFieldTotalSize = 0
         tcamFeature = subCfg.tcamProfile.feature.newMember( feature )
         maybeOverlappedKeyFields = set()
         knownOverlappingFields = {}
         for group in supportedFeatureJson[ 'overlappingKeyFieldGroups' ]:
            for okf in group:
               knownOverlappingFields[ okf ] = set( group )

         for keyField in featureContents[ 'key-field' ]:
            keyFieldSize = supportedFeatureJson[ 'key-field'
               ][ keyField ]
            tcamFeature.qualifier[ keyField ] = True
            kfOverlapGroup = knownOverlappingFields.get( keyField )
            if kfOverlapGroup:
               if keyField not in maybeOverlappedKeyFields:
                  # This is the first keyField in the group that's seen.
                  keyFieldTotalSize += keyFieldSize
                  # Add every field in the group so it doesn't get
                  # double counted.
                  maybeOverlappedKeyFields |= kfOverlapGroup
            else:
               keyFieldTotalSize += keyFieldSize

         keyPortQualSize = featureContents.get( 'portQualifierSize',
               supportedFeatureJson[ 'defaultPortQualifierSize' ] )
         keyFieldTotalSize += keyPortQualSize

         actionTotalSize = 0
         for action in featureContents[ 'action' ]:
            actionSize = supportedFeatureJson[ 'action' ][ action ]
            actionTotalSize += actionSize
            tcamFeature.action[ action ] = True

         for keySizeLimit, actionSizeLimit in zip( keySizeLimits,
                                                   actionSizeLimits ):
            if ( keyFieldTotalSize <= keySizeLimit and
                 actionTotalSize <= actionSizeLimit ):
               subCfg.tcamKeySize[ feature ] = keySizeLimit
               break

      for direction, afToGacEntries in supportedChipType[
            'gacEntriesPerFeature' ].items():
         subCfg.numGacEntries[ direction ] = 0
         for af, numGacEntries in afToGacEntries.items():
            if AF_DIR_TCAM_FEATURE_MAP[ direction ][ af ] not in tcamProfileCfg:
               continue
            subCfg.numGacEntries[ direction ] += numGacEntries

   def validateSubCfg( self, cfgJson, subCfg ):
      chipType = subCfg.chipType
      supportedChipType = self.hardwareCapacityJson[ chipType ]
      # validate key field size and action size
      if not subCfg.bigKaps:
         supportedInNonBigKaps = supportedChipType[ 'mdbProfile' ][
               subCfg.mdbProfile ]
         if not supportedInNonBigKaps:
            supportedProfiles = [
               profileName for profileName, supported in
               supportedChipType[ 'mdbProfile' ].items() if supported ]
            raise UnsupportedHardwareDefinitionError(
               'mdbProfile',
               subCfg.mdbProfile,
               supportedProfiles )
      tcamProfileFeaturesJson = self.tcamProfileParamJson[ chipType ]

      keySizeLimits = sorted( supportedChipType[ 'tcamKeySize' ] )
      actionSizeLimits = sorted( supportedChipType[ 'actionSize' ] )

      for tcamFeatureName, tcamFeature in subCfg.tcamProfile.feature.items():
         totalQualifierSize = 0
         supportedFeatureJson = tcamProfileFeaturesJson[ tcamFeatureName ]
         for qualifier in tcamFeature.qualifier:
            qualSize = supportedFeatureJson[ 'key-field' ][ qualifier ]
            totalQualifierSize += qualSize
         tcamFeatureDefinition = cfgJson[ 'tcamProfile' ][ tcamFeatureName ]
         keyPortQualSize = tcamFeatureDefinition.get( 'portQualifierSize',
               supportedFeatureJson[ 'defaultPortQualifierSize' ] )
         totalQualifierSize += keyPortQualSize
         if totalQualifierSize > keySizeLimits[ -1 ]:
            raise TcamKeySizeError( 'key-field (%s)' % tcamFeatureName,
                                    totalQualifierSize,
                                    keySizeLimits[ -1 ] )
         totalActionSize = 0
         for action in tcamFeature.action:
            actionSize = supportedFeatureJson[ 'action' ][ action ]
            totalActionSize += actionSize
         if totalActionSize > actionSizeLimits[ -1 ]:
            raise TcamKeySizeError( 'action (%s)' % tcamFeatureName,
                                    totalActionSize,
                                    actionSizeLimits[ -1 ] )

def readJson( fileName ):
   parsedJson = None
   with open( fileName ) as fd:
      parsedJson = json.loads( fd.read() )
   return parsedJson
