# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import os
import signal

import AaaPluginLib
from AaaPluginLib import TR_ERROR, TR_WARN, TR_AUTHEN, TR_AUTHZ, TR_INFO, TR_DEBUG
from BothTrace import traceX as bt
from BothTrace import Var as bv
import Cell
import ldap
import Ldap
import Logging
import MgmtSecuritySslStatusSm
import Plugins
import Tac
from Tracing import traceX
from IpLibConsts import DEFAULT_VRF
from M2Crypto import util

AAA_NO_VALID_LDAP_SERVERS = Logging.LogHandle(
              "AAA_NO_VALID_LDAP_SERVERS",
              severity=Logging.logError,
              fmt="No valid LDAP servers for method list '%s'",
              explanation="The configuration contains an authentication "
                          "method list that is not associated with any valid "
                          "LDAP servers. One common cause for this error is "
                          "a hostname for which DNS resolution fails.",
              recommendedAction="Correct the LDAP server configuration." )

# ConfigReactor (below) listens for changes to Ldap::Config and clears
# _sessionPool whenever the config changes to avoid the situation where a
# pooled session with outdated configuration data could be used.
_sessionPool = AaaPluginLib.SessionPool( 'ldap' )

# Constants for ssl profiles
Constants = Tac.Type( "Mgmt::Security::Ssl::Constants" )
ProfileState = Tac.Type( "Mgmt::Security::Ssl::ProfileState" )
SslFeature = Tac.Type( "Mgmt::Security::Ssl::SslFeature" )

def _invalidateSessionPool( hostgroup=None ):
   bt( TR_WARN, 'invalidate LDAP session pool' )
   _sessionPool.clear( hostgroup )

def getConfigAttr( spec, config, attrName ):
   host = config.host.get( spec )
   attrVal = None
   if host:
      attrVal = getattr( host.serverConfig, attrName )
   if not attrVal:
      attrVal = getattr( config.defaultConfig, attrName )
   traceX( TR_DEBUG, "Configured attribute", attrName, "for host", spec )
   return attrVal

class Authenticator( AaaPluginLib.BasicUserAuthenticator ):
   def __init__( self, plugin, method, type, # pylint: disable-msg=W0622
                 service, remoteHost, remoteUser, tty, user, privLevel ):
      self.plugin = plugin
      AaaPluginLib.BasicUserAuthenticator.__init__( self, plugin.aaaConfig,
                                                    method, type, service,
                                                    remoteHost, remoteUser, tty,
                                                    user, privLevel )

   def checkPassword( self, user, password ):
      failType = 'unavailable'
      failText = ''
      sg = AaaPluginLib.extractGroupFromMethod( self.method )

      username = self.user

      msg = []

      session = self.plugin.acquireSession( self.method )
      if session:
         try:
            traceX( TR_AUTHEN, "authenticating user", user )
            result, failText, attrs = session.sendAuthenReq( username, password )
            if result == 'success':
               r = { 'state': self.succeeded,
                     'authenStatus': 'success',
                     'messages': msg,
                     'user': self.user,
                     'authToken': self.password,
                     'sessionData': attrs }
               activeSpec = session.getActiveServer()
               traceX( TR_AUTHEN, "Authenticated by server", activeSpec )
               self.plugin.status.activeServerForSg[ sg ] = activeSpec
               return r
            else:
               del self.plugin.status.activeServerForSg[ sg ]
               failType = result
               session.close()
         except ldap.LDAPError as error:
            errorMsg = error.message[ 'desc' ]
            bt( TR_ERROR, "sendAuthenReq exception", bv( errorMsg ) )
            session.close()
      else:
         failText = 'Error in adding server (DNS issue?)'

      if not failText:
         failText = 'Error in authentication'
      msg = [ Tac.Value( "AaaApi::AuthenMessage", style='error', text=failText ) ]

      return { 'state': self.failed, 'authenStatus': failType, 'messages': msg,
               'user': user, 'authToken': password }

class LdapPlugin( AaaPluginLib.Plugin ):
   def __init__( self, config, status, aaaConfig, ipStatus, ip6Status,
                 allVrfStatusLocal, sslConfig, sslStatus ):
      AaaPluginLib.Plugin.__init__( self, aaaConfig, "ldap",
                                    allVrfStatusLocal=allVrfStatusLocal )
      self.config = config
      self.status = status
      self.aaaConfig = aaaConfig
      self.ipStatus = ipStatus
      self.ip6Status = ip6Status
      self.sslConfig = sslConfig
      self.sslStatus = sslStatus

   def ready( self ):
      if self.config.host:
         return True

      bt( TR_ERROR, "plugin not ready: no LDAP servers configured" )
      return False

   def handlesAuthenMethod( self, method ):
      return self._handlesGroupAuthenMethod( method, "ldap", "ldap" )

   def getActiveGroupPolicy( self, spec ):
      host = self.config.host.get( spec )
      policy = None
      if host:
         policy = host.serverConfig.activeGroupPolicy
      if not policy:
         policy = self.config.defaultConfig.activeGroupPolicy
      return policy

   def getSslProfile( self, spec ):
      sslProfile = getConfigAttr( spec, self.config, 'sslProfile' )
      profile = {}
      if sslProfile:
         profileName = sslProfile
         profileStatusList = self.sslStatus.profileStatus
         if ( ( profileName in profileStatusList ) and
              profileStatusList[ profileName ].state == ProfileState.valid ):
            profileConfig = self.sslConfig.profileConfig[ profileName ]
            if profileConfig.trustedCert:
               profile[ "trustedCert" ] = list( profileConfig.trustedCert )
               profile[ "crl" ] = list( profileConfig.crl )
               profile[ "tlsVersion" ] = profileConfig.tlsVersion
               profile[ "cipherSuite" ] = profileConfig.cipherSuite
            else:
               traceX( 'No trusted certificate configured' )
               return None

         else:
            traceX( TR_ERROR, 'Invalid SSL profile' )
            return None
      return profile

   def createAuthenticator( self, method, type, # pylint: disable-msg=W0622
                            service, remoteHost, remoteUser,
                            tty, user=None, privLevel=0 ):
      traceX( TR_AUTHEN, "LDAP create authenticator for method:", method, "type:",
              type, "service:", service, "user:", user, "privLevel:", privLevel )
      return Authenticator( self, method, type, service,
                            remoteHost, remoteUser,
                            tty, user, privLevel )

   def openSession( self, authenticator ):
      return authenticator

   def closeSession( self, token ):
      pass

   def authorizeShell( self, method, user, session ):
      traceX( TR_AUTHZ, "LDAP authorizeShell user", user )
      if session is None or session.property.get( method ) is None:
         # We do not have user privilege data, most likely the authentication
         # wasn't handled by LDAP
         return ( 'authzUnavailable',
                  'LDAP authorization requires LDAP authentication', {} )
      sg = AaaPluginLib.extractGroupFromMethod( method )
      sessionData = session.property.get( session.authenMethod )

      userDn = None
      activeSpec = self.status.activeServerForSg[ sg ]
      if sessionData:
         userDn = sessionData.attr.get( 'userDn' )
      if not userDn or not activeSpec:
         return ( "denied", "not authenticated by LDAP", {} )

      activeGroupPolicy = self.getActiveGroupPolicy( activeSpec )
      groupPolicy = self.config.groupPolicy.get( activeGroupPolicy )
      if not groupPolicy:
         traceX( TR_AUTHZ, "No policy to authorize applied" )
         attrs = {}
         if sessionData.attr.get( 'priv-lvl' ):
            attrs[ AaaPluginLib.privilegeLevel ] = 1
         return ( "allowed", "no ldap group policy configured", attrs )

      ss = self.acquireSession( method )
      if not ss:
         return ( "authzUnavailable",
                  "LDAP server became unavailable during authorization", {} )
      result, failText, attrs = ss.sendAuthzReq(
         userDn, groupPolicy.searchFilter, groupPolicy.groupRolePrivilege )
      if result == 'allowed':
         self.status.activeServerForSg[ sg ] = ss.getActiveServer()
      else:
         del self.status.activeServerForSg[ sg ]

      return ( result, failText, attrs )

   def authorizeShellCommand( self, method, user, session, mode, privlevel, tokens ):
      # NotSupported
      return ( "denied", "authorization is not supported by Ldap", {} )

   def sendCommandAcct( self, method, user, session, privlevel, timestamp, tokens,
                        cmdType=None, **kwargs ):
      # NotSupported
      return ( "acctFail", "accounting is not supported by Ldap" )

   def acquireSession( self, methodName ):
      traceX( TR_DEBUG, "acquire session for method", methodName )
      hostGroup, serverSpecs = AaaPluginLib.serversForMethodName( methodName,
                                                                  self.aaaConfig,
                                                                  'ldap' )
      if hostGroup is not None:
         # If we ever enable session pool for LDAP, we need to enhance and
         # test ConfigReactor for things like defaultConfig and timeout changes.
         session = _sessionPool.get( hostGroup )
         assert not session

      ldapConfig = self.config

      session = Ldap.Session( hostGroup )

      addedServers = 0
      # If no aaa group servers are found, then follow the order of ldap
      # server configuration. Otherwise, honor the aaa group server config
      if not serverSpecs:
         # If the server dict is empty I consider all configured servers to be
         # fair game.  Either methodName indicates that all servers should be
         # included, or it's a server group with no members, but I expect the
         # former because I shouldn't get to this point without any servers
         # since handlesAuthenMethod should have screened the request out.
         defaultSortedList = sorted( ldapConfig.host.values(),
                                     key=lambda host: host.index )

         # generate round robin order based on known active server
         activeSpec = self.status.activeServerForSg.get( hostGroup )
         activeHost = None
         hostList = defaultSortedList

         if activeSpec:
            traceX( TR_DEBUG, "Active server", activeSpec )
            activeHost = ldapConfig.host.get( activeSpec )

         if activeHost:
            try:
               listIndex = defaultSortedList.index( activeHost )
               hostList = ( defaultSortedList[ listIndex : ] +
                            defaultSortedList[ : listIndex ] )
            except ValueError:
               # Multi-threaded race: a host could just have been deleted.
               traceX( TR_ERROR, "Active server was deleted" )
      else:
         hostList = []
         for spec in serverSpecs:
            try:
               host = ldapConfig.host[ spec ]
               hostList.append( host )
            except KeyError:
               # Server group contains members that were not in the
               # host collection
               groupname = AaaPluginLib.extractGroupFromMethod( methodName )
               # pylint: disable-next=consider-using-f-string
               serverStr = "%s:%d" % ( spec.hostname, spec.port )
               if spec.vrf and spec.vrf != DEFAULT_VRF:
                  # pylint: disable-next=consider-using-f-string
                  serverStr += " (vrf %s)" % spec.vrf
               traceX( TR_WARN, "server group", groupname,
                       "contains unconfigured server:", serverStr )

      for h in hostList:
         assert h.vrf
         ns = self.getNsFromVrf( h.vrf )
         if not ns:
            bt( TR_ERROR, "cannot get namespace for vrf", bv( h.vrf ) )
            continue

         cb = Ldap.CounterCallback( self.status, h.spec )
         profile = self.getSslProfile( h.spec )
         if not profile and getConfigAttr( h.spec, self.config, 'sslProfile' ):
            continue
         try:
            bt( TR_AUTHEN, 'add server', bv( h.spec.stringValue() ) )

            session.addServer(
               h, ns=ns, counterCallback=cb,
               baseDn=getConfigAttr( h.spec, self.config, 'baseDn' ),
               userDn=getConfigAttr( h.spec, self.config, 'userRdnAttribute' ),
               sslProfile=profile,
               activeGroupPolicy=getConfigAttr( h.spec, self.config,
                                                'activeGroupPolicy' ),
               searchUsernamePassword=getConfigAttr( h.spec, self.config,
                                                     'searchUsernamePassword' ),
               timeout=getConfigAttr( h.spec, self.config, 'ldapTimeout' ) )

            addedServers += 1
         except ldap.LDAPError as e:
            bt( TR_ERROR, "failed to add server", bv( h.hostname ),
                "vrf", bv( h.vrf ), ":", bv( str( e ) ) )

      if addedServers == 0:
         # No servers indicates is a configuration error
         Logging.log( AAA_NO_VALID_LDAP_SERVERS, methodName )
         session.close()
         session = None
      return session

   def releaseSession( self, session ):
      # Do not put back in pool as we are not yet sharing session
      # across operations.
      pass

   def hasUnknownUser( self ):
      return True

class ServerConfigReactor( Tac.Notifiee ):
   notifierTypeName = "Ldap::ServerConfig"

   def __init__( self, defaultConfig, sslConfig, sslStatus, config, status, agent ):
      self.serverConfig = defaultConfig
      self.sslConfig = sslConfig
      self.sslStatus = sslStatus
      self.config = config
      self.status = status
      self.agent = agent
      self.sslReactor = None
      self.handleSslProfile()
      Tac.Notifiee.__init__( self, self.serverConfig )

   @Tac.handler( 'sslProfile' )
   def handleSslProfile( self ):
      _invalidateSessionPool()
      profileName = self.serverConfig.sslProfile
      bt( TR_INFO, "handle SSL profile", profileName )
      if profileName:
         self.sslReactor = LdapSslReactor( self.sslConfig, self.sslStatus,
                                           profileName, self.config,
                                           self.status, self.agent )
      else:
         if self.sslReactor:
            self.sslReactor.close()
         self.sslReactor = None

   def close( self ):
      if self.sslReactor:
         self.sslReactor.close()
         self.sslReactor = None
      Tac.Notifiee.close( self )

class ConfigReactor( AaaPluginLib.ConfigReactor ):
   notifierTypeName = "Ldap::Config"
   counterTypeName = "Ldap::Counters"

   def __init__( self, notifier, status, allVrfStatusLocal, ipStatus, ip6Status,
                 sslConfig, sslStatus, agent ):
      AaaPluginLib.ConfigReactor.__init__( self, notifier,
                                           status, allVrfStatusLocal,
                                           ipStatus, ip6Status )
      self.config = notifier
      self.sslConfig = sslConfig
      self.sslStatus = sslStatus
      self.agent = agent
      self.defaultConfigReactor = None
      self.handleDefaultConfig()

   def invalidateSessionPool( self ):
      _invalidateSessionPool()
      self.status_.activeServerForSg.clear()

   @Tac.handler( 'host' )
   def handleHost( self, hostspec=None ):
      self.handleHostEntry( hostspec )

   @Tac.handler( 'defaultConfig' )
   def handleDefaultConfig( self ):
      if self.config.defaultConfig:
         self.defaultConfigReactor = ServerConfigReactor(
            self.config.defaultConfig, self.sslConfig, self.sslStatus, self.config,
            self.status_, self.agent )
      else:
         self.invalidateSessionPool()
         if self.defaultConfigReactor:
            self.defaultConfigReactor.close()
         self.defaultConfigReactor = None

   def handleVrfState( self, vrfName ):
      # Theoretically we could use the vrfName to figure out if we
      # actually care about this change or not, but for simplicity we
      # just hit it with a big hammer and throw out any persistent
      # singleConn connections any time any VRF changes.  Since this
      # is all driven by config change events rather than protocol or
      # packet events, this isn't really a big deal
      _invalidateSessionPool()

class CounterConfigReactor( AaaPluginLib.CounterConfigReactor ):
   counterTypeName = "Ldap::Counters"

class LdapSslReactor( MgmtSecuritySslStatusSm.SslStatusSm ):
   __supportedFeatures__ = [ SslFeature.sslFeatureCertKey,
                             SslFeature.sslFeatureTrustedCert,
                             SslFeature.sslFeatureChainedCert,
                             SslFeature.sslFeatureCrl,
                             SslFeature.sslFeatureTls,
                             SslFeature.sslFeatureFips,
                             SslFeature.sslFeatureCipher ]

   def __init__( self, config, status, profileName,
                 ldapConfig, ldapStatus, containingAgent ):
      self.constants_ = Constants
      self.ldapConfig_ = ldapConfig
      self.containingAgent_ = containingAgent
      self.ldapStatus = ldapStatus
      super().__init__( status, profileName,
                        'Ldap' )

   def handleProfileState( self ):
      traceX( TR_INFO, "handleProfileState" )
      if self.profileStatus_:
         if self.profileStatus_.state == "valid":
            fipsConfig = self.profileStatus_.fipsMode
            self.handleFipsMode( fipsConfig, self.containingAgent_ )
         else:
            self.handleFipsMode( False, self.containingAgent_ )

   def handleProfileDelete( self ):
      traceX( TR_INFO, "handleProfileDelete" )
      self.handleFipsMode( False, self.containingAgent_ )

   def handleFipsMode( self, needFips, agent ):
      traceX( TR_INFO, "handleFipsMode", needFips )
      fipsStatus = util.fipsModeGet()

      def maybeKillAaa( timedOut ):
         if not timedOut:
            os.kill( os.getpid(), signal.SIGKILL )
         assert False, "Timeout in flushing attrlog to Sysdb"

      if needFips != fipsStatus:
         if not agent.initialized:
            traceX( TR_INFO, "Setting fips to true" )
            util.fipsModeSet( needFips )
         else:
            agent.flushEntityLog( maybeKillAaa )
      self.ldapStatus.fips = needFips

_reactors = {}

@Plugins.plugin( provides=[], requires=[] )
def Plugin( ctx ):
   mountGroup = ctx.entityManager.mountGroup()
   config = mountGroup.mount( 'security/aaa/ldap/config', 'Ldap::Config',
                              'r' )
   counterConfig = mountGroup.mount( 'security/aaa/ldap/counterConfig',
                                     'AaaPlugin::CounterConfig', 'r' )
   status = mountGroup.mount( Cell.path( 'security/aaa/ldap/status' ),
                              'Ldap::Status', 'wf' )
   sslConfig = mountGroup.mount( 'mgmt/security/ssl/config',
                                 'Mgmt::Security::Ssl::Config', 'r' )
   sslStatus = mountGroup.mount( 'mgmt/security/ssl/status',
                                 'Mgmt::Security::Ssl::Status', 'r' )
   aaaConfig = ctx.aaaAgent.config
   Tac.Type( "Ira::IraIpStatusMounter" ).doMountEntities( mountGroup.cMg_, True,
                                                          True )
   ipStatus = mountGroup.mount( 'ip/status', 'Ip::Status', 'r' )
   ip6Status = mountGroup.mount( 'ip6/status', 'Ip6::Status', 'r' )
   allVrfStatusLocal = mountGroup.mount( Cell.path( 'ip/vrf/status/local' ),
                                         'Ip::AllVrfStatusLocal', 'r' )

   def _finish():
      _reactors[ "LdapConfigReactor" ] = ConfigReactor( config, status,
                                                        allVrfStatusLocal,
                                                        ipStatus,
                                                        ip6Status,
                                                        sslConfig,
                                                        sslStatus,
                                                        ctx.aaaAgent )
      _reactors[ "LdapCounterConfigReactor" ] = \
            CounterConfigReactor( counterConfig, status )

   mountGroup.close( _finish )
   return LdapPlugin( config, status, aaaConfig, ipStatus, ip6Status,
                      allVrfStatusLocal, sslConfig, sslStatus )
