#!/usr/bin/env python3
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from CliDynamicPlugin.EosSdkRpcModel import (
      ConfigurationErrors,
      EnabledServices,
      EosSdkRpcAggregateStatus,
      EosSdkRpcTransportStatus,
      IpEndpoint,
      ListeningEndpoints,
      LocalInterfaceEndpoint,
      LocalhostUnixSocketEndpoint,
      UserAuthType )
from CliPlugin.VrfCli import DEFAULT_VRF
import LazyMount
import Tac

daemonConfigDir = None
serverConfigDir = None
serverStatusDir = None

def checkErrorsOverlap( seenErrors, applicableErrors ):
   if not seenErrors or seenErrors == [ 'NoConfigError' ]:
      return False
   return set( seenErrors ).intersection( set( applicableErrors ) )

def checkIfErrorsPerEndpoint( endpoints, configErrors ):
   listeningOnLocalIntf = True
   listeningOnLocalLoopback = True
   listeningOnLocalUnixSock = True
   errors = []
   if configErrors:
      for error in configErrors:
         if getattr( configErrors, error ):
            errors.append( error )

   # Verify no endpoint independent errors, which would prevent startup
   if checkErrorsOverlap( errors,
         [ 'noSslProfileFound', 'invalidSslProfileState', 'noSslCertOrKeyFile',
           'noTrustedCertsFound' ] ):
      return ( listeningOnLocalIntf, listeningOnLocalLoopback,
               listeningOnLocalUnixSock )


   if endpoints.localInterface:
      listeningOnLocalIntf = checkErrorsOverlap( errors,
            [ 'unavailableActiveIp', 'localIntfAddressInUse' ] )

   if endpoints.localhostLoopback:
      listeningOnLocalLoopback = checkErrorsOverlap( errors,
            [ 'loopbackVrfSkipped', 'loopbackAddressInUse' ] )

   if endpoints.localhostUnixSocket:
      listeningOnLocalUnixSock = checkErrorsOverlap( errors,
            [ 'UnixSocketInUse' ] )

   return ( listeningOnLocalIntf, listeningOnLocalLoopback,
            listeningOnLocalUnixSock )

def getAgentPid( sysname, agentName ):
   try:
      pid = Tac.run( [ 'pidof', f"EosSdkRpcAgent-{agentName}" ],
         stdout=Tac.CAPTURE ).strip( "\n" )
      return int( pid ) if isinstance( pid, str ) else pid
   except Tac.SystemCommandError:
      # agent not running
      return None

# -------------------------------------------------------------------------------
# The "show management api EosSdkRpc" command
# -------------------------------------------------------------------------------
def showStatusEosSdkRpc( mode, args ):
   model = EosSdkRpcAggregateStatus()
   model.enabled = False
   if not serverConfigDir:
      return model
   for name in serverConfigDir:
      transport = EosSdkRpcTransportStatus()
      # Transport status entity has not yet been created by the EosSdkRpcAgent.
      if not serverStatusDir or name not in serverStatusDir:
         transport.enabled = False
         transport.pid = None
         transport.running = False
         transport.endpoints = None
         transport.configErrors = None
         transport.sslProfile = ''
         transport.serverSecurity = "unknown"
         transport.metadataUserAuth = UserAuthType.disabled
         transport.allServicesEnabled = False
         model.transports[ name ] = transport
         continue
      serverStatus = serverStatusDir[ name ]
      daemonConfig = daemonConfigDir[ name ]
      if daemonConfig.enabled:
         model.enabled = True
      transport.enabled = daemonConfig.enabled
      transport.pid = getAgentPid( mode.entityManager.sysname(), name )
      configErrors = None
      if serverStatus.configError:
         for error in serverStatus.configError:
            if ( ( error != "noConfigError" ) and
                  serverStatus.configError.get( error ) ):
               if not configErrors:
                  configErrors = ConfigurationErrors()
               setattr( configErrors, error[ 0 ] + error[ 1 : ], True )
         transport.configErrors = configErrors
      endpoints = ListeningEndpoints()
      endpoints.vrfName = DEFAULT_VRF
      if serverStatus.localInterface:
         localIntf = LocalInterfaceEndpoint()
         localIntf.intfId = serverStatus.localInterface.intfId
         localIntf.ipAddress = serverStatus.localInterface.address
         localIntf.port = serverStatus.localInterface.port
         localIntf.isListening = False
         endpoints.localInterface = localIntf
         endpoints.vrfName = serverStatus.localInterface.vrfName
      if serverStatus.localhostLoopback:
         localLoopback = IpEndpoint()
         localLoopback.ipAddress = serverStatus.localhostLoopback.address
         localLoopback.port = ( serverStatus.localhostLoopback.port if
             serverStatus.localhostLoopback.port else 9543 )
         localLoopback.isListening = False
         endpoints.localhostLoopback = localLoopback
         if not serverStatus.localInterface:
            endpoints.vrfName = serverStatus.localhostLoopback.vrfName
      if serverStatus.unixEndpointAddress:
         endpoints.localhostUnixSocket = LocalhostUnixSocketEndpoint()
         endpoints.localhostUnixSocket.socketPath = serverStatus.unixEndpointAddress
         endpoints.localhostUnixSocket.isListening = False
      if transport.enabled and transport.pid:
         runningEndpoints = checkIfErrorsPerEndpoint( endpoints, configErrors )
         transport.running = not all( runningEndpoints )
         if endpoints.localInterface:
            endpoints.localInterface.isListening = not runningEndpoints[ 0 ]
         if endpoints.localhostLoopback:
            endpoints.localhostLoopback.isListening = not runningEndpoints[ 1 ]
         if endpoints.localhostUnixSocket:
            endpoints.localhostUnixSocket.isListening = not runningEndpoints[ 2 ]
      else:
         transport.running = False

      transport.endpoints = endpoints
      transport.sslProfile = serverStatus.sslProfile
      transport.serverSecurity = ( serverStatus.serverSecurity
            if serverStatus.serverSecurity else None )
      transport.metadataUserAuth = serverStatus.usernameAuthType
      transport.allServicesEnabled = serverStatus.enableAllServices
      if not transport.allServicesEnabled and serverStatus.servicesEnabled:
         enabledServices = EnabledServices()
         for service in serverStatus.servicesEnabled:
            setattr( enabledServices, service[ : 1 ].lower() + service[ 1 : ], True )
         transport.enabledServices = enabledServices
      model.transports[ name ] = transport
   return model

def Plugin( entityManager ):
   global serverConfigDir, serverStatusDir
   global daemonConfigDir
   daemonConfigDir = LazyMount.mount( entityManager, 'daemon/agent/config',
                                      'Tac::Dir', 'ri' )
   serverConfigDir = LazyMount.mount( entityManager,
                                      'eossdkrpc/server/config',
                                      'Tac::Dir', 'ri' )
   serverStatusDir = LazyMount.mount( entityManager,
                                      'eossdkrpc/server/status',
                                      'Tac::Dir', 'ri' )
