# Copyright (c) 2015 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import BasicCli
import CliCommand
import CliExtensions
import CliMatcher
from CliPlugin import IntfCli
from CliPlugin import IntfModel
from CliPlugin.AleCountersCli import (
   checkCounterFeatureConfigured,
   checkCounterFeatureEnabled,
   counterFeatureSupported )
import LazyMount
import SmashLazyMount
import Tac

aleL3IntfTable = None
brConfig = None
numTcs = 8
tcCounterTable = None
tcSnapshotTable = None
vlanIntfStatusDir = None

FapId = Tac.Type( 'FlexCounters::FapId' )
FeatureId = Tac.Type( 'FlexCounters::FeatureId' )
TcCounterKey = Tac.Type( 'Ale::FlexCounter::TcCounterKey' )
TrafficClass = Tac.Type( 'Qos::TrafficClass' )

#---------------------------------------------------------------------------------
# "show interface [<name>] counters traffic-class [<tc>]" command, in enable mode.
#---------------------------------------------------------------------------------

tcIngressCountersSupportedOnlyForL3IntfsHook = CliExtensions.CliHook()
tcIngressCountersSupportedForSubIntfsHook = CliExtensions.CliHook()
tcIngressUcastMcastCountersSupportedHook = CliExtensions.CliHook()

def countersSupportedOnlyForL3Intfs():
   return tcIngressCountersSupportedOnlyForL3IntfsHook.any()

def countersSupportedForSubIntfs():
   return tcIngressCountersSupportedForSubIntfsHook.any()

def ucastMcastCountersSupported():
   return tcIngressUcastMcastCountersSupportedHook.any()

def tcIngressCounter( intfId, tc, isUcastMcastCounters ):
   tcKey = TcCounterKey( intfId, tc )
   counter = tcCounterTable.tcCounter.get( tcKey )
   snapshot = tcSnapshotTable.tcCounter.get( tcKey )
   intfTcCounter = IntfModel.InterfaceTrafficClassCounters()
   if counter is None:
      if isUcastMcastCounters:
         intfTcCounter.inPkts = 0
         intfTcCounter.inOctets = 0
         intfTcCounter.inUcastPkts = 0
         intfTcCounter.inMcastPkts = 0
         intfTcCounter.inUcastOctets = 0
         intfTcCounter.inMcastOctets = 0
         intfTcCounter.inDroppedUcastPkts = 0
         intfTcCounter.inDroppedMcastPkts = 0
         intfTcCounter.inDroppedUcastOctets = 0
         intfTcCounter.inDroppedMcastOctets = 0
      else:
         intfTcCounter.inPkts = 0
         intfTcCounter.inOctets = 0
   elif snapshot is None:
      if isUcastMcastCounters:
         intfTcCounter.inPkts = counter.pkts + counter.multicastPkts
         intfTcCounter.inOctets = counter.octets + counter.multicastOctets
         intfTcCounter.inUcastPkts = counter.pkts
         intfTcCounter.inMcastPkts = counter.multicastPkts
         intfTcCounter.inUcastOctets = counter.octets
         intfTcCounter.inMcastOctets = counter.multicastOctets
         intfTcCounter.inDroppedUcastPkts = counter.droppedPkts
         intfTcCounter.inDroppedMcastPkts = counter.droppedMulticastPkts
         intfTcCounter.inDroppedUcastOctets = counter.droppedOctets
         intfTcCounter.inDroppedMcastOctets = counter.droppedMulticastOctets
      else:
         intfTcCounter.inPkts = counter.pkts
         intfTcCounter.inOctets = counter.octets
   else:
      if isUcastMcastCounters:
         intfTcCounter.inPkts = counter.pkts + counter.multicastPkts \
                                - snapshot.pkts - snapshot.multicastPkts
         intfTcCounter.inOctets = counter.octets + counter.multicastOctets \
                                  - snapshot.octets - snapshot.multicastOctets
         intfTcCounter.inUcastPkts = counter.pkts - snapshot.pkts
         intfTcCounter.inMcastPkts = counter.multicastPkts - snapshot.multicastPkts
         intfTcCounter.inUcastOctets = counter.octets - snapshot.octets
         intfTcCounter.inMcastOctets = counter.multicastOctets \
                                       - snapshot.multicastOctets
         intfTcCounter.inDroppedUcastPkts = \
               counter.droppedPkts - snapshot.droppedPkts
         intfTcCounter.inDroppedMcastPkts = \
            counter.droppedMulticastPkts - snapshot.droppedMulticastPkts
         intfTcCounter.inDroppedUcastOctets = \
               counter.droppedOctets - snapshot.droppedOctets
         intfTcCounter.inDroppedMcastOctets = \
               counter.droppedMulticastOctets - snapshot.droppedMulticastOctets
      else:
         intfTcCounter.inPkts = counter.pkts - snapshot.pkts
         intfTcCounter.inOctets = counter.octets - snapshot.octets
   return intfTcCounter

def showInterfacesTcCounters( mode, args ):
   intf = args.get( 'INTF' )
   mod = args.get( 'MOD' )
   tcStart = args.get( 'START' )
   tcEnd = args.get( 'END' )

   intfsTcsCounters = IntfModel.InterfacesTrafficClassesCounters()
   if not ( checkCounterFeatureConfigured( FeatureId.TcIngress ) and
            checkCounterFeatureEnabled( FeatureId.TcIngress ) ):
      return intfsTcsCounters
   intfs = IntfCli.counterSupportedIntfs( mode, intf, mod )
   if not intfs:
      if intf is not None and countersSupportedForSubIntfs() and \
            not checkCounterFeatureEnabled( FeatureId.SubInterfaceIngress ):
         interfaces = IntfCli.Intf.getAll( mode, intf, mod, sort=False )
         for interface in interfaces:
            if interface.isSubIntf():
               mode.addWarning(
                  "Please enable the sub-interface ingress counter feature with the "
                  "'hardware counter feature subinterface in' command."
               )
               break
      return intfsTcsCounters

   if tcStart is None and tcEnd is None:
      tcRangeStart = 0
      tcRangeEnd = numTcs - 1
   else:
      tcRangeStart = tcStart
      tcRangeEnd = tcEnd if tcEnd else tcStart

   isL3IntfOnlyCounters = countersSupportedOnlyForL3Intfs()
   isUcastMcastCounters = ucastMcastCountersSupported()
   # Get counters for all intfs in range for all TCs in range;
   # format and display them.
   for intf in intfs:
      intfId = Tac.Value( 'Arnet::IntfId', str( intf ) )
      if isL3IntfOnlyCounters:
         internalVlan = aleL3IntfTable.intfIdToL3Id( intfId )
         # Only L3 intfs with an internal VLAN will have TC counters.
         isCounterIntf = bool( internalVlan )
      else:
         # All interfaces should have TC counters.
         isCounterIntf = True

      if isCounterIntf:
         intfTcsCounters = IntfModel.InterfaceTrafficClassesCounters(
               _name=intfId, _bumCounters=False,
               _umCounters=isUcastMcastCounters )
         for tc in range( tcRangeStart, tcRangeEnd + 1 ):
            intfTcsCounters.trafficClasses[ tc ] = \
                  tcIngressCounter( intfId, tc, isUcastMcastCounters )
         intfsTcsCounters.interfaces[ intfId ] = intfTcsCounters
   return intfsTcsCounters

class ShowIntfCountersTc( IntfCli.ShowIntfCommand ):
   syntax = 'show interfaces counters traffic-class [ START [ END ] ]'
   data = {
      'counters' : IntfCli.countersKw,
      'traffic-class' : CliCommand.guardedKeyword( 'traffic-class',
         helpdesc='Show traffic class information',
         guard=counterFeatureSupported( 'TcIngress' ) ),
      'START' : CliMatcher.IntegerMatcher( 0, numTcs - 1,
                                          helpdesc='Start from this traffic-class' ),
      'END' : CliMatcher.IntegerMatcher( 0, numTcs - 1,
                                         helpdesc='End at this traffic-class' )
   }
   handler = showInterfacesTcCounters
   cliModel = IntfModel.InterfacesTrafficClassesCounters

BasicCli.addShowCommandClass( ShowIntfCountersTc )

def Plugin( entityManager ):
   global aleL3IntfTable
   global brConfig
   global tcCounterTable
   global tcSnapshotTable
   global vlanIntfStatusDir

   aleL3IntfTable = LazyMount.mount( entityManager,
                                     'ale/l3IntfTable', 'AleL3Intf::Table', 'r' )
   brConfig = LazyMount.mount( entityManager,
                               'bridging/config', 'Bridging::Config', 'r' )
   vlanIntfStatusDir = LazyMount.mount( entityManager,
                                        'interface/status/eth/vlan',
                                        'Interface::VlanIntfStatusDir', 'r' )

   # Mount the TcIngress counter current and snapshot Smash tables read only.
   readerInfo = SmashLazyMount.mountInfo( 'reader' )
   mountPath = 'flexCounters/counterTable/TcIngress/%u' % ( FapId.allFapsId )
   tcCounterTable = SmashLazyMount.mount(
      entityManager, mountPath, "Ale::FlexCounter::TcCounterTable", readerInfo )
   mountPath = 'flexCounters/snapshotTable/TcIngress/%u' % ( FapId.allFapsId )
   tcSnapshotTable = SmashLazyMount.mount(
      entityManager, mountPath, "Ale::FlexCounter::TcCounterTable", readerInfo )
