#!/usr/bin/env python3
# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable-msg=protected-access

import ast
from collections import namedtuple
from Ark import (
   switchTimeToUtc
)
from AgentCommandRequest import runCliPrintSocketCommand
from BasicCli import (
   addShowCommandClass,
)
import AgentDirectory
import AgentCommandRequest
from io import StringIO
from CliModel import cliPrinted
import CliParser
import CliCommand
import CliMatcher
from CliPlugin import SfeFtModel
from CliPlugin.FlowTrackingCliLib import (
   getFlowGroupsAndIds,
   getIfIndexMap,
   getTrackerNamesAndIds,
   isSfeAgentRunning,
   hardwareShowKw,
   firewallShowKw,
   distributedShowKw,
   exporterKw,
   exporterNameMatcher,
   trackerKw,
   trackingShowKw,
   trackerNameMatcher,
   IP_GROUP,
   groupKw,
)
from CliPlugin.FlowTrackingCounterCliLib import (
   FlowGroupCounterKey,
   FlowGroupCounterEntry,
   TemplateIdType,
   CollectorStatisticsKey,
   CollectorInfo,
   addExporters,
   FlowCounterKey,
   FlowCounterEntry,
)
import CliPlugin.FlowTrackingCounterModel as ftCounterModel
from CliPlugin.SfeCliLib import nodeSfe
from CliPlugin.VrfCli import VrfExprFactory
from CliToken.Platform import (
   platformMatcherForShow
)
from CliToken.Flow import (
   flowMatcherForShow,
)
from CliToken.SfeFlowCliToken import (
   nodeDetail,
   nodeSrcIp,
   nodeIPv4,
   nodeDstIp,
   nodeSrcPort,
   nodePort,
   nodeDstPort,
   nodeProtocol,
   nodeProtocolValue
)
from FlowTrackerCliUtil import (
   ftrTypeHardware,
   ftrTypeDfw,
)
from IpLibConsts import ALL_VRF_NAME
import Cell
import LazyMount
from ShowCommand import ShowCliCommandClass
import SmashLazyMount
import SharedMem
import Smash
import Tac
import Tracing
import SfeAgent

traceHandle = Tracing.Handle( 'SfeFlowCliShow' )
t0 = traceHandle.trace0
t1 = traceHandle.trace1

DpsConstants = Tac.Type( 'Dps::DpsConstants' )

activeAgentDir = None
ethIntfStatusDir_ = None
ethLagIntfStatusDir_ = None
tunnelIntfStatusDir_ = None
dpsIntfStatusDir_ = None
sfeVrfIdMap = None
vrfNameStatus = None
entityManager = None
sfeConfig = None
dfwSfeConfig = None
sfeCounters = None
ipfixStats = None
shmemEm = None
applicationIdMap = None
appRecognitionConfig = None
classificationStatus = None
avtCliConfig = None

def getFlowTable( trackerName ):
   mountPath = f'flowtracking/{ftrTypeHardware}/flowTable/{trackerName}'
   smashFlowTable = shmemEm.doMount( mountPath,
                                     'SfeFlowTracker::FlowTrackerFlowTable',
                                     Smash.mountInfo( 'reader' ) )
   return smashFlowTable

def sfeGuard( mode, token ):
   if AgentDirectory.agent( mode.sysname, 'Sfe' ):
      return None
   else:
      return CliParser.guardNotThisPlatform

def getIntfFromIfindex( ifindex, ethIntfStatusDir, ethLagIntfStatusDir,
      tunnelIntfStatusDir, dpsIntfStatusDir ):
   ifIndexMap = getIfIndexMap( ethIntfStatusDir )
   if ifindex in ifIndexMap:
      return ifIndexMap[ ifindex ]
   ifIndexMap = getIfIndexMap( ethLagIntfStatusDir )
   if ifindex in ifIndexMap:
      return ifIndexMap[ ifindex ]
   ifIndexMap = getIfIndexMap( tunnelIntfStatusDir )
   if ifindex in ifIndexMap:
      return ifIndexMap[ ifindex ]
   ifIndexMap = getIfIndexMap( dpsIntfStatusDir )
   return ifIndexMap.get( ifindex, 'unknown' )

def ifindexToIntf( ifindex ):
   return getIntfFromIfindex( ifindex, ethIntfStatusDir_, ethLagIntfStatusDir_,
         tunnelIntfStatusDir_, dpsIntfStatusDir_ )

def executeCommand( mode, params, model ):
   paramStr = ' '.join( k + '=' + str( v ) for k, v in params.items() )
   return runCliPrintSocketCommand( mode.entityManager,
                                    'Sfe',
                                    'sfe-ipfix-show-flow',
                                    paramStr,
                                    mode,
                                    keepalive=True,
                                    connErrMsg='Sfe agent is inactive',
                                    model=model )

# SHOW COMMANDS
#------------------------------------------------------------
# show flow tracking hardware flow-table [detail]
# [ tracker <tracker-name> ] [ group <group-name> ]
# [ src-ip <ip> ] [ dst-ip <ip> ] [ src-port <port> ] [ dst-port <port> ]
# [ protocol <protocol> ] [ vrf <vrf> ] [ application APP_NAME ]
# [ application-service APPSERVICE_NAME ]
#
# implementation moved to C++ agent code
#------------------------------------------------------------

def allAppNames( mode ):
   apps = set( appRecognitionConfig.app )
   apps.update( classificationStatus.applicationStatus )
   return apps

def allServiceNames( mode ):
   services = set()
   services.update( classificationStatus.serviceStatus )
   return services

nodeApp = CliCommand.Node(
   matcher=CliMatcher.KeywordMatcher( 'application',
                                      'Specific application' ),
   maxMatches=1 )

nodeAppName = CliCommand.Node(
   matcher=CliMatcher.DynamicNameMatcher( allAppNames,
                                          'Name of an application',
                                          priority=CliParser.PRIO_LOW ),
   maxMatches=1 )

nodeAppService = CliCommand.Node(
   matcher=CliMatcher.KeywordMatcher( 'application-service',
                                      'Specific application service' ),
   maxMatches=1 )

nodeAppServiceName = CliCommand.Node(
   matcher=CliMatcher.DynamicNameMatcher( allServiceNames,
                                    helpdesc='Name of service of an application',
                                    priority=CliParser.PRIO_LOW ),
   maxMatches=1 )

allGroupNameMatcher = CliCommand.Node(
   matcher=CliMatcher.DynamicNameMatcher(
      lambda mode: ( IP_GROUP, ),
      "Flow group name",
      passContext=True ),
   maxMatches=1 )

classificationDpiExpression = ''
classificationDpiDict = {}
classificationDpiExpression = '| ( application APP_NAME ) '\
                              '| ( application-service APPSERVICE_NAME )'
classificationDpiDict = {
   'application' : nodeApp,
   'APP_NAME' : nodeAppName,
   'application-service' : nodeAppService,
   'APPSERVICE_NAME' : nodeAppServiceName,
}

class ShowTrackingFilterExpression( CliCommand.CliExpression ):
   expression = '[ { ( tracker TRACKER ) '\
                  '| ( group GROUP ) '\
                  '| ( src-ip SRC-IP ) '\
                  '| ( dst-ip DST-IP ) '\
                  '| ( src-port SRC-PORT ) '\
                  '| ( dst-port DST-PORT ) '\
                  '| ( protocol PROTOCOL ) '\
                  '| ( VRF ) '\
                  f'{classificationDpiExpression} '\
                  '| ( detail ) } ]'
   data = {
      'tracker' : CliCommand.Node( trackerKw, maxMatches=1 ),
      'TRACKER' : trackerNameMatcher,
      'group' : CliCommand.Node( groupKw, maxMatches=1 ),
      'GROUP' : allGroupNameMatcher,
      'src-ip' : nodeSrcIp,
      'SRC-IP' : nodeIPv4,
      'dst-ip' : nodeDstIp,
      'DST-IP' : nodeIPv4,
      'src-port' : nodeSrcPort,
      'SRC-PORT' : nodePort,
      'dst-port' : nodeDstPort,
      'DST-PORT' : nodePort,
      'protocol' : nodeProtocol,
      'PROTOCOL' : nodeProtocolValue,
      'VRF' : VrfExprFactory( helpdesc='Flow VRF',
                              inclDefaultVrf=True,
                              inclAllVrf=True,
                              maxMatches=1 ),
      'detail' : nodeDetail,
   }
   data.update( classificationDpiDict )

class ShowFlowTrackingFlowTable( ShowCliCommandClass ):
   syntax = '''show flow tracking hardware
            flow-table [ FILTER ] '''
   data = {
      'flow' : flowMatcherForShow,
      'tracking' : trackingShowKw,
      'hardware' : hardwareShowKw,
      'flow-table' : CliCommand.guardedKeyword(
         'flow-table',
         helpdesc='Flow table',
         guard=sfeGuard ),
      'FILTER' : ShowTrackingFilterExpression
   }
   cliModel = SfeFtModel.TrackingModel

   @staticmethod
   def handler( mode, args ):
      params = {}

      # common handling for subset of args
      # here value key is argName in upper case
      for param in ( 'tracker', 'group', 'src-ip', 'dst-ip',
            'src-port', 'dst-port' ):
         if param in args:
            if value:= args.get( param.upper() ):
               params[ param ] = value

      # rest of the args need some additional validation
      # hence done outside of common handling above
      if 'application' in args:
         appName = args[ 'APP_NAME' ]
         if appName not in classificationStatus.applicationStatus and \
               appName != 'unclassified':
            mode.addError( f'application {appName} was not found' )
            return None
         # appName is valid
         params[ 'app' ] = appName

      if 'application-service' in args:
         serviceName = args[ 'APPSERVICE_NAME' ]
         if serviceName not in classificationStatus.serviceStatus and \
               serviceName != 'unclassified':
            mode.addError( f'service {serviceName} was not found' )
            return None
         # serviceName is valid
         params[ 'service' ] = serviceName

      if 'VRF' in args:
         vrfName = args[ 'VRF' ]
         # Note we're passing vrfName so the agent can convert it to a platform
         # dependent ID
         if vrfName == ALL_VRF_NAME:
            # special case for handling "all" hinting match all vrfs
            params[ 'vrf' ] = vrfName
         elif vrfNameStatus.vrfIdExists( vrfName ):
            params[ 'vrf' ] = vrfName
         else:
            mode.addError( f'vrf {vrfName} was not found' )
            return None

      if 'detail' in args:
         params[ 'detail' ] = True

      if 'protocol' in args:
         protocol = args[ 'PROTOCOL' ]
         if protocol:
            if protocol.isdigit():
               protoNumber = int( protocol )
            else:
               protoStr = 'ipProto' + protocol.capitalize()
               try:
                  protoNumber = Tac.enumValue( "Arnet::IpProtocolNumber", protoStr )
               except AttributeError:
                  mode.addError( f'Unknown protocol {protocol}' )
                  return None
            params[ 'protocol' ] = protoNumber

      model = executeCommand( mode, params,
                     SfeFtModel.TrackingModel )
      return cliPrinted( model ) if model else None

addShowCommandClass( ShowFlowTrackingFlowTable )

#--------------------------
FgCount = namedtuple( 'FgCount', [ 'flows', 'expiredFlows', 'packets' ] )

def addFlowGroups( trModel, trName, ftId, ftrType ):
   anyGroups = False
   flows = 0
   expiredFlows = 0
   packets = 0
   for fgName, fgId in getFlowGroupsAndIds( ftrType, trName ):
      anyGroups = True
      t1( 'process flow group', fgName )
      fgModel = ftCounterModel.FlowGroupCounters()

      counterKey = FlowGroupCounterKey( ftId, fgId )
      fgCounts = sfeCounters.flowGroupCounter.get( counterKey )
      if not fgCounts:
         t0( 'No counters for', counterKey.smashString() )
         fgCounts = FlowGroupCounterEntry()
      flows += fgCounts.flowEntry.totalFlowsCreated
      fgModel.flows = fgCounts.flowEntry.totalFlowsCreated
      fgModel.activeFlows = fgCounts.flowEntry.activeFlows
      expFlows = fgCounts.flowEntry.totalFlowsCreated - \
         fgCounts.flowEntry.activeFlows
      expiredFlows += expFlows
      fgModel.expiredFlows = expFlows
      packets += fgCounts.flowEntry.packets
      fgModel.packets = fgCounts.flowEntry.packets

      trModel.flowGroups[ fgName ] = fgModel

   clearTimeKey = FlowGroupCounterKey( ftId, TemplateIdType.maxTemplateId )
   clearTime = sfeCounters.flowGroupCounter.get( clearTimeKey )
   if clearTime and clearTime.key == clearTimeKey:
      # clearTime should be most recent lastClearedTime from either
      # sfeCounters or ipfixStats
      lastClearedTime = switchTimeToUtc( clearTime.flowEntry.lastClearedTime )
      if trModel.clearTime is None or lastClearedTime > trModel.clearTime:
         trModel.clearTime = lastClearedTime

   if not anyGroups:
      t0( 'WARNING: no flow groups for tracker', trName )

   return FgCount( flows=flows, expiredFlows=expiredFlows, packets=packets )

def addTrackers( model, trFilter, expFilter, ftrType ):
   allTrackerFlows = 0
   allActiveFlows = 0
   for trName, ftId in getTrackerNamesAndIds( ftrType ):
      if trFilter and trFilter != trName:
         t1( 'tracker', trName, 'did not match filter' )
         continue
      t1( 'process tracker', trName )
      flowTable = getFlowTable( trName )
      if not flowTable:
         continue
      trModel = ftCounterModel.TrackerCounters()
      clearTimeKey = CollectorStatisticsKey( trName, "", CollectorInfo() )
      clearTime = ipfixStats.stats.get( clearTimeKey )
      if clearTime and clearTime.key == clearTimeKey:
         trModel.clearTime = switchTimeToUtc( clearTime.lastClearedTime )
      counts = addFlowGroups( trModel, trName, ftId, ftrType )
      trModel.flows = counts.flows
      trModel.activeFlows = counts.flows - counts.expiredFlows
      trModel.expiredFlows = counts.expiredFlows
      allTrackerFlows += counts.flows
      allActiveFlows += trModel.activeFlows
      trModel.packets = counts.packets
      if ftrType == ftrTypeHardware:
         addExporters( trModel, trName, expFilter, sfeConfig, ipfixStats, True )
      else:
         addExporters( trModel, trName, expFilter, dfwSfeConfig, ipfixStats, True )
      model.trackers[ trName ] = trModel
   ftrKey = FlowCounterKey( 0 )
   ftrCounts = sfeCounters.flowsCounter.get( ftrKey )
   if not ftrCounts:
      ftrCounts = FlowCounterEntry()
   if ftrCounts.flowEntry.lastClearedTime != 0:
      lastClearedTime = switchTimeToUtc( ftrCounts.flowEntry.lastClearedTime )
      model.clearTime = lastClearedTime
   model.activeFlows = allActiveFlows
   model.flows = ftrCounts.flowEntry.totalFlowsCreated
   model.expiredFlows = ftrCounts.flowEntry.totalFlowsCreated - \
      ftrCounts.flowEntry.activeFlows
   model.packets = ftrCounts.flowEntry.packets

def showCountersCmd( mode, args ):
   model = ftCounterModel.FtrCounters()
   model.running = isSfeAgentRunning( entityManager, activeAgentDir )
   model.softwareFlowTable = True
   if model.running:
      trFilter = args.get( 'TRACKER_NAME' )
      expFilter = args.get( 'EXPORTER_NAME' )
      ftrType = ftrTypeHardware
      if 'firewall' in args:
         ftrType = ftrTypeDfw
      addTrackers( model, trFilter, expFilter, ftrType )
   return model

class ShowFtCounters( ShowCliCommandClass ):
   syntax = '''show flow tracking
               hardware | ( firewall distributed )
               counters
               [ tracker TRACKER_NAME [ exporter EXPORTER_NAME ] ]'''

   data = {
      'flow' : flowMatcherForShow,
      'tracking' : trackingShowKw,
      'hardware' : hardwareShowKw,
      'firewall' : firewallShowKw,
      'distributed' : distributedShowKw,
      'counters' : CliCommand.guardedKeyword(
         'counters',
         helpdesc='Show flow tracking counters',
         guard=sfeGuard ),
      'tracker' : trackerKw,
      'TRACKER_NAME' : trackerNameMatcher,
      'exporter' : exporterKw,
      'EXPORTER_NAME' : exporterNameMatcher,
   }

   handler = showCountersCmd
   cliModel = ftCounterModel.FtrCounters

addShowCommandClass( ShowFtCounters )

#---------------------------------------------------------------------------------
# show platform sfe flow tracking counters
#---------------------------------------------------------------------------------
def doShowSfeFtCounters( mode, args ):
   buff = StringIO()
   AgentCommandRequest.runSocketCommand( mode.entityManager, SfeAgent.name(),
                                         "sfeFlowTracker", "Ftcnt", stringBuff=buff,
                                         timeout=50, keepalive=True )
   output = buff.getvalue()
   try:
      # pylint: disable-msg=W0123
      ftCounters = ast.literal_eval( output )
   except SyntaxError:
      mode.addError( output )
      return SfeFtModel.PlatformCountersModel()

   model = SfeFtModel.PlatformCountersModel()
   model.setAttrsFromDict( ftCounters )
   return model

class ShowSfeFtCountersCmd( ShowCliCommandClass ):
   syntax = 'show platform sfe flow tracking counters'
   data = {
      'platform' : platformMatcherForShow,
      'sfe' : nodeSfe,
      'flow' : flowMatcherForShow,
      'tracking' : trackingShowKw,
      'counters' : 'Show flow tracker counters',
   }

   handler = doShowSfeFtCounters
   cliModel = SfeFtModel.PlatformCountersModel
   privileged = True

addShowCommandClass( ShowSfeFtCountersCmd )

#--------------------------
def Plugin( em ):
   global sfeConfig
   global dfwSfeConfig
   global activeAgentDir
   global ethIntfStatusDir_
   global ethLagIntfStatusDir_
   global tunnelIntfStatusDir_
   global dpsIntfStatusDir_
   global sfeVrfIdMap
   global vrfNameStatus
   global entityManager
   global shmemEm
   global ipfixStats
   global sfeCounters
   global applicationIdMap
   global appRecognitionConfig
   global classificationStatus
   global avtCliConfig

   entityManager = em

   sfeConfig = LazyMount.mount( em,
                                'hardware/flowtracking/config/hardware',
                                'HwFlowTracking::Config', 'r' )
   dfwSfeConfig = LazyMount.mount( em,
                                'hardware/flowtracking/config/dfw',
                                'HwFlowTracking::Config', 'r' )
   activeAgentDir = LazyMount.mount( em, 'flowtracking/activeAgent',
                                     'Tac::Dir', 'ri' )
   ethIntfStatusDir_ = LazyMount.mount( em, "interface/status/eth/intf",
                                       "Interface::EthIntfStatusDir", "r" )
   ethLagIntfStatusDir_ = LazyMount.mount( em, "interface/status/eth/lag",
                                       "Interface::EthLagIntfStatusDir", "r" )
   tunnelIntfStatusDir_ = LazyMount.mount( em, "interface/status/tunnel/intf",
                                          "Interface::TunnelIntfStatusDir", "r" )
   dpsIntfStatusDir_ = LazyMount.mount( em, "interface/status/dps/intf",
                                          "Interface::DpsIntfStatusDir", "r" )
   vrfNameStatus = LazyMount.mount( em, Cell.path( "vrf/vrfNameStatus" ),
                                       "Vrf::VrfIdMap::NameToIdMapWrapper", "r" )
   sfeVrfIdMap = SmashLazyMount.mount( em, "vrf/hardware/vrfIdMapStatus",
                                        "Vrf::VrfIdMap::Status",
                                         SmashLazyMount.mountInfo( 'reader' ) )
   ipfixStats = SmashLazyMount.mount( em, 'flowtracking/hardware/ipfix/statistics',
                                      'Smash::Ipfix::CollectorStatistics',
                                      SmashLazyMount.mountInfo( 'reader' ) )
   sfeCounters = SmashLazyMount.mount( em, 'flowtracking/hardware/counters',
                                       'Smash::FlowTracker::FtCounters',
                                        SmashLazyMount.mountInfo( 'reader' ) )
   shmemEm = SharedMem.entityManager( sysdbEm=em )

   applicationIdMap = LazyMount.mount( em, "flowtracking/status/applicationIdMap",
                                       "FlowTracking::ApplicationIdMap", "r" )
   appRecognitionConfig = LazyMount.mount( em,
                                          'classification/app-recognition/config',
                                          'Classification::AppRecognitionConfig',
                                          'r' )
   classificationStatus = \
         LazyMount.mount( em, 'classification/app-recognition/status',
                          'Classification::Status',
                          'r' )
   avtCliConfig = \
         LazyMount.mount( em, 'avt/input/cli',
                          'Avt::AvtCliConfig', 'r' )
