# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import errno
import ldap
import ldap.filter
import os
import re
import shutil
import tempfile
import threading

from Ark import synchronized
import AaaPluginLib
from AaaPluginLib import TR_ERROR, TR_WARN, TR_AUTHEN, TR_AUTHZ, TR_INFO, TR_DEBUG
import Arnet
from Arnet.NsLib import DEFAULT_NS
from BothTrace import traceX as bt
from BothTrace import Var as bv
from LdapUtil import RunInNetworkNamespace
import Tac
from Tracing import traceX

Constants = Tac.Type( "Mgmt::Security::Ssl::Constants" )
ldapCounterAttrs = ( "bindRequests",
                     "bindFails",
                     "bindSuccesses",
                     "bindTimeouts",
                     "searchRequests",
                     "searchFails",
                     "searchSuccesses",
                     "searchTimeouts" )
counterSysdbStatusLock = threading.Lock()

class _Server:
   def __init__( self, host, port, ns, baseDn, userDn, sslProfile,
                 activeGroupPolicy, searchUsernamePassword ):
      self.server = None # ldap object
      self.host_ = host
      self.port_ = port
      self.ns_ = ns
      self.baseDn = baseDn
      self.userDn = userDn
      self.sslProfile = sslProfile
      self.activeGroupPolicy = activeGroupPolicy
      self.searchUsernamePassword = searchUsernamePassword
      self.statusCallback = None
      self.caCertDir = ""

      aaaHost = host.hostname
      ipv6Match = re.search( Arnet.Ip6AddrRe, host.hostname )
      if ipv6Match:
         aaaHost = "[" + aaaHost + "]"
      self.hostStr = "ldap://" + aaaHost + ":" + str( port )

class Session( AaaPluginLib.Session ):
   """A session with a LDAP server, or potentially multiple servers in a
   failover arrangement.

   The usage model for authentication is to create a Session, add one or more
   servers using addServer, create an authentication request using
   createAuthenReq, configuring the request by setting the user, etc, then
   calling sendAuthenReq.  Check the return value from sendAuthenReq, then call
   continueAuthen one or more times, providing information as requested by the
   server until the authentication negotiation completes.

   The usage model for authorization is to create a Session, add one or more
   servers using addServer, create an authorization request using
   createAuthzReq, configuring the request by setting the user, command, etc,
   then calling sendAuthzReq and checking the return value.

   When a failure occurs during the initial communication with a server, ie.
   before a response has been successfully received by this class, the code
   herein will retry up to once per server.  Any failures that happen later
   during an authentication session will result in an AuthenticationError
   being raised.  Authorization is treated much the same way, except an
   AuthorizationError is raised."""
   def __init__( self, hostgroup ):
      # pylint: disable-next=consider-using-f-string
      traceX( TR_INFO, "creating LDAP session for %s" % hostgroup )
      AaaPluginLib.Session.__init__( self, hostgroup )
      self.servers_ = []
      self.activeServer_ = None

   def getActiveServer( self ):
      return self.activeServer_

   def makex509HashMap( self, x509FileList ):
      traceX( TR_DEBUG, "Ldap::Session::makex509HashMap: ", str( x509FileList ) )
      x509HashMap = {}
      for x509File in x509FileList:
         traceX( TR_DEBUG, "Working on file:", x509File )
         # Update the file to have the full path for later usage
         x509File = Constants.certPath( x509File )
         opensslSubCmd = ""
         with open( x509File ) as handle:
            rawText = handle.read()
            if "BEGIN CERTIFICATE" in rawText:
               opensslSubCmd = "x509"
            elif "BEGIN X509 CRL" in rawText:
               opensslSubCmd = "crl"
         assert opensslSubCmd, "This should have been set"
         certHash = Tac.run( [ "openssl", opensslSubCmd,
                               "-hash", "-noout",
                               "-in", x509File ],
                             stdout=Tac.CAPTURE,
                             stderr=Tac.DISCARD )
         # remove whitespace
         certHash = certHash.strip()
         # pylint: disable-next=use-list-literal
         hashList = x509HashMap.setdefault( certHash, list() )
         hashList.append( x509File )
      return x509HashMap

   def createCaCertSymlinks( self, hashMap, hashDir, isCrl=False ):
      traceX( TR_DEBUG, "Ldap::Session::createCaCertSymlinks: hashDir", hashDir )
      for hashVal, fileList in hashMap.items():
         for idx, fileName in enumerate( fileList ):
            traceX( TR_DEBUG, "hash", hashVal, "idx", idx, "file", fileName )
            numPrefix = ""
            if isCrl:
               #CRLs have a 'r' prefix on their symlink number
               numPrefix = "r"
            # pylint: disable-next=consider-using-f-string
            symlinkName = hashVal + ".{numPrefix}{idx}".format( numPrefix=numPrefix,
                                                                idx=idx )
            symlinkPath = os.path.join( hashDir, symlinkName )
            try:
               os.symlink( fileName, symlinkPath )
            except OSError as e:
               bt( TR_ERROR, "symlink", symlinkPath, "error:", str( e.strerror ) )
               if e.errno != errno.ENOSPC: # pylint: disable=no-else-raise
                  # Gracefully handle no space, bubble all other errors
                  # up
                  raise
               else:
                  return

   def createCaCertDir( self, trustCertList, crlList ):
      traceX( TR_DEBUG, "Ldap::Session::createCaCertDir" )
      try:
         caCertDir = tempfile.mkdtemp( prefix="LdapCaCertDir-" )
      except OSError as e:
         bt( TR_ERROR, "mkdtemp error:", str( e.strerror ) )
         if e.errno != errno.ENOSPC: # pylint: disable=no-else-raise
            # Gracefully handle no space, bubble all other errors
            # up
            raise
         else:
            return ""
      traceX( TR_DEBUG, "TLS caCertDir is", caCertDir )
      certHashMap = self.makex509HashMap( trustCertList )
      crlHashMap = self.makex509HashMap( crlList )
      self.createCaCertSymlinks( certHashMap, caCertDir )
      self.createCaCertSymlinks( crlHashMap, caCertDir, isCrl=True )
      return caCertDir

   def addServer( self, aaaHost, timeout, ns=DEFAULT_NS, counterCallback=None,
                  baseDn=None, userDn=None, sslProfile=None, activeGroupPolicy=None,
                  searchUsernamePassword=None ):
      ss = _Server( aaaHost, aaaHost.port, ns, baseDn, userDn,
                    sslProfile, activeGroupPolicy, searchUsernamePassword )
      ss.statusCallback = counterCallback
      self.servers_.append( ss )
      traceX( TR_INFO, "addServer:", ss.host_.spec, ":", ss.port_, "ns:", ss.ns_ )

      try:
         with RunInNetworkNamespace( ns ):
            ss.server = ldap.initialize( ss.hostStr )
            traceX( TR_INFO, "Initialized server in namespace", ns )
            ss.server.protocol_version = ldap.VERSION3
            # Disable auto-chasing for referrals
            ss.server.set_option( ldap.OPT_REFERRALS, 0 )
            if timeout:
               ss.server.set_option( ldap.OPT_TIMEOUT, int( timeout ) )
            if ss.sslProfile:
               ss.caCertDir = self.createCaCertDir( ss.sslProfile[ "trustedCert" ],
                                                    ss.sslProfile[ "crl" ] )
               ss.server.set_option( ldap.OPT_X_TLS_CACERTDIR,
                                       ss.caCertDir )
               ss.server.set_option( ldap.OPT_X_TLS_REQUIRE_CERT,
                                       ldap.OPT_X_TLS_DEMAND )
               if ss.sslProfile[ "crl" ]:
                  ss.server.set_option( ldap.OPT_X_TLS_CRLCHECK,
                                          ldap.OPT_X_TLS_CRL_ALL )
               ss.server.set_option( ldap.OPT_X_TLS_CIPHER_SUITE,
                                     ss.sslProfile[ "cipherSuite" ] )

               # Set tls version
               tlsVersionMap = [ ( Constants.tlsv1_3,
                                   ldap.OPT_X_TLS_PROTOCOL_TLS1_3 ),
                                 ( Constants.tlsv1_2,
                                   ldap.OPT_X_TLS_PROTOCOL_TLS1_2 ),
                                 ( Constants.tlsv1_1,
                                   ldap.OPT_X_TLS_PROTOCOL_TLS1_1 ),
                                 ( Constants.tlsv1,
                                   ldap.OPT_X_TLS_PROTOCOL_TLS1_0 ) ]

               min_version = None
               max_version = None
               for mask, version in tlsVersionMap:
                  # Check from 1_3 to 1_0
                  if ss.sslProfile[ "tlsVersion" ] & mask:
                     if not max_version:
                        max_version = version
                     min_version = version

               # Use TLS1_2 for min version by default
               if not min_version:
                  min_version = ldap.OPT_X_TLS_PROTOCOL_TLS1_2

               ss.server.set_option( ldap.OPT_X_TLS_PROTOCOL_MIN,
                                     min_version )
               if max_version:
                  ss.server.set_option( ldap.OPT_X_TLS_PROTOCOL_MAX,
                                        max_version )

               ss.server.set_option( ldap.OPT_X_TLS_NEWCTX, 0 )
               ss.server.start_tls_s()

      except ldap.LDAPError as error:
         # not collecting any stats for initialize or tls related failures.
         errorMsg = error.args[ 0 ][ 'desc' ]
         bt( TR_ERROR, "addServer exception:", bv( errorMsg ) )
         raise

   def handleLdapException( self, error, ss, action, isAuthen ):
      operation = 'authentication' if isAuthen else 'authorization'
      unavailableStatus = 'unavailable' if isAuthen else 'authzUnavailable'
      attrs = None if isAuthen else {}
      if isinstance( error, ldap.TIMEOUT ):
         traceX( TR_ERROR, f"{action} timeout during {operation}" )
         counterAttr = f'{action}Timeouts'
         self.serverError( ss, counterAttr )
         errorMsg = f"LDAP server {action} timeout"
         return ( unavailableStatus, errorMsg, attrs )
      elif isinstance( error, ldap.SERVER_DOWN ):
         traceX( TR_ERROR, f"LDAP server unavailable during {operation}" )
         counterAttr = f'{action}Fails'
         self.serverError( ss, counterAttr )
         errorMsg = "LDAP server unavailable"
         return ( unavailableStatus, errorMsg, attrs )
      else:
         searchUsername = ss.searchUsernamePassword.username
         errorMsg = error.args[ 0 ][ 'desc' ]
         bt( TR_ERROR, f"Failing to {action} with", bv( searchUsername ),
             ":", bv( errorMsg ) )
         counterAttr = f'{action}Fails'
         self.serverError( ss, counterAttr )
         return ( "fail", errorMsg, attrs )

   def sendAuthenReq( self, username, password ):
      result, failText, attrs = None, None, None
      self.activeServer_ = None
      for ss in self.servers_:
         result, failText, attrs = self._sendAuthenReq( ss, username, password )
         self.activeServer_ = ss.host_.spec
         if result == 'success': # pylint: disable=no-else-break
            if ss.sslProfile:
               attrs[ 'sslProfile' ] = ss.sslProfile
            break
         elif result == 'fail':
            # authentication failed for searchUser or user
            break
         elif result == 'unavailable':
            # server error, try a different server
            bt( TR_ERROR, "Server error for", ss.host_.spec )

      return ( result, failText, attrs )

   def _sendAuthenReq( self, ss, username, password ):
      traceX( TR_AUTHEN, "Authenticating against server",
              ss.host_.spec.hostname )
      attrs = {}
      searchUsername = ss.searchUsernamePassword.username
      searchPassword = ss.searchUsernamePassword.password
      userMatches = []

      traceX( TR_AUTHEN, "Binding to", searchUsername, ss.ns_ )
      userCn = ss.userDn + "=" + username
      with RunInNetworkNamespace( ss.ns_ ):
         try:
            # OPT_TIMEOUT sets timeout for Ldap operations, and OPT_TIMELIMIT sets
            # timeout for search operations. The timeout configured for Ldap server
            # is for the total operation(combined bind and search times)
            startTime = Tac.now()
            self.serverBind( ss, searchUsername, searchPassword.getClearText() )
            bindTime = round( Tac.now() - startTime )
            self.serverBindSuccess( ss, unbind=False )
         except ldap.LDAPError as error:
            return self.handleLdapException( error, ss, 'bind', isAuthen=True )
         traceX( TR_AUTHEN, "Admin successfully authenticated" )
         # pylint: disable=consider-using-f-string
         timeout = int( ss.server.get_option( ldap.OPT_TIMEOUT ) )
         if timeout <= bindTime:
            return self.handleLdapException( ldap.TIMEOUT(), ss, 'search',
                                             isAuthen=True )
         try:
            ss.server.set_option( ldap.OPT_TIMELIMIT, timeout - bindTime )
            userMatches = self.serverSearch( ss, "(& (%s))" % userCn,
                                             attrlist=None )
         except ldap.LDAPError as error:
            return self.handleLdapException( error, ss, 'search', isAuthen=True )
         # pylint: enable=consider-using-f-string
         self.serverSearchSuccess( ss )
         traceX( TR_AUTHEN, "Search returned results:", userMatches )

         for user in userMatches:
            userDn = user[ 0 ]
            if not userDn:
               continue
            try:
               traceX( TR_DEBUG, "Attempting bind to userDn:", userDn )
               self.serverBind( ss, userDn, password )
               traceX( TR_AUTHEN, "User", userDn, "is successfully authenticate" )
               # Close the connection
               self.serverBindSuccess( ss )
               attrs[ 'userDn' ] = userDn
               return ( "success", "", attrs )
            except ldap.LDAPError as error:
               counterAttr = 'bindFails'
               if isinstance( error, ldap.TIMEOUT ):
                  counterAttr = 'bindTimeouts'
               self.serverError( ss, counterAttr )
               bt( TR_ERROR, "server error:", bv( error.args[ 0 ] ) )

         # pylint: disable-next=consider-using-f-string
         errorMsg = "User %s authentication failed" % username
         return ( "fail", errorMsg, None )

   def sendAuthzReq( self, userDn, searchFilter, groupRolePrivilege ):
      result, failText, attrs = ( 'authzUnavailable', 'no usable server', None )
      self.activeServer_ = None
      for ss in self.servers_:
         result, failText, attrs = self._sendAuthzReq( ss, userDn, searchFilter,
                                                       groupRolePrivilege )
         self.activeServer_ = ss.host_.spec
         if result == 'allowed': # pylint: disable=no-else-break
            break
         elif result == 'denied':
            break
         elif result == 'authzUnavailable':
            bt( TR_ERROR, "Server", self.activeServer_.stringValue(), "unavailable" )
            # try next Server

      return ( result, failText, attrs )

   def _sendAuthzReq( self, ss, userDn, searchFilter, groupRolePrivilege ):
      assert ss
      traceX( TR_AUTHZ, "Binding to", ss.searchUsernamePassword.username, ss.ns_ )
      with RunInNetworkNamespace( ss.ns_ ):
         try:
            startTime = Tac.now()
            self.serverBind( ss, ss.searchUsernamePassword.username,
                             ss.searchUsernamePassword.password.getClearText() )
            bindTime = round( Tac.now() - startTime )
            self.serverBindSuccess( ss, unbind=False )
         except ldap.LDAPError as error:
            return self.handleLdapException( error, ss, 'bind', isAuthen=False )
         traceX( TR_AUTHZ, "Admin", ss.searchUsernamePassword.username,
                  "successfully authenticated" )
         matchedGroups = None
         query = ldap.filter.filter_format( "(&(objectclass=%s)(%s=%s))",
                                             [ searchFilter.group,
                                                searchFilter.member,
                                                userDn ] )
         timeout = int( ss.server.get_option( ldap.OPT_TIMEOUT ) )
         if timeout <= bindTime:
            return self.handleLdapException( ldap.TIMEOUT(), ss, 'search',
                                             isAuthen=False )
         try:
            ss.server.set_option( ldap.OPT_TIMELIMIT, timeout - bindTime )
            matchedGroups = self.serverSearch( ss, query, attrsonly=1 )
            self.serverSearchSuccess( ss )
         except ldap.LDAPError as error:
            return self.handleLdapException( error, ss, 'search',
                                             isAuthen=False )
         traceX( TR_AUTHZ, "Matched groups:", matchedGroups )
         groupSet = set()
         for group in matchedGroups:
            if not group[ 0 ]:
               continue
            groupNameRegex = r"[^,]=([^,]+),"
            matchGroup = re.search( groupNameRegex, group[ 0 ] )
            if matchGroup:
               groupSet.add( matchGroup.group( 1 ) )
         attrs = {}
         traceX( TR_AUTHZ, "Trying to find matching role for group",
                  ', '.join( groupSet ) )
         for grp in groupRolePrivilege.values():
            if grp.group in groupSet:
               bt( TR_AUTHZ, "Matched role", bv( grp.role ),
                     "privilege", bv( grp.privilege ),
                     "for group", bv( grp.group ) )
               attrs[ AaaPluginLib.roles ] = [ grp.role ]
               attrs[ AaaPluginLib.privilegeLevel ] = grp.privilege
               break
         try:
            ss.server.unbind_s()
            ss.server = None
         except ldap.LDAPError as error:
            return self.handleLdapException( error, ss, 'bind', isAuthen=False )
         traceX( TR_DEBUG, "Updated attrs", attrs )
         return ( "allowed", "", attrs )

   def serverBind( self, ss, username, password ):
      cb = ss.statusCallback
      if cb:
         cb( 'bindRequests', 1 )
      else:
         traceX( TR_DEBUG, "No bindRequests callback found for", ss.host_.hostname )
      ss.server.bind_s( username, password )

   def serverBindSuccess( self, ss, unbind=True ):
      if unbind:
         ss.server.unbind_s()
         ss.server = None
      cb = ss.statusCallback
      if cb:
         cb( 'bindSuccesses', 1 )
      else:
         traceX( TR_DEBUG, "No bindSuccesses callback found for", ss.host_.hostname )

   def serverSearch( self, ss, query, attrlist=None, attrsonly=False ):
      cb = ss.statusCallback
      if cb:
         cb( 'searchRequests', 1 )
      else:
         traceX( TR_DEBUG, "No searchRequests callback found for",
                 ss.host_.hostname )
      return ss.server.search_s( ss.baseDn, ldap.SCOPE_SUBTREE, query,
                                 attrlist=attrlist or [],
                                 attrsonly=attrsonly )

   def serverSearchSuccess( self, ss ):
      cb = ss.statusCallback
      if cb:
         cb( 'searchSuccesses', 1 )
      else:
         traceX( TR_DEBUG, "No searchSuccesses callback found for",
                 ss.host_.hostname )

   def serverError( self, ss, counterAttr ):
      cb = ss.statusCallback
      if cb:
         cb( counterAttr, 1 )
      else:
         traceX( TR_DEBUG, "No", counterAttr, "callback found for",
                 ss.host_.stringValue() )

   def close( self ):
      for ss in self.servers_:
         if ss.server:
            traceX( TR_INFO, "Unbind from server" )
            ss.server.unbind_s()
            ss.server = None
         try:
            shutil.rmtree( ss.caCertDir )
         except OSError as excep:
            if excep.errno == errno.ENOENT:
               pass
            else:
               raise excep

   def __del__( self ):
      # If close was not called explicitly
      self.close()

class CounterCallback:

   def __init__( self, status, hostspec ):
      self.status = status
      self.hostspec = hostspec

   @synchronized( counterSysdbStatusLock )
   def __call__( self, attrName, delta ):
      traceLevel = ( TR_ERROR if attrName in ( 'bindFails', 'bindTimeout' )
                     else TR_INFO )
      bt( traceLevel, "LDAP counter", bv( attrName ), "host",
          bv( self.hostspec.stringValue() ) )
      if not self.status:
         return # used for testing
      if attrName not in ldapCounterAttrs:
         bt( TR_ERROR, "unknown counter type", attrName )
         return
      counterToUse = self.status.counter
      if self.hostspec not in counterToUse:
         bt( TR_WARN, "LDAP counters missing entry for",
             self.hostspec.stringValue() )
         return # Operator probably removed ldap host in the
         # middle of the authentication request.
      c = Tac.nonConst( counterToUse[ self.hostspec ] )
      old = c.__getattribute__( attrName )
      new = old + delta
      traceX( TR_DEBUG, "Counter", attrName, "set to", new )
      c.__setattr__( attrName, new )
      counterToUse[ self.hostspec ] = c
