# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function
import six
import hashlib
import threading
import uuid

import AaaApiClient
import Ark
import Syscall
import Tac
import _Tac
import Tracing
import UwsgiAaa
import UwsgiAaaCacheLib

traceHandle = Tracing.Handle( 'UwsgiSessionManager' )
warn = traceHandle.trace1
info = traceHandle.trace2
trace = traceHandle.trace3
debug = traceHandle.trace4

class SessionLogoutError( Exception ):
   """ Error representing a failure to authenticate. """

class UwsgiSessionManager:
   def __init__( self, sysname, sessionCacheTimeout ):
      trace( 'UwsgiSessionManager.__init__' )
      self.sysname_ = sysname
      self.threadLocalData_ = threading.local()
      self.sessionCache_ = UwsgiAaaCacheLib.Cache(entryCleanupFn =
                                                  self._cleanupSession )
      # session timeout in seconds
      self.sessionCacheTimeout_ = sessionCacheTimeout 
      self.cacheKeyToSessionId_ = {}
      self.sessionIdToCacheKey_ = {}

   @property
   def aaaApiClient_( self ):
      # AaaApiClient.AaaApiClient is not thread safe, so each thread should
      # create it's own client
      if not hasattr( self.threadLocalData_, 'aaaApiClient' ):
         self.threadLocalData_.aaaApiClient = AaaApiClient.AaaApiClient(
               self.sysname_ )
      return self.threadLocalData_.aaaApiClient

   def _generateCacheKey( self, user, passwdHash, requesterIp, userAgent ):
      return ( user, passwdHash, requesterIp, userAgent )

   def _hashPassword( self, passwd ):
      try:    
         # Sometimes these could be unicode.
         # pylint: disable-msg=E1101
         return hashlib.sha1( passwd ).hexdigest() 
      except TypeError:
         trace( 'authenticate exit bad passwd' )
         # pylint: disable-next=raise-missing-from
         raise UwsgiAaa.AuthenticationError( 'Illegal password format' )

   def _cleanupSession( self, key, entry ):
      # NOTE: This should always be called with activity lock held
      cacheKey = self.sessionIdToCacheKey_[ key ]
      del self.cacheKeyToSessionId_[ cacheKey ]
      del self.sessionIdToCacheKey_[ key ]
      with Tac.ActivityUnlockHolder():
         # release activity lock when closing the session. This in theory
         # could lock 
         self.aaaApiClient_.closeSession( entry.aaaAuthnId )

   def logoutSession( self, sessionId ):
      debug( 'logoutSession', sessionId )
      try:
         self.sessionCache_.cleanupEntry( sessionId )
      except UwsgiAaaCacheLib.UwsgiAaaCacheSessionNotFound:
         # pylint: disable-next=raise-missing-from
         raise SessionLogoutError( 'Illegal logout request' )

   # only allow 1 session create at a time. This is because if 2 simultaneous
   # login requests would end up with the same cache key. In that case we could
   # end up with 2 calls to authenticateAndAuthorizeSession.
   @Ark.synchronized()
   def createSession( self, requesterIp, userAgent, user, passwd, tty, service ):
      debug( 'createSession', user )
      # NB: This function should never be called with the activity lock held
      # otherwise a deadlock can occur
      assert _Tac.activityLockOwner() != Syscall.gettid(), 'Deadlock can occur'
      passwdHash = self._hashPassword( six.ensure_binary( passwd ) )
      cacheKey = self._generateCacheKey( user, passwdHash, requesterIp, userAgent )
      with Tac.ActivityLockHolder():
         sessionId = self.cacheKeyToSessionId_.get( cacheKey )
         if sessionId is not None:
            # This get is part of a login, not a runCmd: don't increment usage.
            authEntry, expiryTime = self.getSession( sessionId,
                                                     incrementUsageCnt=False )
            return ( authEntry, sessionId, expiryTime )

      aaaResult = self.aaaApiClient_.authenticateAndAuthorizeSession( user, passwd,
            service, requesterIp, tty=tty )
      authEntry = UwsgiAaa.parseAaaResults( aaaResult, user )
      # insert into cache with the activity lock held
      with Tac.ActivityLockHolder():
         sessionId = str( uuid.uuid4() )
         while self.sessionCache_.hasKey( sessionId ):
            # Let's keep on looping while we have the same session ID
            sessionId = str( uuid.uuid4() )
         self.cacheKeyToSessionId_[ cacheKey ] = sessionId
         self.sessionIdToCacheKey_[ sessionId ] = cacheKey
         expiryTime = Tac.now() + self.sessionCacheTimeout_
         self.sessionCache_.insert( sessionId, authEntry, expiryTime )
         return ( authEntry, sessionId, expiryTime )

   def getSession( self, sessionId, incrementUsageCnt=True ):
      debug( 'getSession', sessionId )
      try:
         return self.sessionCache_.get( sessionId, incrementUsageCnt )
      except UwsgiAaaCacheLib.UwsgiAaaCacheSessionNotFound:
         # pylint: disable-next=raise-missing-from
         raise UwsgiAaa.AuthenticationError( 'Bad session Id' )
      except UwsgiAaaCacheLib.UwsgiAaaCacheSessionExpired:
         # pylint: disable-next=raise-missing-from
         raise UwsgiAaa.AuthenticationError( 'Session Expired' )

   def releaseSession( self, sessionId ):
      debug( 'releaseSession', sessionId )
      self.sessionCache_.release( sessionId )

   def updateSessionTimeout(self, sessionCacheTimeout):
      debug('current session timeout value', sessionCacheTimeout)
      self.sessionCacheTimeout_ = sessionCacheTimeout