# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from typing import (
   Iterator,
   Optional,
)

from CliDynamicPlugin.CbfNhgCliModel import (
   CbfNhgOverride,
   CbfNhgOverrideTable,
)
from CliPlugin.CbfNhgCli import (
   gv,
   IPv4,
   IPv6,
)
import Tracing
from TypeFuture import (
   TacLazyType,
)

th = Tracing.Handle( 'CbfNhgCli' )
t0 = th.trace0

CbfFecMappingSource = TacLazyType( 'Qos::Smash::CbfFecMappingSource' )
CbfForwardingType = TacLazyType( 'Qos::Cbf::CbfForwardingType' )
FecOverrideKey = TacLazyType( 'Qos::Smash::OverrideKey' )
NexthopGroupEntry = TacLazyType( 'NexthopGroup::NexthopGroupEntry' )
NexthopGroupEntryKey = TacLazyType( 'NexthopGroup::NexthopGroupEntryKey' )
NhgIdToFecIdSet = TacLazyType( 'CbfNhg::NhgIdToFecIdSet' )
OverrideKey = TacLazyType( 'CbfNhgExt::CbfNhgOverrideKey' )
TrafficClass = TacLazyType( 'Qos::TrafficClass' )

class CbfNhgOverrideTableBuilder:
   def __init__(
         self,
         addressFamilyFilter: set,
         dscpFilter: set,
         defaultFecFilter: set,
         overrideFecFilter: set,
         defaultNhgFilter: set,
         overrideNhgFilter: set,
         counterFilter: bool,
         cbfNhgOverrideConfig,
         fecOverrideConfig,
         fecOverrideStatus,
         nhgEntryStatus,
         nhgFecIdStatus ):
      t0( 'defaultNhgFilter:', defaultNhgFilter )
      t0( 'overrideNhgFilter:', overrideNhgFilter )
      t0( 'defaultFecFilter:', defaultFecFilter )
      t0( 'overrideFecFilter:', overrideFecFilter )
      t0( 'addressFamilyFilter:', addressFamilyFilter )
      t0( 'dscpFilter:', dscpFilter )
      t0( 'counterFilter:', counterFilter )
      # Use these filter sets to determine which CbfNhgOverride models should be
      # included in a CbfNhgOverrideTable model. A CbfNhgOverride model's attribute
      # must be present in the associated filter to be included in the
      # CbfNhgOverrideTable model. If a filter is empty then there is no filtering
      # for that specific attribute.
      self.addressFamilyFilter = addressFamilyFilter
      self.dscpFilter = dscpFilter
      self.defaultFecFilter = defaultFecFilter
      self.overrideFecFilter = overrideFecFilter
      self.defaultNhgFilter = defaultNhgFilter
      self.overrideNhgFilter = overrideNhgFilter
      # If counterFilter is True, populate the inOctets/inPackets attributes in the
      # CbfNhgOverride model.
      self.counterFilter = counterFilter

      t0( 'cbfNhgOverrideConfig', cbfNhgOverrideConfig )
      t0( 'fecOverrideConfig', fecOverrideConfig )
      t0( 'fecOverrideStatus', fecOverrideStatus )
      t0( 'nhgEntryStatus', nhgEntryStatus )
      t0( 'nhgFecIdStatus', nhgFecIdStatus )
      self.cbfNhgOverrideConfig = cbfNhgOverrideConfig
      self.fecOverrideConfig = fecOverrideConfig
      self.fecOverrideStatus = fecOverrideStatus
      self.nhgEntryStatus = nhgEntryStatus
      self.nhgFecIdStatus = nhgFecIdStatus

   def skipAddressFamily( self, addressFamily: str ) -> bool:
      '''If addressFamilyFilter is empty return False. Otherwise, if "addressFamily"
      does not exist in addressFamilyFilter return True.
      '''
      if not self.addressFamilyFilter:
         return False
      return addressFamily.lower() not in self.addressFamilyFilter

   def skipDscp( self, dscp: int ) -> bool:
      '''If dscpFilter is empty return False. Otherwise, if "dscp" does not exist in
      dscpFilter return True.
      '''
      if not self.dscpFilter:
         return False
      return dscp not in self.dscpFilter

   def skipDefaultFecId( self, fecId: int ) -> bool:
      '''If defaultFecFilter is empty return False. Otherwise, if "fecId" does not
      exist in defaultFecFilter return True.
      '''
      if not self.defaultFecFilter:
         return False
      return fecId not in self.defaultFecFilter

   def skipOverrideFecId( self, fecId: int ) -> bool:
      '''If overrideFecFilter is empty return False. Otherwise, if "fecId" does not
      exist in overrideFecFilter return True.
      '''
      if not self.overrideFecFilter:
         return False
      return fecId not in self.overrideFecFilter

   def skipDefaultNhg( self, nhgName: str ) -> bool:
      '''If defaultNhgFilter is empty return False. Otherwise, if "nhgName" does not
      exist in defaultNhgFilter return True.
      '''
      if not self.defaultNhgFilter:
         return False
      return nhgName not in self.defaultNhgFilter

   def skipOverrideNhg( self, nhgName: str ) -> bool:
      '''If overrideNhgFilter is empty return False. Otherwise, if "nhgName" does not
      exist in overrideNhgFilter return True.
      '''
      if not self.overrideNhgFilter:
         return False
      return nhgName not in self.overrideNhgFilter

   def createOverride(
         self,
         defaultNhgName: str,
         overrideNhgName: str,
         addressFamily: str,
         dscp: int,
         defaultFecId: int,
         ecmp: bool ) -> Optional[ CbfNhgOverride ]:
      '''Construct a FecOverrideKey and look up its corresponding FecOverrideEntry
      in the FecOverrideConfig. Return None if the defaultFecFilter is not empty and
      the defaultFecId is not present. Also return None when the overrideFecFilter
      is not empty and either:
         1. The FecOverrideEntry does not exist
         2. The override FecId is not present in overrideFecFilter
      '''
      t0( 'createOverride', 'defaultNhgName:', defaultNhgName, 'overrideNhgName:',
          overrideNhgName, 'addressFamily:', addressFamily, 'dscp:', dscp,
          'defaultFecId:', defaultFecId )
      # If the default FecId filter is not empty and this default FecId was not
      # present do not display it, it was not chosen.
      if self.skipDefaultFecId( defaultFecId ):
         return None
      override = CbfNhgOverride()
      override.defaultNhg = defaultNhgName
      override.overrideNhg = overrideNhgName
      override.addressFamily = addressFamily
      override.dscp = dscp
      if defaultFecId:
         override.defaultFecId = defaultFecId
         override.ecmp = ecmp
      ft = ( CbfForwardingType.ipv4 if addressFamily == IPv4
             else CbfForwardingType.ipv6 )
      fecOverrideKey = FecOverrideKey( ft, dscp, TrafficClass.invalid, defaultFecId )
      fecOverrideKey.source = CbfFecMappingSource.cbfSourceNhg
      # Get the FecOverrideEntry for this FecOverrideKey from the Smash mounted
      # Qos::Smash::FecOverrideConfig.
      fecOverrideEntry = self.fecOverrideConfig.override.get( fecOverrideKey )
      t0( 'FecOverrideEntry:', fecOverrideEntry )
      if self.overrideFecFilter:
         # Override FecId filter is not empty, and we do not have an override FecId
         # so this OverrideEntry should not be displayed.
         if not fecOverrideEntry:
            return None
         # This override FecId was not chosen in the filter so do not display it.
         elif fecOverrideEntry.fecId not in self.overrideFecFilter:
            return None
         override.overrideFecId = fecOverrideEntry.fecId
      elif fecOverrideEntry:
         override.overrideFecId = fecOverrideEntry.fecId
      fecOverrideStatusEntry = self.fecOverrideStatus.override.get( fecOverrideKey )
      t0( 'FecOverrideStatusEntry:', fecOverrideStatusEntry )
      if fecOverrideStatusEntry:
         override.status = fecOverrideStatusEntry.status
      return override

   def __build(
         self,
         addressFamily: str,
         defaultNhgName: str,
         overrideNhgName: str,
         dscpIter: Iterator[ int ],
         fecIds: list[ int ] ) -> list[ CbfNhgOverride ]:
      overrides = []
      if ( self.skipDefaultNhg( defaultNhgName ) or
           self.skipOverrideNhg( overrideNhgName ) ):
         t0( 'Nexthop group name excluded' )
         return overrides
      if self.skipAddressFamily( addressFamily ):
         t0( 'Address family:', addressFamily, 'excluded' )
         return overrides
      # Only process DSCP values which should not be skipped, skipDscp returns False.
      # Iteration of DscpBitmap is in strictly non-decreasing order.
      for dscp in filter( lambda dscp: not self.skipDscp( dscp ), dscpIter ):
         for ( fecId, isEcmp ) in fecIds:
            if override := self.createOverride(
                  defaultNhgName,
                  overrideNhgName,
                  addressFamily=addressFamily,
                  dscp=dscp,
                  defaultFecId=fecId,
                  ecmp=isEcmp ):
               t0( 'Adding:', override )
               overrides.append( override )
         # If we did not create any overrides then there are no default FECs and no
         # override FECs present for this CbfNhgOverrideEntry. We still want to
         # display the default nexthop group name, override nexthop group name,
         # address family, and DSCP so it is clear a CbfNhgOverrideEntry is
         # configured.
         if not overrides:
            if override := self.createOverride(
                  defaultNhgName,
                  overrideNhgName,
                  addressFamily=addressFamily,
                  dscp=dscp,
                  # Set the default FecId to 0 (invalid) so we do not populate the
                  # defaultFecId, overrideFecId, and ecmp attributes in the
                  # CbfNhgOverrideModel.
                  defaultFecId=0,
                  ecmp=False ):
               t0( 'Adding:', override )
               overrides.append( override )
      t0( addressFamily, 'overrides:', overrides )
      return overrides

   def build( self ) -> CbfNhgOverrideTable:
      '''Populate a CbfNhgOverrideTable model with CbfNhgOverride model objects by
      iterating over the CbfNhgOverride::nhgOverride collection and
      looking up FecOverrideEntry objects from the FecOverrideConfig::override
      collection.
      '''
      cbfNhgOverrideTable = CbfNhgOverrideTable( _renderCounters=self.counterFilter )
      # Process each CbfNhgOverrideEntry in the Smash mounted
      # CbfNhgExt::CbfNhgOverrideConfig.
      for overrideEntry in self.cbfNhgOverrideConfig.nhgOverride.values():
         t0( 'CbfNhgOverrideEntry:', overrideEntry )
         # Get the default nexthop group's NexthopGroupId from the Smash mounted
         # NexthopGroup::EntryStatus.
         defaultNhgEntry = self.nhgEntryStatus.nexthopGroupEntry.get(
            overrideEntry.key.defaultNhg, NexthopGroupEntry() )
         # Get the default FecId's which reference this default NexthopGroupId from
         # the Sysdb mounted CbfNhgExt::NhgFecIdStatus.
         fecIdSet = self.nhgFecIdStatus.nhgFecIdRef.get(
            defaultNhgEntry.nhgId, NhgIdToFecIdSet( 0 ) )
         t0( 'Default NexthopGroupEntry:', defaultNhgEntry )
         defaultFecIds = list( fecIdSet.fecId.items() )
         t0( 'NhgFecIdStatus::nhgFecIdRef:', defaultFecIds )
         defaultNhgName = overrideEntry.key.defaultNhg.nhgName()
         overrideNhgName = overrideEntry.key.overrideNhg.nhgName()
         for addressFamily in [ IPv4, IPv6 ]:
            dscpIter = ( overrideEntry.v4Dscps.dscp if addressFamily == IPv4
                         else overrideEntry.v6Dscps.dscp )
            overrides = self.__build( addressFamily, defaultNhgName, overrideNhgName,
                                      dscpIter, defaultFecIds )
            cbfNhgOverrideTable.overrides.extend( overrides )
      return cbfNhgOverrideTable

def showCbfNhgOverrideCmdHandler( mode, args ):
   t0( 'args:', args )

   defaultNhgFilter = set( args.get( 'DEFAULT_NAME', [] ) )
   overrideNhgFilter = set( args.get( 'OVERRIDE_NAME', [] ) )
   defaultFecFilter = set( args.get( 'DEFAULT_ID', [] ) )
   overrideFecFilter = set( args.get( 'OVERRIDE_ID', [] ) )
   addressFamilyFilter = set( args.get( 'AF', [] ) )
   dscpFilter = set( args.get( 'DSCP', [] ) )
   counterFilter = 'counters' in args

   cbfNhgOverrideTableBuilder = CbfNhgOverrideTableBuilder(
      addressFamilyFilter,
      dscpFilter,
      defaultFecFilter,
      overrideFecFilter,
      defaultNhgFilter,
      overrideNhgFilter,
      counterFilter,
      gv.cbfNhgOverrideConfig,
      gv.fecOverrideConfig,
      gv.fecOverrideStatus,
      gv.nhgEntryStatus,
      gv.nhgFecIdStatus )

   cbfNhgOverrideTable = cbfNhgOverrideTableBuilder.build()
   t0( cbfNhgOverrideTable )

   return cbfNhgOverrideTable
