# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function
from CliPlugin import ServerProbeModel
from CliPlugin.VrfCli import DEFAULT_VRF
from CliPlugin.ServerProbeCli import MonitorServerMode
import LazyMount
import Tac
import Intf
from TypeFuture import TacLazyType
import ReversibleSecretCli
import operator

config = None
status = None
counterConfig = None
counterStatus = None
serverProbeService = TacLazyType( 'ServerProbe::ServerProbeService' )

# ------------------------------------------------------------
# switch(config)# monitor server <PROTOCOL_NAME>
# ------------------------------------------------------------

# Regular handler

def monitorServerModeCmd_handler( mode, args ):
   protocol = args[ 'PROTOCOL' ]
   childMode = mode.childMode( MonitorServerMode, protocol=protocol )
   mode.session_.gotoChildMode( childMode )

# no/default handler

def monitorServerModeCmd_noOrDefaultHandler( mode, args ):
   protocol = args[ 'PROTOCOL' ]
   protoEnum = Tac.enumValue( "ServerProbe::Protocol", protocol )
   protoConfig = config.protocolConfig[ protoEnum ]
   protoConfig.serviceList.clear()
   protoConfig.interval = protoConfig.defaultInterval
   protoConfig.failCount = protoConfig.defaultFailCount
   if protocol == 'radius':
      protoConfig.probeMethod = 'radiusStatusServer'
      protoConfig.radiusAccessRequestKey = Tac.Value(
            "ServerProbe::RadiusAccessRequestKey", username='',
            password=ReversibleSecretCli.getDefaultSecret() )

# ------------------------------------------------------------
# switch(config-monitor-server-<PROTOCOL_NAME>)# service dot1x
# ------------------------------------------------------------

# Regular handler

def serviceDot1xCmd_defaultHandler( mode, args ):
   pConfig = config.protocolConfig.newMember( mode.getProtocol() )
   pConfig.serviceList.add( serverProbeService.dot1x )

# no/default handler

def serviceDot1xCmd_noOrDefaultHandler( mode, args ):
   pConfig = config.protocolConfig.newMember( mode.getProtocol() )
   del pConfig.serviceList[ serverProbeService.dot1x ]

# ------------------------------------------------------------
# switch(config-monitor-server-<PROTOCOL_NAME>)# probe interval <INT> seconds
# ------------------------------------------------------------

# Regular handler

def probeIntervalCmd_defaultHandler( mode, args ):
   pConfig = config.protocolConfig.newMember( mode.getProtocol() )
   pConfig.interval = args[ 'INTERVAL' ]

# no/default handler

def probeIntervalCmd_noOrDefaultHandler( mode, args ):
   pConfig = config.protocolConfig.newMember( mode.getProtocol() )
   pConfig.interval = pConfig.defaultInterval

# ------------------------------------------------------------
#  switch(config-monitor-server-<PROTOCOL_NAME>)# probe threshold failure <1-256>
# ------------------------------------------------------------

# Regular handler

def probeFailureThresholdCmd_handler( mode, args ):
   pConfig = config.protocolConfig.newMember( mode.getProtocol() )
   probeState = args[ 'PROBE_STATE' ]
   probeCount = args[ 'COUNT' ]
   if probeState == 'failure':
      pConfig.failCount = probeCount

# no/default handler

def probeFailureThresholdCmd_noOrDefaultHandler( mode, args ):
   probeState = args[ 'PROBE_STATE' ]
   pConfig = config.protocolConfig.newMember( mode.getProtocol() )
   if probeState == 'failure':
      pConfig.failCount = pConfig.defaultFailCount

# ------------------------------------------------------------
#  switch(config-monitor-server-<PROTOCOL_NAME>)# probe method ...
# ------------------------------------------------------------
# TODO : If another protocol is introduced in serverProbe, this command needs
#        to be made explicit only when the protocol is RADIUS as 'status-server'
#        and 'access-request' as probe methods are not applicable to other protocols

# Regular handler

def probeMethodCmd_handler( mode, args ):
   if mode.protocol == 'radius':
      if 'status-server' in args:
         config.protocolConfig[ mode.getProtocol() ].probeMethod = \
               Tac.enumValue( "ServerProbe::ProbeMethod", 'radiusStatusServer' )
         config.protocolConfig[ mode.getProtocol() ].radiusAccessRequestKey = \
               Tac.Value( "ServerProbe::RadiusAccessRequestKey", username='',
                     password=ReversibleSecretCli.getDefaultSecret() )

      elif 'access-request' in args:
         username = args.get( 'USERNAME' )
         password = args.get( 'PASSWORD' )
         config.protocolConfig[ mode.getProtocol() ].radiusAccessRequestKey = \
               Tac.Value( "ServerProbe::RadiusAccessRequestKey",
                     username=username, password=password )
         config.protocolConfig[ mode.getProtocol() ].probeMethod = \
               Tac.enumValue( "ServerProbe::ProbeMethod", 'radiusAccessRequest' )

# no/default handler

def probeMethodCmd_noOrDefaultHandler( mode, args ):
   if mode.protocol == 'radius':
      config.protocolConfig[ mode.getProtocol() ].probeMethod = \
            Tac.enumValue( "ServerProbe::ProbeMethod", 'radiusStatusServer' )
      config.protocolConfig[ mode.getProtocol() ].radiusAccessRequestKey = \
            Tac.Value( "ServerProbe::RadiusAccessRequestKey", username='',
                  password=ReversibleSecretCli.getDefaultSecret() )

# ------------------------------------------------------------
#  switch> show monitor server <PROTOCOL> [detail]
# ------------------------------------------------------------

# Helper functions

def showServerProbeHost( probeStatusInfo ):
   hostSpec = probeStatusInfo.hostSpec
   protocol = probeStatusInfo.protocol
   ret = ServerProbeModel.ServerProbeStats()
   ret.serverInfo = ServerProbeModel.ServerInfo(
         host=hostSpec.hostname,
         port=hostSpec.port,
         alive=probeStatusInfo.alive,
         lastProbeTime=probeStatusInfo.lastProbeTime,
         lastResponseTime=probeStatusInfo.lastResponseTime,
         lastStatusChangeTime=probeStatusInfo.lastStatusChangeTime,
         failCount=probeStatusInfo.failCount )
   if hostSpec.vrf != DEFAULT_VRF:
      ret.serverInfo.vrf = hostSpec.vrf
   if hostSpec.acctPort:
      ret.serverInfo.acctPort = hostSpec.acctPort
   protocolCounters = counterStatus.protocolCounters[ protocol ]
   counters = protocolCounters.hostCounters.get( hostSpec, None )
   ret.lastClearTime = protocolCounters.lastClearTime
   if counters:
      ret.probesSent = counters.probesSent
      ret.responseAccepted = counters.successfulResponse
      ret.responseDropped = counters.droppedResponse
   else:
      ret.probesSent = 0
      ret.responseAccepted = 0
      ret.responseDropped = 0
   return ret

def showServerProbe( mode, protocol, detail ):
   ret = ServerProbeModel.ShowServerProbe()
   ret.protocol = protocol
   ret.detailIs( detail )

   # We should not set the probe method if no services are enabled for this protocol
   if not config.protocolConfig[ protocol ].serviceList:
      ret.probeMethod = None
   elif config.protocolConfig[ protocol ].probeMethod == \
         'radiusStatusServer':
      ret.probeMethod = 'statusServerPacket'
   elif config.protocolConfig[ protocol ].probeMethod == \
         'radiusAccessRequest':
      ret.probeMethod = 'accessRequestPacket'

   for info in sorted( status.status.values(),
                       key=operator.attrgetter( 'hostSpec' ) ):
      if info.protocol == protocol:
         ret.probedServers.append( showServerProbeHost( info ) )

   return ret

# Regular handler

def showMonitorServerCmd_handler( mode, args ):
   protocol = args[ 'PROTOCOL' ]
   detail = 'detail' in args
   return showServerProbe( mode, protocol, detail )

# -------------------------------------------------------------------------------
# switch# clear monitor server <PROTOCOL> counters
# -------------------------------------------------------------------------------

# Helper function

def clearCounters( mode, protocol ):
   if protocol not in counterStatus.protocolCounters:
      return
   Intf.Log.logClearCounters( 'monitor server %s' % protocol )
   now = Tac.utcNow()
   counterConfig.protocolCounterConfig[ protocol ].clearCounterRequestTime = now
   try:
      Tac.waitFor(
         lambda: counterStatus.protocolCounters[ protocol ].lastClearTime >= now,
         description=( 'Monitor server clear counter request for %s to complete'
                        % protocol ),
         warnAfter=None, sleep=True, maxDelay=0.5, timeout=5 )
   except Tac.Timeout:
      mode.addWarning( 'Monitor server counters for %s may '
                       'not have been reset yet' % protocol )

# Regular handler

def clearMonitorServerCountersCmd_handler( mode, args ):
   protocol = args[ 'PROTOCOL' ]
   clearCounters( mode, protocol )

def Plugin( entityManager ):
   global config, status, counterStatus, counterConfig

   config = LazyMount.mount( entityManager, "serverProbe/config",
         "ServerProbe::Config", "w" )
   status = LazyMount.mount( entityManager, "serverProbe/allServerStatus",
         "ServerProbe::ServerProbeStatus", "r" )
   counterStatus = LazyMount.mount( entityManager, "serverProbe/counter/status",
         "ServerProbe::ServerProbeCounterStatus", "r" )
   counterConfig = LazyMount.mount( entityManager, "serverProbe/counter/config",
         "ServerProbe::ServerProbeCounterConfig", "w" )
