# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import six

import Tac
import Tracing

traceHandle = Tracing.Handle( 'UwsgiAaaCacheLib' )
warn = traceHandle.trace1
info = traceHandle.trace2
trace = traceHandle.trace3
debug = traceHandle.trace4

class UwsgiAaaCacheBaseException( Exception ):
   pass

class UwsgiAaaCacheSessionNotFound( UwsgiAaaCacheBaseException ):
   pass

class UwsgiAaaCacheSessionExpired( UwsgiAaaCacheBaseException ):
   pass

class _CacheEntry:
   def __init__( self, obj ):
      self.obj_ = obj
      self.expiryTime_ = Tac.endOfTime
      self.usageCounter_ = 0

   def setExpiryTime( self, expiryTime ):
      ''' Sets the expiration time of this entry '''
      self.expiryTime_ = expiryTime

   def getExpiryTime( self ):
      ''' Gets the time expiration time of this entry '''
      return self.expiryTime_

   def incrUsageCount( self ):
      ''' Marks this entry as 'in use' so the object won't cleanup '''
      trace( '_CacheEntry.incrUsageCount Marking', self, 'as in use' )
      self.usageCounter_ += 1

   def isInUse( self ):
      ''' Returns True if this entry is currently being used. '''
      return self.usageCounter_ > 0

   def decrUsageCount( self ):
      ''' Releases the entry, so it might be considered for cleanup '''
      trace( '_CacheEntry.decrUsageCount Marking', self, 'as no longer used' )
      assert self.usageCounter_ > 0
      self.usageCounter_ -= 1
      trace( '_CacheEntry.decrUsageCount exiting', self, 'usage cnt',
             self.usageCounter_ )

   def getObject( self ):
      ''' Return the object in the cache '''
      trace( '_CacheEntry.getObject', self.obj_ )
      return self.obj_


class Cache:
   ''' Generic Cache object. '''

   def __init__( self, entryCleanupFn=None ):
      self.entries_ = {}
      self.entryCleanupFn_ = entryCleanupFn
      self.timeoutNotifiee_ = Tac.ClockNotifiee( handler=self._applyTimeouts,
                                                 timeMin=Tac.endOfTime )

   @Tac.withActivityLock
   def cleanup( self ):
      ''' Uncondiontally releases all of the entries from the cache. This
          should only really be used for testing'''
      trace( 'Cache.cleanup entry' )
      # we need to call all of our entries cleanup funcs
      for key, entry in six.iteritems( self.entries_ ):
         trace( 'Entry', entry, 'in use:', entry.isInUse() )
         if self.entryCleanupFn_:
            trace( 'Cleaning up entry', entry )
            self.entryCleanupFn_( key, entry.getObject() )
      self.entries_.clear()

      # Update the timer now that we have no entries
      self.timeoutNotifiee_.timeMin = Tac.endOfTime
      trace( 'Cache.cleanup exit' )

   @Tac.withActivityLock
   def get( self, key, incrementUsageCnt=True ):
      ''' Returns the Object corresponding to the given key.
          None if no Entry exists '''
      trace( 'Cache.get entry key:', key )
      entry = self.entries_.get( key )
      if entry is None:
         raise UwsgiAaaCacheSessionNotFound()
      
      expiryTime = entry.getExpiryTime()
      if expiryTime < Tac.now():
         raise UwsgiAaaCacheSessionExpired()
      if incrementUsageCnt:
         entry.incrUsageCount()
      trace( 'Cache.get exit', entry.getObject() )
      return entry.getObject(), expiryTime

   @Tac.withActivityLock
   def insert( self, key, obj, expiryTime ):
      ''' Insert a new object in the cache '''
      trace( 'Cache.insert entry', obj )
      assert obj is not None
      assert key not in self.entries_, f'{key} duplicate keys!'
      entry = _CacheEntry( obj )
      entry.setExpiryTime( expiryTime )
      self.entries_[ key ] = entry

      # BUG767483: Make sure that we kick start the timer to expire this
      # session if it is never used
      self.timeoutNotifiee_.timeMin = self._getNextDueTime()
      trace( 'Cache.insert exit', obj )

   @Tac.withActivityLock
   def release( self, key ):
      ''' Releases an entry in the cache '''
      trace( 'Cache.release entry', key )
      entry = self.entries_.get( key )
      if entry is None:
         return

      entry.decrUsageCount()
      self._maybeCleanupEntry( key )

      if self.hasKey( key ) and not entry.isInUse():
         # this means that this entry this entry was not cleaned up
         # but that it no longer has any users, that means we have
         # to adjust the timeout to be the min of what currently and
         # and this entry
         dueTime = min( self.timeoutNotifiee_.timeMin, entry.getExpiryTime() )
         self.timeoutNotifiee_.timeMin = dueTime
      trace( 'Cache.release exit', key )

   @Tac.withActivityLock
   def cleanupEntry( self, key ):
      ''' Releases an entry in the cache '''
      trace( 'cleanupEntry entry', key )
      if key not in self.entries_:
         raise UwsgiAaaCacheSessionNotFound()
      entry = self.entries_.get( key )
      debug( 'Deleting entry for', entry )
      if entry:
         del self.entries_[ key ]
  
      if not self.entries_:
         self.timeoutNotifiee_.timeMin = Tac.endOfTime

      if self.entryCleanupFn_:
         self.entryCleanupFn_( key, entry.getObject() )
      trace( 'Cache.cleanupEntry exit', key )

   @Tac.withActivityLock
   def getExpiryTime( self, key ):
      return self.entries_[ key ].getExpiryTime()

   @Tac.withActivityLock
   def hasKey( self, key ):
      return key in self.entries_

   def _maybeCleanupEntry( self, key ):
      entry = self.entries_.get( key )
      if not entry or entry.isInUse():
         return

      if Tac.now() < entry.getExpiryTime():
         return

      # Only delete if it has expired from the cache, and the
      # entry is no longer in use. Otherwise we may be
      # deleting an entry for a user that is still running
      # commands (i.e. the request took longer than
      # gracePeriod)
      self.cleanupEntry( key )

   def _getNextDueTime( self ):
      validEntries = [ e for e in self.entries_.values() if not e.isInUse() ]
      if not validEntries:
         debug( 'No valid entries found in cache' )
         # There are no valid entries that will expire, so return end of time
         return Tac.endOfTime
      return min( e.getExpiryTime() for e in validEntries )

   def _applyTimeouts( self ):
      '''Timer callback. Discards stale entries and computes next callback
      time. This is called from the activity thread'''
      trace( 'Cache._applyTimeouts entry' )
      for key in list( self.entries_ ): # API call changes this collection.
         self._maybeCleanupEntry( key )

      # We do not have to worry about dueTime being in the past.
      # In those cases, _applyTimeouts() will be called immediately.
      dueTime = self._getNextDueTime()
      trace( 'Cache._updateTimer() setting dueTime=', dueTime )
      self.timeoutNotifiee_.timeMin = dueTime
      trace( 'Cache._applyTimeouts exit' )
