#!/usr/bin/env python3
# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import sys

import AgentDirectory
from CliModel import cliPrinted
from CliPlugin.MplsModel import MplsBindingsModel
from IpLibConsts import DEFAULT_VRF
import LazyMount
import SharedMem
import Smash
from TypeFuture import TacLazyType

BgpMplsBindingsHelper = TacLazyType( 'BgpShowCli::BgpMplsBindingsHelper' )
CommonLibConsumerSm = TacLazyType( 'CommonLibSmash::CommonLibConsumerSm' )
FwdEqvClass = TacLazyType( 'Mpls::FwdEqvClass' )
IpGenPrefix = TacLazyType( 'Arnet::IpGenPrefix' )
LabelBindingTable = TacLazyType( 'Mpls::LabelBindingTable' )
LabelBindingTableColl = TacLazyType( 'Mpls::LabelBindingTableColl' )
RouterId = TacLazyType( 'Mpls::RouterId' )

class BgpLabelBindingsCliHelper:
   """ Helper to simplify the rendering of the MPLS label bindings produced by the
       Bgp agent (and consumed by the Mpls agent). This class takes care of mounting
       the required Smash tables, instantiating the CommonLibConsumerSm's to copy
       those tables to local memory, and then rendering the bindings as desired.
   """
   def __init__( self, entityManager, lbtPath, lbtPathPeer=None ):
      """ Initialize the helper to retrieve the local and peer label bindings from
          the Smash tables at the specified 'lbtPath'. 

          If 'lbtPathPeer' is specified, we will render the peer label bindings
          present at that path (rather than the peer bindings at 'lbtPath').
      """
      self.entityManager = entityManager
      self.lbtPath = lbtPath
      self.lbtPathPeer = lbtPathPeer

      self.mplsRoutingConfig = None
      self.routingVrfInfoDir = None
      self.lbt = None
      self.lbtPeer = None

      self.lbtConsumerSm = None
      self.lbtConsumerSmPeer = None

      self.doMounts_()

   def doMounts_( self ):
      """ Mount the required Smash/Sysdb tables.
      """
      # Used to determine if MPLS routing is enabled.
      self.mplsRoutingConfig = LazyMount.mount(
         self.entityManager, "routing/mpls/config", "Mpls::Config", "r" )

      # Used to determine if IP routing is enabled.
      self.routingVrfInfoDir = LazyMount.mount(
         self.entityManager, "routing/vrf/routingInfo/status", "Tac::Dir", "ri" )

      # Mount the required LBT Smash tables.
      lbtType = "CommonLibSmash::LabelBindingTable"
      shmemEm = SharedMem.entityManager( sysdbEm=self.entityManager )
      self.lbt = shmemEm.doMount(
         self.lbtPath, lbtType, Smash.mountInfo( 'keyshadow' ) )
      if self.lbtPathPeer:
         self.lbtPeer = shmemEm.doMount(
            self.lbtPathPeer, lbtType, Smash.mountInfo( 'keyshadow' ) )

   def showConfigWarnings( self, mode ):
      """ Displays some warnings/hints to the user if the required agents aren't
          running or MPLS/IP routing isn't enabled.
      """
      if not AgentDirectory.agentIsRunning( mode.entityManager.sysname(), 'Bgp' ):
         mode.addWarning( "Agent 'Bgp' is not running" )

      if not AgentDirectory.agentIsRunning( mode.entityManager.sysname(), 'Mpls' ):
         mode.addWarning( "Agent 'Mpls' is not running" )

      if not self.mplsRoutingConfig.mplsRouting:
         mode.addWarning( "MPLS routing is not enabled" )

      routingInfo = self.routingVrfInfoDir.get( DEFAULT_VRF )
      if not ( routingInfo and routingInfo.routing ):
         mode.addWarning( "IP routing is not enabled" )

   def renderBindings( self, fmt, detail=False, summary=False, matchPrefix=None,
                       showPeerBindings=True, showRecvCount=True ):
      """ Renders the label bindings in the required format (text or json).

          If 'detail' is True, the displayed bindings will include additional
          information (like the Uptime and source protocol).

          If 'summary' is True, a summary table of the bindings is displayed (one
          line per binding).

          If 'matchPrefix' is specified, only the bindings that match the specified
          prefix will be displayed.

          If 'showPeerBindings' is False, only the local label bindings will be
          displayed.

          if 'showRecvCount' is False, the "Advertisements Received" column in the
          summary output table will not be displayed.
      """
      def createConsumerSm( lbt ):
         assert lbt, 'Missing label binding table'
         clearOnInit = True
         tmpLocalLbt = LabelBindingTable( RouterId() )
         tmpPeerLbtColl = LabelBindingTableColl( "plbtc" )
         return CommonLibConsumerSm( lbt, tmpLocalLbt, tmpPeerLbtColl, clearOnInit )

      # Setup the Consumer SMs to copy the LLBT/PLBT in Smash to local collections.
      if not self.lbtConsumerSm:
         self.lbtConsumerSm = createConsumerSm( self.lbt )
      if not self.lbtConsumerSmPeer and self.lbtPeer:
         self.lbtConsumerSmPeer = createConsumerSm( self.lbtPeer )

      # Identify the local/peer binding tables to render.
      localLbt = self.lbtConsumerSm.localLbt
      if showPeerBindings:
         if self.lbtConsumerSmPeer:
            # Use the peer bindings from the peer-specific tables (if there is one).
            peerLbtColl = self.lbtConsumerSmPeer.peerLbtColl
         else:
            # Use the peer bindings from the normal binding tables.
            peerLbtColl = self.lbtConsumerSm.peerLbtColl
      else:
         # Use an empty table (i.e. so there will be no peer bindings to show).
         peerLbtColl = LabelBindingTableColl( "plbtc-dummy" )

      # Populate the helper with the bindings.
      helper = BgpMplsBindingsHelper()
      if matchPrefix:
         try:
            fec = FwdEqvClass()
            fec.prefix = IpGenPrefix( str( matchPrefix ) )
            helper.populateFecBindingsFromTablesFiltered(
               localLbt, peerLbtColl, detail, fec )
         except IndexError:
            # An invalid prefix matches nothing.
            pass
      else:
         helper.populateFecBindingsFromTables( localLbt, peerLbtColl, detail )

      # Look for conflicts.
      helper.populateLocalConflicts()

      # Now render the bindings.
      sys.stdout.flush()
      fd = sys.stdout.fileno()
      helper.renderBindings( fd, fmt, summary, showRecvCount )

      return cliPrinted( MplsBindingsModel )
