#!/usr/bin/env python3
# Copyright (c) 2022 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Arnet
import BasicCli
import CliCommand
from CliPlugin import (
   TunnelCli,
   TunnelCounterModel,
)
import CliToken.TunnelCli
import LazyMount
import SharedMem
import ShowCommand
import Smash
import SmashLazyMount
import Tac
import CliParser
from TypeFuture import TacLazyType

fcFeatureStatusDir = None
protocolTunnelNameStatus = None
tunnelCounterTable = None
tunnelCounterSnapshotTable = None
tunnelFib = None

# -------------------------------------------------------------------------
# The "show tunnel counters ip-in-ip" command
# -------------------------------------------------------------------------

# stolen from AleCountersCli
# TODO: May be move this function to some common package
def getCurrentCounter( idx, counterTable, snapshotTable ):
   # read the counters from counterTable and deduct the value
   # found in snapshotTable. snapshotTable contains the
   # snapshot of counter values from counterTable when
   # clear command is issued
   running = counterTable.counter.get( idx )
   snapshot = snapshotTable.counter.get( idx )
   pkts = getattr( running, 'pkts', 0 ) - getattr( snapshot, 'pkts', 0 )
   octets = getattr( running, 'octets', 0 ) - getattr( snapshot, 'octets', 0 )
   return pkts, octets

def getTunnelCountersFromSmash( tunnelId ):
   # TODO: Refer the bug 750839. Check if we can use counterActiveForTunnelId()
   # for the IpInIp tunnels before reading smash collection to fetch the counters.
   CounterIndex = Tac.Type( 'FlexCounters::CounterIndex' )
   return getCurrentCounter( CounterIndex( tunnelId ),
                             tunnelCounterTable,
                             tunnelCounterSnapshotTable )

def ipInIpTunnelCountersGuard( mode, token ):
   '''
   Guard the ip-in-ip Tunnel Counters CLI commands.
   '''
   # Feature supported on all platforms, no need to check that
   IpInIpTunnelFeatureId = TacLazyType( 'FlexCounters::FeatureId' ).IpInIpEncap
   status = fcFeatureStatusDir.entityPtr.get( IpInIpTunnelFeatureId )
   if status and status.state:
      return None
   return CliParser.guardNotThisPlatform

ipInIpCountersNode = CliCommand.guardedKeyword( 'ip-in-ip',
      helpdesc='IP-in-IP tunnel counters',
      guard=ipInIpTunnelCountersGuard )

def getTunnelName( tunnelId ):
   tunnelType = Tac.Value( "Tunnel::TunnelTable::TunnelId", tunnelId ).tunnelType()
   tunnelNameStatus = \
                    protocolTunnelNameStatus.localTunnelNameStatus.get( tunnelType )
   if tunnelNameStatus:
      return tunnelNameStatus.tunnelIdToName.get( tunnelId )
   return None

def getTunnelCounterEntryModel( tunnelId=None, endpoint=None ):

   tunnelIndex = TunnelCli.getTunnelIndexFromId( tunnelId )
   tunnelCounterEntry = tunnelFib.entry.get( tunnelId )
   if tunnelCounterEntry:
      tunnelEndpoint = TunnelCli.getEndpointFromTunnelId( tunnelId )
      if endpoint and tunnelEndpoint != endpoint:
         return None

      txPackets, txBytes = getTunnelCountersFromSmash( tunnelId )
      tunnelName = getTunnelName( tunnelId )

      return TunnelCounterModel.TunnelCountersEntry(
         txPackets=txPackets, txBytes=txBytes,
         tunnelIndex=tunnelIndex,
         endpoint=tunnelEndpoint,
         tunnelName=tunnelName )
   return None

def getTunnelCounterModel( args ):
   tunnelCounterEntries = {}
   allTunnelIds = tunnelFib.entry
   tunnelIds = allTunnelIds

   tunnelType = 'ipInIpTunnel'
   tunnelIds = [ tId
                 for tId in allTunnelIds
                 if tunnelType == TunnelCli.getTunnelTypeFromTunnelId( tId ) ]
   index = args.get( 'INDEX' )
   if index:
      tunnelIds = [ TunnelCli.getTunnelIdFromIndex( tunnelType, index ) ]
   endpoint = args.get( 'ENDPOINT' )
   if endpoint:
      endpoint = Arnet.IpGenPrefix( str( endpoint ) )
   for tunnelId in tunnelIds:
      tunnelCounterEntryModel = getTunnelCounterEntryModel(
         tunnelId=tunnelId, endpoint=endpoint )
      if tunnelCounterEntryModel:
         tunnelCounterEntries[ tunnelId ] = tunnelCounterEntryModel

   return tunnelCounterEntries

class ShowTunnelCountersCmd( ShowCommand.ShowCliCommandClass ):
   syntax = ( 'show tunnel counters ip-in-ip'
              '[ ( index INDEX )'
              '| ( endpoint ENDPOINT ) ] ' )
   data = {
         'tunnel': CliToken.TunnelCli.tokenTunnelMatcher,
         'counters': CliToken.TunnelCli.countersAfterTunnelNode,
         'ip-in-ip': ipInIpCountersNode,
         'index': CliToken.TunnelCli.indexMatcher,
         'INDEX': TunnelCli.tunnelIndexMatcher,
         'endpoint': CliToken.TunnelCli.endpointMatcher,
         'ENDPOINT': CliToken.TunnelCli.tunnelEndpointMatcher,
   }
   cliModel = TunnelCounterModel.TunnelCounters

   @staticmethod
   def handler( mode, args ):
      ipInIpTunnelCountersGuard( mode, 'ip-in-ip' )
      entries = list( getTunnelCounterModel( args ).values() )
      # Sort entries by endpoint
      entries.sort( key=lambda entry: entry.endpoint.stringValue
                                      if entry.endpoint else "" )
      return TunnelCounterModel.TunnelCounters( entries=entries )

BasicCli.addShowCommandClass( ShowTunnelCountersCmd )

# ------------------------------------------------------------------------
# Plugin
# ------------------------------------------------------------------------
def Plugin( em ):
   global fcFeatureStatusDir
   global protocolTunnelNameStatus
   global tunnelCounterTable
   global tunnelCounterSnapshotTable
   global tunnelFib

   shmashEm = SharedMem.entityManager( sysdbEm=em )

   protocolTunnelNameStatus = LazyMount.mount(
      em, 'tunnel/tunnelNameStatus',
      "Tunnel::TunnelTable::ProtocolTunnelNameStatus", "r" )

   # Mount the counter tables via smash
   mountInfo = SmashLazyMount.mountInfo( 'reader' )
   allFapsId = str( Tac.Type( 'FlexCounters::FapId' ).allFapsId )

   mountPath = "flexCounters/counterTable/Nexthop/" + allFapsId
   tunnelCounterTable = SmashLazyMount.mount(
      em, mountPath, "FlexCounters::CounterTable", mountInfo )

   mountPath = "flexCounters/snapshotTable/Nexthop/" + allFapsId
   tunnelCounterSnapshotTable = SmashLazyMount.mount(
      em, mountPath, "FlexCounters::CounterTable", mountInfo )

   # Mount the featureStatusDir
   mountPath = 'flexCounter/featureStatusDir'
   fcFeatureStatusDir = LazyMount.mount( em, mountPath, 'Tac::Dir', 'ri' )

   tunnelFib = shmashEm.doMount( 'tunnel/tunnelFib',
                                 'Tunnel::TunnelFib::TunnelFib',
                                 Smash.mountInfo( 'reader' ) )
