#!/usr/bin/env python3
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Arnet
import Tac
import LazyMount
import SmashLazyMount
import ShowCommand
import CliParser
import CliCommand
import CliMatcher
import BasicCli
import BasicCliModes
import CliToken.Clear
from CliPlugin import MplsCli
from CliPlugin.AleCountersCli import checkCounterFeatureEnabled
from CliPlugin.AleCountersCli import checkCounterFeatureSupported
from CliPlugin.MplsIngressCountersModel import MplsLfibCounterLabelEntry
from CliPlugin.MplsIngressCountersModel import MplsLfibCounters
from CliPlugin.MplsIngressCountersModel import MplsLfibCountersSummary
from TypeFuture import TacLazyType

BoundedMplsLabelStack = TacLazyType( 'Arnet::BoundedMplsLabelStack' )
FapId = TacLazyType( 'FlexCounters::FapId' )
FeatureId = TacLazyType( 'FlexCounters::FeatureId' )
MplsLabel = TacLazyType( 'Arnet::MplsLabel' )
MplsRouteHelper = TacLazyType( 'Mpls::MplsRouteHelper' )
RouteKey = TacLazyType( 'Mpls::RouteKey' )
RouteMetric = TacLazyType( "Mpls::RouteMetric" )
LfibViaType = TacLazyType( 'Mpls::LfibViaType' )
LfibCounterUnallocatedEntry =  \
      TacLazyType( 'Ale::FlexCounter::LfibCounterUnallocatedEntry' )

fcFeatureConfigDir = None
mplsHwStatus = None
mplsLfibClear = None
mplsLfibCounterTable = None
mplsLfibSnapshotTable = None
mplsLfibSourceInfoDir = None
mplsRoutingInfo = None
transitLfib = None
decapLfib = None
lfibCounterUnallocatedSet = None

def mplsLfibCountersGuard( mode, token ):
   '''
   Guard the Mpls Lfib cli commands.
   '''
   if not checkCounterFeatureSupported( FeatureId.MplsLfib ):
      return CliParser.guardNotThisPlatform
   return None

def noResourceGuard( mode, token ):
   """
   Guard function for the 'show mpls lfib counters no-resource' command.
   """
   if MplsCli.mplsHwCapability.mplsLfibFailedAllocationCountersSupported and \
      mplsLfibCountersGuard( mode, token ) is None:
      return None
   else:
      return CliParser.guardNotThisPlatform

nodeCounters = CliCommand.guardedKeyword( 'counters',
   helpdesc="MPLS LFIB counters",
   guard=mplsLfibCountersGuard )

nodeNoResourceKeyword = CliMatcher.KeywordMatcher( 'no-resource',
   helpdesc="MPLS LFIB labels which failed to be allocated dedicated counter"
            "resources" )

nodeSummary = CliCommand.guardedKeyword( 'summary',
   helpdesc="MPLS LFIB counters summary",
   guard=noResourceGuard )

#-------------------------------------------------------------------------
# The "show mpls lfib counters [ { LABELS } ]" command
#-------------------------------------------------------------------------
nodeLabelMatcher = Arnet.MplsLib.labelValMatcher

def willExceedConfiguredCount( labels ):
   maxCount = 1
   if mplsRoutingInfo:
      maxCount = mplsRoutingInfo.mplsLookupLabelCount

   labelsCount = 0
   if labels:
      labelsCount = len( labels )

   return labelsCount > maxCount


def createLabelStack( labels ):
   labelStack = BoundedMplsLabelStack()
   if labels:
      for label in reversed( labels ):
         labelStack.append( label )
   return Tac.const( labelStack )

def mplsLfibSmashCounter( routeKey ):
   counter = mplsLfibCounterTable.counterEntry.get( routeKey )
   snapshot = mplsLfibSnapshotTable.counterEntry.get( routeKey )
   if counter is None:
      pkts = 0
      octets = 0
   elif snapshot is None:
      pkts = counter.pkts
      octets = counter.octets
   else:
      pkts = counter.pkts - snapshot.pkts
      octets = counter.octets - snapshot.octets
   return pkts, octets

def labelUnprogrammed( routeKey ):
   unprogrammed = True
   route = mplsHwStatus.route.get( routeKey )
   if route:
      unprogrammed = route.unprogrammed
      if not unprogrammed:
         adj = mplsHwStatus.adjacencyBase( route.adjBaseKey )
         if adj:
            unprogrammed = adj.unprogrammed
   return unprogrammed

def addCounterInfoForRouteKey( mplsLfibCounters, mplsLfibMultiLabelCounters,
                               routeKey, lfibTable=None, noResourceOnly=False ):
   labelStack = routeKey.labelStack

   if noResourceOnly and lfibCounterUnallocatedSet.lfibCounterUnallocatedSetEntry.\
                       get( LfibCounterUnallocatedEntry( labelStack ) ) is None and \
                       routeKey != RouteKey():
      return

   unprogrammed = False
   if lfibTable is transitLfib:
      if labelStack != BoundedMplsLabelStack():
         unprogrammed = labelUnprogrammed( routeKey )
   numPkts, numBytes = mplsLfibSmashCounter( routeKey )
   source = None
   sourceDetails = None
   if lfibTable:
      lfibRoute = lfibTable.lfibRoute.get( routeKey )
      if lfibRoute:
         source = MplsRouteHelper.lfibSourceToCapiVal( lfibRoute.source )
         lfibSourceInfo = mplsLfibSourceInfoDir.get( str( lfibRoute.source ) )
         if lfibSourceInfo:
            lfibRouteInfo = lfibSourceInfo.routeInfo.get( routeKey )
            if lfibRouteInfo:
               sourceDetails = lfibRouteInfo.description or None

   if MplsCli.mplsHwCapability.mplsLfibFailedAllocationCountersSupported and \
      lfibCounterUnallocatedSet.lfibCounterUnallocatedSetEntry.get(
                                        LfibCounterUnallocatedEntry( labelStack ) ):
      counterInfo = MplsLfibCounterLabelEntry( totalPackets=numPkts,
                                               totalBytes=numBytes,
                                               unprogrammed=unprogrammed,
                                               source=source,
                                               sourceDetails=sourceDetails,
                                               failedCounterAlloc=True )
   else:
      counterInfo = MplsLfibCounterLabelEntry( totalPackets=numPkts,
                                            totalBytes=numBytes,
                                            unprogrammed=unprogrammed,
                                            source=source,
                                            sourceDetails=sourceDetails )

   # Empty label stack indicates non-countable labels
   if routeKey == RouteKey():
      mplsLfibCounters[ MplsLabel.null ] = counterInfo
   elif labelStack.stackSize == 1:
      mplsLfibCounters[ labelStack.top() ] = counterInfo
   else:
      mplsLfibMultiLabelCounters[ labelStack.cliShowString() ] = counterInfo

def showMplsLfibCounters( mode, args, noResourceOnly=False ):
   '''
   Return Mpls Lfib counter data.
   '''
   if not checkCounterFeatureEnabled( FeatureId.MplsLfib ):
      mode.addError( "hardware counter feature mpls lfib in should be "
                     "enabled first" )
      return MplsLfibCounters()

   labels = args.get( 'LABELS' )

   # limit num of labels
   if willExceedConfiguredCount( labels ):
      mode.addError( MplsCli.MAX_LABEL_EXCEEDED )
      return MplsLfibCounters()

   mplsLfibCounters = {}
   mplsLfibMultiLabelCounters = {}

   if not labels:
      for routeKey in transitLfib.lfibRoute:
         addCounterInfoForRouteKey( mplsLfibCounters, mplsLfibMultiLabelCounters,
                                    routeKey, lfibTable=transitLfib,
                                    noResourceOnly=noResourceOnly )
      if MplsCli.mplsHwCapability.mplsVrfLabelCountersSupported:
         for routeKey in decapLfib.lfibRoute:
            viaSetKey = decapLfib.lfibRoute[ routeKey ].viaSetKey
            if not viaSetKey:
               continue
            viaset = decapLfib.viaSet[ viaSetKey ]
            if not viaset:
               continue
            if viaset.hasViaType( LfibViaType.viaTypeIpLookup ):
               addCounterInfoForRouteKey( mplsLfibCounters,
                                          mplsLfibMultiLabelCounters,
                                          routeKey, lfibTable=decapLfib,
                                          noResourceOnly=noResourceOnly )

      # Add counter information for all non-countable labels.
      addCounterInfoForRouteKey( mplsLfibCounters, mplsLfibMultiLabelCounters,
                                 RouteKey(), noResourceOnly=noResourceOnly )

   else:
      labelArgsRouteKey = RouteKey( createLabelStack( labels ) )
      if transitLfib.lfibRoute.get( labelArgsRouteKey ):
         addCounterInfoForRouteKey( mplsLfibCounters, mplsLfibMultiLabelCounters,
                                    labelArgsRouteKey, lfibTable=transitLfib,
                                    noResourceOnly=noResourceOnly )
      elif MplsCli.mplsHwCapability.mplsVrfLabelCountersSupported and \
           decapLfib.lfibRoute.get( labelArgsRouteKey ):
         viaSetKey = decapLfib.lfibRoute[ labelArgsRouteKey ].viaSetKey
         viaset = decapLfib.viaSet[ viaSetKey ]
         if viaset.hasViaType( LfibViaType.viaTypeIpLookup ):
            addCounterInfoForRouteKey( mplsLfibCounters,
                                       mplsLfibMultiLabelCounters,
                                       labelArgsRouteKey, lfibTable=decapLfib,
                                       noResourceOnly=noResourceOnly )

   return MplsLfibCounters( counters=mplsLfibCounters,
                            multiLabelCounters=mplsLfibMultiLabelCounters )

class ShowMplsLfibCountersCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show mpls lfib counters [ { LABELS } ]'
   data = {
         'mpls' : MplsCli.mplsNodeForShow,
         'lfib' : MplsCli.matcherLfib,
         'counters' : nodeCounters,
         'LABELS' : nodeLabelMatcher,
         }
   privileged = True
   cliModel = MplsLfibCounters

   handler = showMplsLfibCounters

BasicCli.addShowCommandClass( ShowMplsLfibCountersCmd )

def showMplsLfibCountersNoResource( mode, args ):
   '''
   Return Mpls Lfib counter data for labels which failed to have dedicated counter
   resources allocated
   '''

   return showMplsLfibCounters( mode, args, noResourceOnly=True )

class ShowMplsLfibCountersNoResourceCmd( ShowCommand.ShowCliCommandClass ):
   syntax = 'show mpls lfib counters no-resource'
   data = {
         'mpls' : MplsCli.mplsNodeForShow,
         'lfib' : MplsCli.matcherLfib,
         'counters' : nodeCounters,
         'no-resource' : CliCommand.Node( matcher=nodeNoResourceKeyword,
                                          guard=noResourceGuard ),
         }
   privileged = True
   cliModel = MplsLfibCounters

   handler = showMplsLfibCountersNoResource

BasicCli.addShowCommandClass( ShowMplsLfibCountersNoResourceCmd )

# -------------------------------------------------------------------------
# The "show mpls counters summary" command
# -------------------------------------------------------------------------
def showMplsLfibCountersSummary( mode, args ):
   numPkts, numBytes = mplsLfibSmashCounter( RouteKey() )
   countableLabelStackEntries = 0
   noresourceLabelStackEntries = 0

   labelStacks = set()

   for routeKey in transitLfib.lfibRoute:
      if routeKey.labelStack != BoundedMplsLabelStack():
         labelStacks.add( routeKey.labelStack )

   if MplsCli.mplsHwCapability.mplsVrfLabelCountersSupported:
      for routeKey in decapLfib.lfibRoute:
         viaSetKey = decapLfib.lfibRoute[ routeKey ].viaSetKey
         if not viaSetKey:
            continue
         viaset = decapLfib.viaSet[ viaSetKey ]
         if not viaset:
            continue
         if viaset.hasViaType( LfibViaType.viaTypeIpLookup ):
            if routeKey.labelStack != BoundedMplsLabelStack():
               labelStacks.add( routeKey.labelStack )

   totalMplsLabels = len( labelStacks )

   for labelStack in labelStacks:
      if MplsCli.mplsHwCapability.mplsLfibFailedAllocationCountersSupported and \
           lfibCounterUnallocatedSet.lfibCounterUnallocatedSetEntry.get(
                              LfibCounterUnallocatedEntry( labelStack ) ):
         noresourceLabelStackEntries += 1
      else:
         countableLabelStackEntries += 1

   if MplsCli.mplsHwCapability.mplsLfibFailedAllocationCountersSupported:
      return MplsLfibCountersSummary( totalMplsLabels=totalMplsLabels,
                                   countableMplsLabels=countableLabelStackEntries,
                                   noResourceMplsLabels=noresourceLabelStackEntries,
                                   agreggateNoResourceBytes=numBytes,
                                   agreggateNoResourcePkts=numPkts )
   else:
      return MplsLfibCountersSummary( totalMplsLabels=totalMplsLabels,
                                      agreggateNoResourceBytes=numBytes,
                                      agreggateNoResourcePkts=numPkts )

class ShowMplsLfibCountersSummaryCmd( ShowCommand.ShowCliCommandClass ):
   '''
   Display Mpls Lfib counter summary.
   '''
   syntax = 'show mpls lfib counters summary'
   data = {
         'mpls' : MplsCli.mplsNodeForShow,
         'lfib' : MplsCli.matcherLfib,
         'counters' : nodeCounters,
         'summary' : nodeSummary,
         }
   privileged = True
   cliModel = MplsLfibCountersSummary

   handler = showMplsLfibCountersSummary

BasicCli.addShowCommandClass( ShowMplsLfibCountersSummaryCmd )

#-------------------------------------------------------------------------
# The "clear mpls lfib counters [ START_LABEL [ END_LABEL ] ]" command
#-------------------------------------------------------------------------

def clearMplsLfibCounters( mode, args, multiLabel=False ):
   '''
   Clear Mpls Lfib counters.
   '''
   if not checkCounterFeatureEnabled( FeatureId.MplsLfib ):
      mode.addError( "hardware counter feature mpls lfib in should be "
                     "enabled first" )
      return
   labelStart = args.get( 'START_LABEL' )
   labelEnd = args.get( 'END_LABEL' )

   clearReq = mplsLfibClear.clearLfibEntryCountersRequest
   clearReq.clear()

   def createClearReqs( lfibTable=None, labelStart=None, labelEnd=None,
                        clearReq=None ):
      if labelStart is None:
         # Clear all labels
         if not multiLabel:
            # Clear the non-countable labels
            clearReq[ BoundedMplsLabelStack() ] = True
         for routeKey in lfibTable.lfibRoute:
            labelStack = routeKey.labelStack
            # 'clear mpls lfib counters' clears all counters
            # 'clear mpls lfib counters multi-label' clears only multi-label counters
            if not multiLabel or labelStack.stackSize > 1:
               if labelStack not in clearReq:
                  clearReq[ labelStack ] = True
      else:
         # Clear range of labels.
         labelEnd = labelStart if labelEnd is None else labelEnd
         for routeKey in lfibTable.lfibRoute:
            labelStack = routeKey.labelStack
            # Only clear single label route counters when specifying range
            if labelStack.stackSize == 1:
               label = routeKey.topLabel
               # pylint: disable-next=chained-comparison
               if label >= labelStart and label <= labelEnd:
                  if labelStack not in clearReq:
                     clearReq[ labelStack ] = True

   createClearReqs( lfibTable=transitLfib, labelStart=labelStart,
                    labelEnd=labelEnd, clearReq=clearReq )
   if MplsCli.mplsHwCapability.mplsVrfLabelCountersSupported:
      createClearReqs( lfibTable=decapLfib, labelStart=labelStart,
                       labelEnd=labelEnd, clearReq=clearReq )

class ClearMplsLfibCountersMultiLabelStartEndCmd( CliCommand.CliCommandClass ):
   syntax = 'clear mpls lfib counters [ START_LABEL [ END_LABEL ] ]'
   data = {
         'clear' : CliToken.Clear.clearKwNode,
         'mpls' : MplsCli.mplsMatcherForClear,
         'lfib' : MplsCli.matcherLfib,
         'counters' : nodeCounters,
         'START_LABEL' : nodeLabelMatcher,
         'END_LABEL' : nodeLabelMatcher,
         }

   handler = clearMplsLfibCounters

BasicCliModes.EnableMode.addCommandClass(
   ClearMplsLfibCountersMultiLabelStartEndCmd )

#-------------------------------------------------------------------------
# The "clear mpls lfib counters multi-label [ { LABELS } ]" command
#-------------------------------------------------------------------------
nodeMultiLabel = CliCommand.guardedKeyword( "multi-label",
   helpdesc="Specify a multi-label entry (labels ordered top-most to bottom-most)",
   guard=MplsCli.mplsMultiLabelLookupGuard )

def clearMplsLfibCountersMultiLabel( mode, args ):
   '''
   Clear Mpls Lfib counters for a particular label stack or all multi-label route
   counters if no label stack is specified
   '''
   labels = args.get( 'LABELS' )
   if willExceedConfiguredCount( labels ):
      mode.addError( MplsCli.MAX_LABEL_EXCEEDED )
      return

   clearReq = mplsLfibClear.clearLfibEntryCountersRequest
   clearReq.clear()

   if labels is None:
      # Clear all multi-label route counters
      clearMplsLfibCounters( mode, args, multiLabel=True )
   else:
      # Clear counter for specific label stack.
      clearReq[ createLabelStack( labels ) ] = True

class ClearMplsLfibCountersMultiLabelCmd( CliCommand.CliCommandClass ):
   syntax = "clear mpls lfib counters multi-label [ { LABELS } ]"
   data = {
      'clear' : CliToken.Clear.clearKwNode,
      'mpls' : MplsCli.mplsMatcherForClear,
      'lfib' : MplsCli.matcherLfib,
      'counters' : nodeCounters,
      'multi-label' : nodeMultiLabel,
      'LABELS' : nodeLabelMatcher,
      }

   handler = clearMplsLfibCountersMultiLabel

BasicCliModes.EnableMode.addCommandClass( ClearMplsLfibCountersMultiLabelCmd )

def Plugin( em ):
   global mplsLfibCounterTable
   global mplsLfibSnapshotTable
   global mplsLfibClear
   global mplsHwStatus
   global mplsRoutingInfo
   global fcFeatureConfigDir
   global mplsLfibSourceInfoDir
   global transitLfib
   global decapLfib
   global lfibCounterUnallocatedSet
   mplsLfibClear = LazyMount.mount( em,
                                    'hardware/counter/mplsLfib/clear/config',
                                    "Ale::FlexCounter::MplsLfibCliClearConfig",
                                    "w" )

   mplsHwStatus = LazyMount.mount( em,
                                   "routing/hardware/mpls/status",
                                   "Mpls::Hardware::Status",
                                   "r" )

   mplsRoutingInfo = LazyMount.mount( em,
                                      "routing/mpls/routingInfo/status",
                                      "Mpls::RoutingInfo",
                                      "r" )

   fcFeatureConfigDir = LazyMount.mount( em,
         "flexCounter/featureConfigDir/cliAgent",
         "Ale::FlexCounter::FeatureConfigDir", 'r' )

   mplsLfibSourceInfoDir = LazyMount.mount( em,
         "routing/mpls/lfibSourceInfo",
         "Tac::Dir",
         "riS" )

   readerInfo = SmashLazyMount.mountInfo( 'reader' )
   keyshadowInfo = SmashLazyMount.mountInfo( 'keyshadow' )

   # Mount the Mpls Lfib current counter smash.
   # pylint: disable-next=consider-using-f-string
   mountPath = 'flexCounters/counterTable/MplsLfib/%u' % ( FapId.allFapsId )
   mplsLfibCounterTable = SmashLazyMount.mount( em, mountPath,
                                          "Ale::FlexCounter::MplsLfibCounterTable",
                                          readerInfo )

   # Mount the Mpls Lfib snapshot counter smash.
   # pylint: disable-next=consider-using-f-string
   mountPath = 'flexCounters/snapshotTable/MplsLfib/%u' % ( FapId.allFapsId )
   mplsLfibSnapshotTable = SmashLazyMount.mount( em, mountPath,
                                          "Ale::FlexCounter::MplsLfibCounterTable",
                                          readerInfo )
   transitLfib = SmashLazyMount.mount( em, "mpls/transitLfib", "Mpls::LfibStatus",
                                       readerInfo )

   decapLfib = SmashLazyMount.mount( em, "mpls/decapLfib", "Mpls::LfibStatus",
                                     keyshadowInfo )

   # Mount the collection of labelStacks that failed counter allocation.
   mountPath = 'hardware/counter/mplsLfib/counterunallocatedset'
   lfibCounterUnallocatedSet = SmashLazyMount.mount( em, mountPath,
                                       "Ale::FlexCounter::LfibCounterUnallocatedSet",
                                       readerInfo )
