#!/usr/bin/env python3
# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pkgdeps: rpmwith %{_libdir}/libRsvpLibSysdbTypes.so*

from CliDynamicSymbol import CliDynamicPlugin
from CliModel import UnknownEntityError
from CliPlugin.AleCountersCli import checkCounterFeatureEnabled
from CliPlugin.EthIntfCli import EthIntf
from CliPlugin.MplsCli import showMplsConfigWarnings
from CliPlugin import IntfCli
from CliPlugin.SubIntfCli import SubIntf
from CliPlugin.VlanIntfCli import VlanIntf
from IpLibConsts import DEFAULT_VRF
import LazyMount
from RoutingIntfUtils import isManagement
from TypeFuture import TacLazyType
import SmashLazyMount
from Arnet.MplsLib import labelListToMplsLabelStack
import Tac

MplsCliModel = CliDynamicPlugin( "MplsCliModel" )

FapId = TacLazyType( 'FlexCounters::FapId' )
FeatureId = TacLazyType( 'FlexCounters::FeatureId' )
RsvpCliConfig = TacLazyType( 'Rsvp::RsvpCliConfig' )
RouteKey = TacLazyType( 'Mpls::RouteKey' )

ldpConfigColl = None
mplsInterfaceCounterTable = None
mplsInterfaceSnapshotTable = None
mplsRoutingConfig = None
routingHwMplsStatus = None
transitLfib = None
rsvpCliConfig = None
mplsVskFecIdCliCollection = None

#--------------------------------------------------------------------------------
# show mpls interface [ INTFS ]
#--------------------------------------------------------------------------------

def generateInterfaceModel( mode, args ):
   intfs = args.get( 'INTFS' )
   # FIXME: Pass "module" argument mod=xxx to getAll(), BUG374639
   intfTypes = ( EthIntf, VlanIntf, SubIntf )
   allIntfs = IntfCli.Intf.getAll( mode, intf=intfs, intfType=intfTypes )
   # Filter Management interfaces and those that do not support routing
   allIntfs = [ i for i in allIntfs if i.routingCurrentlySupported()
                                       and not isManagement( i.name ) ]
   if not allIntfs:
      if intfs:
         msg = f"{next( iter( intfs ) )} does not support MPLS/IP"
      else:
         msg = "No MPLS/IP capable interfaces"
      raise UnknownEntityError( msg )

   ldpConfig = ldpConfigColl.config.get( DEFAULT_VRF )
   ldpEnabled = ldpConfig is not None and ldpConfig.enabled
   rsvpEnabled = rsvpCliConfig is not None and rsvpCliConfig.enabled
   mplsInterfaceInfo = {}
   for intf in allIntfs:
      # Set MPLS configured
      mplsConfigured = True
      if not mplsRoutingConfig.mplsRouting or \
         mplsRoutingConfig.mplsRoutingDisabledIntf.get( intf.name ):
         mplsConfigured = False

      # Set LDP enabled
      # both Ldp and Mpls must be enabled to consider this interface as Ldp enabled
      ldpIntfEnabled = mplsConfigured and ldpEnabled
      # check the interface is in the selected list, if only running on select intfs
      if ldpIntfEnabled and ldpConfig.onlyLdpEnabledIntfs:
         ldpIntfEnabled = intf.name in ldpConfig.ldpEnabledIntf

      # Set RSVP enabled
      rsvpIntfEnabled = (
            mplsConfigured and
            rsvpEnabled and
            intf.name not in mplsRoutingConfig.mplsRoutingDisabledIntf )

      # Put it all together
      mplsInterfaceInfo[ intf.name ] = MplsCliModel.MplsInterfaceInfo(
         status=intf.getIntfState(),
         mplsConfigured=mplsConfigured,
         ldpConfigured=ldpIntfEnabled,
         rsvpConfigured=rsvpIntfEnabled )

   return MplsCliModel.MplsInterfaceModel( intfs=mplsInterfaceInfo )

#--------------------------------------------------------------------------------
# show mpls interface [ INTFS ] counters
#--------------------------------------------------------------------------------

def mplsInterfaceSmashCounter( intf ):
   counter = mplsInterfaceCounterTable.counterEntry.get( intf )
   snapshot = mplsInterfaceSnapshotTable.counterEntry.get( intf )
   if counter is None:
      pkts = 0
      octets = 0
   elif snapshot is None:
      pkts = counter.pkts
      octets = counter.octets
   else:
      pkts = counter.pkts - snapshot.pkts
      octets = counter.octets - snapshot.octets
   return pkts, octets

def showInterfacesMplsCounter( mode, args ):
   intf = args.get( 'INTFS' )
   mod = None # FIXME, BUG374639
   if not checkCounterFeatureEnabled( FeatureId.MplsInterfaceIngress ):
      feature = "hardware counter feature mpls interface in"
      mode.addError( f'"{feature}" should be enabled first' )
      return None
   mplsInterfaceCounters = MplsCliModel.MplsInterfaceIngressCounters()
   intfs = IntfCli.counterSupportedIntfs( mode, intf, mod )
   if not intfs:
      return mplsInterfaceCounters
   # Filter out non-'Ethernet' interfaces.
   intfs = ( i.name for i in intfs if i.name.startswith( 'Ethernet' ) )
   for intf in intfs:
      if not mplsRoutingConfig.mplsRouting or \
          mplsRoutingConfig.mplsRoutingDisabledIntf.get( intf ):
         continue
      intfCounters = mplsInterfaceCounters.MplsInterfaceIngressCounters()
      intfCounters.mplsInPkts, intfCounters.mplsInOctets = \
            mplsInterfaceSmashCounter( intf )
      mplsInterfaceCounters.interfaces[ intf ] = intfCounters
   return mplsInterfaceCounters

def ShowMplsRouteAckCmd_handler( mode, args ):
   allRoutes = 'all' in args
   labels = args.get( 'LABELS' )

   showMplsConfigWarnings( mode )

   routes = {}
   transitLfibRoutes = transitLfib.lfibRoute

   # Create a superset of routeKeys from the transitLfib and mpls hw route status
   # as we want to detect when keys appear in one collection but not the other
   routeKeys = set( transitLfibRoutes )
   routeKeys.update( routingHwMplsStatus.route )
   if labels:
      labelStack = labelListToMplsLabelStack( labels )
      rk = Tac.const( RouteKey( labelStack ) )
      if rk in routeKeys:
         routeKeys = [ rk ]
      else:
         routeKeys = []

   for rk in routeKeys:
      requestedRouteVersion = None
      acknowledgedRouteVersion = None
      installedFecId = None
      installedFecVersion = None
      fecProgrammingStatus = None
      labels = rk.labelStack.cliShowString()

      # Set the requested route version if routeKey is found in the LFIB
      lfibRoute = transitLfibRoutes.get( rk )
      if lfibRoute and lfibRoute.versionId:
         requestedRouteVersion = lfibRoute.versionId

      # Set all the other acknowledged params if the routeKey is found in the hw
      # route status
      hwRoute = routingHwMplsStatus.route.get( rk )
      if hwRoute:
         acknowledgedRouteVersion = hwRoute.routeVersionId
         # If FEC is not versioned, we won't publish related fields in hwRoute
         # because we don't do integrated ACK in that case and the fec info could be
         # inaccurate. However, we may still want to display the inaccurate info of
         # installed FecId to align with another show command like `show mpls
         # hardware adj` and we have another collection for that purpose.
         if hwRoute.fecVersionId:
            installedFecId = hwRoute.fecId
            installedFecVersion = hwRoute.fecVersionId
            fecProgrammingStatus = hwRoute.fecHwProgrammingState
         else:
            mplsVskFecIds = \
               mplsVskFecIdCliCollection.mplsVskFecIdEntry.get( lfibRoute.viaSetKey )
            if mplsVskFecIds and mplsVskFecIds.installedFecId:
               installedFecId = mplsVskFecIds.installedFecId

      # Print a row for the routeKey if any of the following are true:
      # 1. The `all` subcommand is provided - in this case all routeKeys are printed
      # 2. The requested route versionId and acknowledged route versionId don't match
      # 3. The routeKey appears in the LFIB but not in the programmedMplsRoute coll
      # 4. The routeKey appears in the programmedMplsRoute coll but not in the LFIB
      mismatchedVersions = requestedRouteVersion != acknowledgedRouteVersion
      if allRoutes or mismatchedVersions:
         entry = MplsCliModel.MplsRouteAckTableEntry(
            requestedId=requestedRouteVersion,
            acknowledgedId=acknowledgedRouteVersion,
            installedFecId=installedFecId,
            installedFecVersion=installedFecVersion,
            fecProgrammingStatus=fecProgrammingStatus
         )
         routes[ labels ] = entry
   return MplsCliModel.MplsRouteAckTable( routes=routes )

def Plugin( em ):
   global ldpConfigColl
   global mplsInterfaceCounterTable
   global mplsInterfaceSnapshotTable
   global mplsRoutingConfig
   global routingHwMplsStatus
   global rsvpCliConfig
   global transitLfib
   global mplsVskFecIdCliCollection

   # Mount the ldp config.
   ldpConfigColl = LazyMount.mount( em,
                                    "mpls/ldp/ldpConfigColl",
                                    "Ldp::LdpConfigColl",
                                    "r" )

   mplsRoutingConfig = LazyMount.mount( em,
                                        "routing/mpls/config",
                                        "Mpls::Config",
                                        "r" )

   # Mount the RSVP config
   rsvpCliConfig = LazyMount.mount( em,
                                    RsvpCliConfig.mountPath,
                                    "Rsvp::RsvpCliConfig",
                                    "r" )

   # Mount the RSVP config
   routingHwMplsStatus = LazyMount.mount( em,
                                          "routing/hardware/mpls/status",
                                          "Mpls::Hardware::Status",
                                          "r" )

   readerInfo = SmashLazyMount.mountInfo( 'reader' )

   # Mount the Mpls interface current counter smash.
   # pylint: disable-next=consider-using-f-string
   mountPath = 'flexCounters/counterTable/MplsInterfaceIngress/%u' % (
         FapId.allFapsId )
   mplsInterfaceCounterTable = SmashLazyMount.mount( em, mountPath,
         "Ale::FlexCounter::MplsInterfaceIngressCounterTable", readerInfo )

   # Mount the Mpls interface snapshot counter smash.
   # pylint: disable-next=consider-using-f-string
   mountPath = 'flexCounters/snapshotTable/MplsInterfaceIngress/%u' % (
         FapId.allFapsId )
   mplsInterfaceSnapshotTable = SmashLazyMount.mount( em, mountPath,
         "Ale::FlexCounter::MplsInterfaceIngressCounterTable", readerInfo )

   # Mount the transit LFIB
   transitLfib = SmashLazyMount.mount( em, 'mpls/transitLfib', 'Mpls::LfibStatus',
         readerInfo )

   # Mount the vskFecId collection
   mplsVskFecIdCliCollection = SmashLazyMount.mount( em,
      "mpls/mplsVskFecIdColl", "Smash::Mpls::MplsVskFecIdCliCollection", readerInfo )
