#!/usr/bin/env python3
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from AleFlexCounterTableMounter import tableMountPath
import Arnet
import BasicCli
import CliCommand
import CliExtensions
import CliMatcher
import CliParser
from CliPlugin import AleCountersCli
# pylint: disable-next=consider-using-from-import
import CliPlugin.IpAddrMatcher as IpAddrMatcher
# pylint: disable-next=consider-using-from-import
import CliPlugin.IpGenAddrMatcher as IpGenAddrMatcher
from CliPlugin.SrTePolicyCli import srTeTunnelCountersHook
from CliPlugin import MplsCli
from CliPlugin.MplsDebugCli import tunnelTypeValMatcher
from CliPlugin import TunnelCli
from CliPlugin.TunnelFibCli import getMplsViaModelFromTunnelVia
from CliPlugin import TunnelModels

from CliPlugin.MplsTunnelCountersModel import (
   MplsTunnelCountersEntry,
   MplsTunnelCounters,
   MplsTunnelCountersRateEntry,
   MplsTunnelCountersRate,
)
import CliToken
import LazyMount
from MplsTypeLib import (
   showTunnelFibIgnoredTunnelTypes,
   tunnelTypeXlate,
)
import ShowCommand
import SmashLazyMount
import Tac

from operator import attrgetter

mplsTunnelClear = None
mplsTunnelCounterTable = None
mplsTunnelSnapshotTable = None
mplsTunnelRateCounterTable = None
protocolTunnelNameStatus = None

CounterIndex = Tac.Type( 'FlexCounters::CounterIndex' )
MplsTunnelFeatureId = Tac.Type( 'FlexCounters::FeatureId' ).MplsTunnel

def mplsTunnelCountersGuard( mode, token ):
   '''
   Guard the MPLS tunnel counters CLI commands.
   '''
   if not AleCountersCli.checkCounterFeatureSupported( MplsTunnelFeatureId ):
      return CliParser.guardNotThisPlatform

   return None

def ensureMplsTunnelCountersEnabled( mode ):
   '''
   Return an error if the MPLS tunnel counters are not enabled.
   '''
   if not AleCountersCli.checkCounterFeatureConfigured( MplsTunnelFeatureId ):
      mode.addErrorAndStop(
         "'hardware counter feature mpls tunnel' should be enabled first" )

# Platform-specific hook indicating if given counter is actually programmed
mplsTunnelCounterActiveHook = CliExtensions.CliHook()

def getTunnelName( tunnelId ):
   tunnelType = Tac.Value( "Tunnel::TunnelTable::TunnelId", tunnelId ).tunnelType()
   if tunnelType in protocolTunnelNameStatus.localTunnelNameStatus:
      tunnelNameStatus = protocolTunnelNameStatus.localTunnelNameStatus[ tunnelType ]
      return tunnelNameStatus.tunnelIdToName.get( tunnelId )
   else:
      return None

def getTunnelNameList( mode ):
   localTunnelNameStatusColl = protocolTunnelNameStatus.localTunnelNameStatus
   return [
      name
      for localTunnelNameStatus in localTunnelNameStatusColl.values()
      for name in localTunnelNameStatus.tunnelIdToName.values()
   ]

def matchesTunnelName( tunnelId, name ) -> bool:
   """
   Returns whether the tunnel name matches the given name or, for rsvpLerSubTunnels,
   whether the name or super tunnel name matches.
   """
   tunnelType = TunnelCli.getTunnelTypeFromTunnelId( tunnelId )
   tunnelName = getTunnelName( tunnelId )
   if tunnelName is None:
      return False

   exactMatch = tunnelName == name

   # Show any sub tunnels that are under the super tunnel of the given name.
   # Sub tunnels are named "<Super Tunnel>-<index>".
   superTunnelNameMatch = (
      tunnelType == "rsvpLerSubTunnel" and tunnelName.rsplit( "-", 1 )[ 0 ] == name
   )

   return exactMatch or superTunnelNameMatch

def counterActiveForTunnelId( tunnelId ):
   counterActive = False
   for hook in mplsTunnelCounterActiveHook.extensions():
      counterActive = hook( tunnelId )
      break
   return counterActive

def getMplsTunnelCountersFromSmash( tunnelId ):
   # In case counter feature is not enabled, return None.
   if not AleCountersCli.checkCounterFeatureConfigured( MplsTunnelFeatureId ):
      return ( None, None, None )
   # counterInUse indicates counter availability status for given tunnelId.
   counterInUse = counterActiveForTunnelId( tunnelId )
   txPackets, txBytes = AleCountersCli.getCurrentCounter( CounterIndex( tunnelId ),
                                                          mplsTunnelCounterTable,
                                                          mplsTunnelSnapshotTable )
   return txPackets, txBytes, counterInUse

def getMplsTunnelRateCountersFromSmash( tunnelId ):
   if counterActiveForTunnelId( tunnelId ):
      counterIndex = CounterIndex( tunnelId )
      rateCtr = mplsTunnelRateCounterTable.rateCounter.get( counterIndex )
      if rateCtr is None:
         return 0.0, 0.0
      return rateCtr.pktsRate, rateCtr.bitsRate
   else:
      return None, None

srTeTunnelCountersHook.addExtension( getMplsTunnelCountersFromSmash )
#-------------------------------------------------------------------------
# The "show mpls tunnel counters" command
#-------------------------------------------------------------------------
sharedMatchObj = object()
countersAfterTunnelNode = CliCommand.guardedKeyword( 'counters',
      helpdesc='Tunnel egress hardware counters',
      guard=mplsTunnelCountersGuard )
typeMatcher = CliMatcher.KeywordMatcher( 'type',
      helpdesc='Match tunnel type' )
typeNode = CliCommand.singleNode(
   matcher=typeMatcher,
   sharedMatchObj=sharedMatchObj,
)
indexMatcher = CliMatcher.KeywordMatcher( 'index',
      helpdesc='Match tunnel index' )
nexthopMatcher = CliMatcher.KeywordMatcher( 'nexthop',
      helpdesc='Match tunnel nexthop' )
nexthopNode = CliCommand.singleNode(
   matcher=nexthopMatcher,
   sharedMatchObj=sharedMatchObj,
)
interfaceMatcher = CliMatcher.KeywordMatcher( 'interface',
      helpdesc='Match tunnel interface' )
interfaceNode = CliCommand.singleNode(
   matcher=interfaceMatcher,
   sharedMatchObj=sharedMatchObj,
)
nameMatcher = CliMatcher.KeywordMatcher( 'name',
      helpdesc='Match tunnel name' )
nameNode = CliCommand.singleNode( matcher=nameMatcher )
tableOutputMatcher = CliMatcher.KeywordMatcher( 'table-output',
      helpdesc='Provide results in a table format' )
endpointNode = CliCommand.singleNode(
   matcher=MplsCli.endpointMatcher,
   sharedMatchObj=sharedMatchObj,
)

nexthopValMatcher = IpGenAddrMatcher.IpGenAddrMatcher( MplsCli.nhStr )
tunnelEndpointMatcher = IpGenAddrMatcher.IpGenAddrOrPrefixExprFactory(
   ipOverlap=IpAddrMatcher.PREFIX_OVERLAP_AUTOZERO,
   ip6Overlap=IpAddrMatcher.PREFIX_OVERLAP_AUTOZERO,
   allowAddr=True )
tunnelNameMatcher = CliMatcher.DynamicNameMatcher(
   getTunnelNameList, 'Tunnel Name',
   pattern=CliParser.namePattern )

def getMplsTunnelCounterEntryModel( tunnelId=None, endpoint=None,
                                    nexthop=None, intfId=None ):

   tunnelType = TunnelCli.getTunnelTypeFromTunnelId( tunnelId )
   tunnelIndex = TunnelCli.getTunnelIndexFromId( tunnelId )
   if tunnelType in showTunnelFibIgnoredTunnelTypes:
      return None

   vias = []
   mplsTunnelCounterEntry = MplsCli.tunnelFib.entry.get( tunnelId )
   if mplsTunnelCounterEntry:
      tunnelEndpoint = TunnelCli.getEndpointFromTunnelId( tunnelId )
      if endpoint and tunnelEndpoint != endpoint:
         return None
      viaMatched = False
      for via in mplsTunnelCounterEntry.tunnelVia.values():
         singleViaMatched = True
         viaModels = getMplsViaModelFromTunnelVia( via )
         for viaModel in viaModels:
            if nexthop:
               singleViaMatched &= ( viaModel.nexthop == nexthop )
            if intfId:
               singleViaMatched &= ( viaModel.interface == intfId )
            vias.append( TunnelModels.IpVia( nexthop=viaModel.nexthop,
                                             interface=viaModel.interface,
                                             type='ip' ) )
         viaMatched |= singleViaMatched

      if not viaMatched:
         return None

      txPackets, txBytes, counterInUse = getMplsTunnelCountersFromSmash( tunnelId )
      tunnelName = getTunnelName( tunnelId )

      return MplsTunnelCountersEntry(
            txPackets=txPackets, txBytes=txBytes, counterInUse=counterInUse,
            tunnelIndex=tunnelIndex,
            tunnelType=TunnelModels.tunnelTypeStrDict[ tunnelType ],
            endpoint=tunnelEndpoint, vias=vias,
            tunnelName=tunnelName )
   return None

def getMplsTunnelCounterModel( args ):

   mplsTunnelCounterEntries = {}
   allTunnelIds = MplsCli.tunnelFib.entry
   tunnelIds = allTunnelIds
   endpoint = None
   nexthop = None
   intfId = None

   # Filters in this "block" are mutually exclusive. The Cli nodes use singleNode
   # so if the arg is in `args`, the list must have exactly one element.
   if 'TYPE' in args:
      tunnelType = tunnelTypeXlate[ args[ 'TYPE' ][ 0 ] ]
      if 'INDEX' in args:
         tunnelIds = [ TunnelCli.getTunnelIdFromIndex( tunnelType,
                          args[ 'INDEX' ][ 0 ] ) ]
      else:
         tunnelIds = [ tId for tId in allTunnelIds if tunnelType ==
                       TunnelCli.getTunnelTypeFromTunnelId( tId ) ]
   elif 'endpoint' in args:
      endpoint = Arnet.IpGenPrefix( str( args[ 'ENDPOINT' ][ 0 ] ) )
   elif 'NEXTHOP' in args:
      nexthop = Arnet.IpGenAddr( str( args[ 'NEXTHOP' ][ 0 ] ) )
   elif 'INTF' in args:
      intf = args[ 'INTF' ][ 0 ]
      intfId = Tac.Value( "Arnet::IntfId", str( intf ) ) if intf else None

   # `name` can be specified with one of any of the above filters.
   # CliCommand.singleNode is used so if "NAME" is in args, then
   # the value must have exactly one element.
   if 'NAME' in args:
      tunnelIds = filter(
         lambda tId: matchesTunnelName( tId, args[ 'NAME' ][ 0 ] ),
         tunnelIds,
      )

   for tunnelId in tunnelIds:
      mplsTunnelCounterEntryModel = getMplsTunnelCounterEntryModel(
                                       tunnelId=tunnelId, endpoint=endpoint,
                                       nexthop=nexthop, intfId=intfId )
      if mplsTunnelCounterEntryModel:
         mplsTunnelCounterEntries[ tunnelId ] = mplsTunnelCounterEntryModel

   return mplsTunnelCounterEntries

class ShowMplsTunnelCountersCmd( ShowCommand.ShowCliCommandClass ):
   syntax = \
      """show mpls tunnel counters
         [ {
            ( type TYPE [ index INDEX ] )
            | ( endpoint ENDPOINT )
            | ( nexthop NEXTHOP )
            | ( interface INTF )
            | ( name NAME )
         } ]
         [ table-output ]
      """

   data = {
         'mpls': MplsCli.mplsNodeForShow,
         'tunnel': MplsCli.tunnelAfterMplsMatcherForShow,
         'counters': countersAfterTunnelNode,
         'type' : typeNode,
         'TYPE': tunnelTypeValMatcher,
         'index': indexMatcher,
         'INDEX': TunnelCli.tunnelIndexMatcher,
         'endpoint' : endpointNode,
         'ENDPOINT' : tunnelEndpointMatcher,
         'nexthop' : nexthopNode,
         'NEXTHOP': nexthopValMatcher,
         'interface' : interfaceNode,
         'INTF' : MplsCli.intfValMatcher,
         'name' : nameNode,
         'NAME' : tunnelNameMatcher,
         'table-output' : tableOutputMatcher,
   }
   cliModel = MplsTunnelCounters

   @staticmethod
   def handler( mode, args ):
      ensureMplsTunnelCountersEnabled( mode )

      entries = list( getMplsTunnelCounterModel( args ).values() )
      # Sort entries by endpoint
      entries.sort( key=lambda entry: entry.endpoint.stringValue
                                      if entry.endpoint else "" )
      return MplsTunnelCounters( entries=entries,
                                 _tableOutput='table-output' in args )

BasicCli.addShowCommandClass( ShowMplsTunnelCountersCmd )

# -------------------------------------------------------------------------
# The "show tunnel counters mpls rate" command
# -------------------------------------------------------------------------
mplsCounterNodeMatcher = CliCommand.guardedKeyword( 'mpls',
      helpdesc='MPLS tunnel counters',
      guard=mplsTunnelCountersGuard )
rateCounterNodeMatcher = CliMatcher.KeywordMatcher( 'rate',
      helpdesc='MPLS tunnel rates' )

def getMplsTunnelCountersRateModelEntry( args, tunnelId ):
   mplsTunnelCounterEntry = MplsCli.tunnelFib.entry.get( tunnelId )
   if mplsTunnelCounterEntry is None:
      return None

   tunnelType = TunnelCli.getTunnelTypeFromTunnelId( tunnelId )
   if tunnelType in showTunnelFibIgnoredTunnelTypes:
      return None

   tunnelIndex = TunnelCli.getTunnelIndexFromId( tunnelId )
   tunnelEndpoint = TunnelCli.getEndpointFromTunnelId( tunnelId )
   tunnelName = getTunnelName( tunnelId )

   endpointFilter = args.get( 'ENDPOINT' )
   if endpointFilter is not None:
      if tunnelEndpoint != Arnet.IpGenPrefix( str( endpointFilter[ 0 ] ) ):
         return None

   nameFilter = args.get( 'NAME' )
   if nameFilter is not None and not matchesTunnelName( tunnelId, nameFilter[ 0 ] ):
      return None

   typeFilter = args.get( 'TYPE' )
   if typeFilter is not None:
      if tunnelType != tunnelTypeXlate[ typeFilter[ 0 ] ]:
         return None

   txPps, txBps = getMplsTunnelRateCountersFromSmash( tunnelId )
   return MplsTunnelCountersRateEntry(
      txPps=txPps, txBps=txBps,
      tunnelIndex=tunnelIndex,
      tunnelType=TunnelModels.tunnelTypeStrDict[ tunnelType ],
      endpoint=tunnelEndpoint,
      tunnelName=tunnelName )

def showMplsTunnelCountersRate( mode, args ):
   ensureMplsTunnelCountersEnabled( mode )

   rateModelEntryList = []
   allTunnelIds = MplsCli.tunnelFib.entry
   tunnelIds = allTunnelIds

   # Type and index uniquely identify a tunnel entry.
   typeFilter = args.get( 'TYPE' )
   indexFilter = args.get( 'INDEX' )
   if typeFilter is not None and indexFilter is not None:
      tunnelIds = [ TunnelCli.getTunnelIdFromIndex(
         tunnelTypeXlate[ typeFilter[ 0 ] ], indexFilter[ 0 ] ) ]

   for tunnelId in tunnelIds:
      rateModelEntry = getMplsTunnelCountersRateModelEntry( args, tunnelId )
      if rateModelEntry:
         rateModelEntryList.append( rateModelEntry )

   rateModelEntryList.sort( key=attrgetter( 'tunnelType', 'tunnelIndex' ) )
   return MplsTunnelCountersRate( entries=rateModelEntryList )

class ShowMplsTunnelCountersRateCmd( ShowCommand.ShowCliCommandClass ):
   syntax = \
      """show tunnel counters mpls rate
         [ {
            ( type TYPE [ index INDEX ] )
            | ( endpoint ENDPOINT )
            | ( name NAME )
         } ]
      """
   data = {
         'tunnel' : TunnelCli.tokenTunnelMatcher,
         'counters' : TunnelCli.tokenCountersMatcher,
         'mpls' : mplsCounterNodeMatcher,
         'rate' : rateCounterNodeMatcher,
         'type' : typeNode,
         'TYPE' : tunnelTypeValMatcher,
         'index' : indexMatcher,
         'INDEX' : TunnelCli.tunnelIndexMatcher,
         'endpoint' : endpointNode,
         'ENDPOINT' : tunnelEndpointMatcher,
         'name' : nameNode,
         'NAME' : tunnelNameMatcher,
   }
   cliModel = MplsTunnelCountersRate
   handler = showMplsTunnelCountersRate

BasicCli.addShowCommandClass( ShowMplsTunnelCountersRateCmd )

#------------------------------------------
# clear mpls tunnel counters
#------------------------------------------
class ClearMplsTunnelCountersCmd( CliCommand.CliCommandClass ):
   syntax = \
      """clear mpls tunnel counters
         [ {
           ( type TYPE [ index INDEX ] )
           | ( endpoint ENDPOINT )
           | ( nexthop NEXTHOP )
           | ( interface INTF )
           | ( name NAME )
         } ]
      """
   data = {
         'clear': CliToken.Clear.clearKwNode,
         'mpls': MplsCli.mplsMatcherForClear,
         'tunnel': MplsCli.tunnelAfterMplsMatcherForShow,
         'counters': countersAfterTunnelNode,
         'type' : typeNode,
         'TYPE': tunnelTypeValMatcher,
         'index': indexMatcher,
         'INDEX': TunnelCli.tunnelIndexMatcher,
         'endpoint' : endpointNode,
         'ENDPOINT' : tunnelEndpointMatcher,
         'nexthop' : nexthopNode,
         'NEXTHOP': nexthopValMatcher,
         'interface' : interfaceNode,
         'INTF' : MplsCli.intfValMatcher,
         'name' : nameNode,
         'NAME' : tunnelNameMatcher,
         'table-output' : tableOutputMatcher,
   }

   @staticmethod
   def handler( mode, args ):
      ensureMplsTunnelCountersEnabled( mode )

      entries = getMplsTunnelCounterModel( args )
      if entries:
         mplsTunnelClear.clearTunnelCountersRequest.clear()
         for tunnelId in entries:
            mplsTunnelClear.clearTunnelCountersRequest[ tunnelId ] = True
         n = len( entries )
         # pylint: disable-next=consider-using-f-string
         print( "%d tunnel counter entr%s cleared successfully" % (
                 n, 'y' if n == 1 else 'ies' ) )
      else:
         print( "% No tunnel counter entries found" )

BasicCli.EnableMode.addCommandClass( ClearMplsTunnelCountersCmd )

#-------------------------------------------------------------------------
# The "show mpls tunnel interface counters" command
#-------------------------------------------------------------------------
#class ShowMplsTunnelInterfaceCountersCmd( CliCommand.CliCommandClass ):
#   syntax = 'show mpls tunnel interface [ INTF ] counters egress'
#   data = {
#      'mpls': MplsCli.mplsNodeForShow,
#      'tunnel': MplsCli.tunnelAfterMplsMatcherForShow,
#      'interface': CliCommand.guardedKeyword( 'interface',
#                      helpdesc="Per-interface MPLS tunnel information",
#                      guard=mplsTunnelCountersGuard ),
#      'counters': 'ggregate MPLS tunnel counters',
#      'INTF': MplsCli.intfValMatcher,
#      'egress': 'Aggregate egress MPLS tunnel counters',
#   }
#   cliModel = MplsTunnelInterfaceCounters
#
#   @staticmethod
#   def handler( mode, args ):
#      tunnelInterfaceCounters = MplsTunnelInterfaceCounters()
#      interfaces = tunnelInterfaceCounters.interfaces
#
#      intfId = None
#      intf = args.get( 'INTF' )
#      if intf is not None:
#         intfId = Tac.Value( "Arnet::IntfId", str( intf ) )
#
#      # Generate per-interface aggregate stats for mpls tunnels
#      for tunnelId, tunnelEntry in MplsCli.tunnelFib.entry.items():
#         for via in tunnelEntry.tunnelVia:
#            viaModels = TunnelCli.getMplsViaModelFromTunnelVia( via )
#            for viaModel in viaModels:
#               if not viaModel.interface:
#                  continue
#               # Skip if need to filter by interface and not matched
#               if intfId and viaModel.interface != intfId:
#                  continue
#               txPackets, txBytes = getMplsTunnelCountersFromSmash( tunnelId )
#               if txPackets and txBytes:
#                  interface = viaModel.interface
#                  if interface in interfaces:
#                     interfaces[ interface ].txPackets += txPackets
#                     interfaces[ interface ].txBytes += txBytes
#                  else:
#                     interfaces[ interface ] = \
#                        MplsTunnelInterfaceCountersEntry(
#                           txPackets=txPackets, txBytes=txBytes )
#
#      return tunnelInterfaceCounters
#
# Command disabled, see BUG252582
# Todor: I've converted the above command, but it may need a little work.
#
# BasicCli.registerLegacyShowCommandClass( ShowMplsTunnelInterfaceCountersCmd )

def Plugin( em ):
   global mplsTunnelClear
   global mplsTunnelCounterTable
   global mplsTunnelSnapshotTable
   global mplsTunnelRateCounterTable
   global protocolTunnelNameStatus

   mplsTunnelClear = LazyMount.mount( em, 'hardware/counter/tunnel/clear/config',
                                      'Ale::FlexCounter::TunnelCliClearConfig', 'w' )

   # Reference to Nexthop feature ID is deliberate here as we are
   # reusing Nexthop counter smashes by design
   FeatureId = Tac.Type( 'FlexCounters::FeatureId' )
   AllFapsId = Tac.Type( 'FlexCounters::FapId' ).allFapsId

   # Mount the counter tables via smash
   mountInfo = SmashLazyMount.mountInfo( 'reader' )

   counterMountPath = tableMountPath( em, FeatureId.Nexthop, AllFapsId, False )
   mplsTunnelCounterTable = SmashLazyMount.mount( em,
      counterMountPath, "FlexCounters::CounterTable", mountInfo )

   snapshotMountPath = tableMountPath( em, FeatureId.Nexthop, AllFapsId, True )
   mplsTunnelSnapshotTable = SmashLazyMount.mount( em,
      snapshotMountPath, "FlexCounters::CounterTable", mountInfo )

   rateCounterMountPath = Tac.Type( 'FlexCounters::RateCounterTable' ).mountPath(
      "", FeatureId.Nexthop, AllFapsId )
   mplsTunnelRateCounterTable = SmashLazyMount.mount( em,
      rateCounterMountPath, 'FlexCounters::RateCounterTable', mountInfo )

   protocolTunnelNameStatus = LazyMount.mount( em, 'tunnel/tunnelNameStatus',
                                    "Tunnel::TunnelTable::ProtocolTunnelNameStatus",
                                    "r" )
