# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from typing import Iterator, Optional
import threading
from collections import namedtuple
from contextlib import contextmanager

import QuickTrace
import Tac

IPETHHEADER_TABLE_TIMEOUT = 300 # < inactivity time to expire eth header

warn = QuickTrace.trace0
trace = QuickTrace.trace1
bv = QuickTrace.Var

EthHeader = namedtuple( 'EthHeader', [ 'vlan', 'src', 'dst' ] )

TacSeconds = float

class Dot1xWebIpEthHeaderTableLocked:
   '''
   This class owns the table where we store the eth header corresponding to each IP,
   which we have to share between threads.

   All functions here must run with the lock acquired, and they can't use most Tac
   functions due to lock restrictions.
   '''

   def __init__( self ) -> None:
      self.ipEthHeaderTable: dict[ str, EthHeader ] = {}
      self.ipExpiration: dict[ str, TacSeconds ] = {}
      self.lock_ = threading.Lock()

   @contextmanager
   def _lock( self ) -> Iterator[ None ]:
      '''
      Protects self members, as we use them in Dot1xWebAgentLib.HttpReqHandler
      instances that run in threads other than main, and in Dot1xL2ForwarderSm.
      '''
      self.lock_.acquire()
      try:
         yield
      finally:
         self.lock_.release()

   def _renewEntry( self, ip: str, now: TacSeconds ) -> TacSeconds:
      '''
      Internal function that doesn't lock - must be called within the lock already
      acquired.
      '''
      expiration = now + IPETHHEADER_TABLE_TIMEOUT
      self.ipExpiration[ ip ] = expiration
      return expiration

   def getIpEthHeader( self, now: TacSeconds, ip: str ) -> Optional[ EthHeader ]:
      with self._lock():
         entry = self.ipEthHeaderTable.get( ip )
         if not entry:
            return None
         self._renewEntry( ip, now )
         return entry

   def setIpEthHeader( self, now: TacSeconds, timeMin: TacSeconds, ip: str,
                       ethHeader: EthHeader ) \
                       -> tuple[ bool, Optional[ TacSeconds ] ]:
      with self._lock():
         isnew = ip not in self.ipEthHeaderTable
         self.ipEthHeaderTable[ ip ] = ethHeader
         expiration = self._renewEntry( ip, now )
         if expiration < timeMin:
            return isnew, expiration
         return isnew, None

   def handleIpEthHeaderExpirationTimer( self,
                                         now: TacSeconds ) -> Optional[ TacSeconds ]:
      '''
      This method is called from Dot1xWebIpEthHeaderTable to handle expiring
      eth headers.

      It "leaks" out the time when we should call it again so that
      Dot1xWebIpEthHeaderTable can schedule it.
      '''
      with self._lock():
         nextExpiration = None
         for ip, expiration in list( self.ipExpiration.items() ):
            if expiration <= now:
               del self.ipEthHeaderTable[ ip ]
               del self.ipExpiration[ ip ]
            if nextExpiration is None or expiration < nextExpiration:
               nextExpiration = expiration
         return nextExpiration

   def __contains__( self, ip: str ) -> bool:
      '''This implements the "in" operator that we use in tests'''
      with self._lock():
         return ip in self.ipEthHeaderTable

   def finish( self ) -> None:
      '''Used in tests to prevent old tables from expiring IPs and polluting logs'''
      with self._lock():
         self.ipEthHeaderTable = {}
         self.ipExpiration = {}

class Dot1xWebIpEthHeaderTable:
   '''
   This class wraps Dot1xWebIpEthHeaderTableLocked and is able to call
   Tac functions, as it doesn't have any locks. It must not own any data.
   '''

   def __init__( self ) -> None:
      self._dot1xWebIpEthHeaderTableLocked = Dot1xWebIpEthHeaderTableLocked()
      self.timer = Tac.ClockNotifiee(
         handler=self._handleIpEthHeaderExpirationTimer )

   def getIpEthHeader( self, ip: str ) -> Optional[ EthHeader ]:
      '''Gets the header corresponding to the IP; can be called from other threads'''
      return self._dot1xWebIpEthHeaderTableLocked.getIpEthHeader( Tac.now(), ip )

   def setIpEthHeader( self, ip: str, ethHeader: EthHeader ) -> bool:
      '''Sets the header corresponding to the IP; can be called from other threads'''
      isnew, nextExpiration = self._dot1xWebIpEthHeaderTableLocked.setIpEthHeader(
         Tac.now(), self.timer.timeMin, ip, ethHeader )
      if nextExpiration is not None:
         self.timer.timeMin = nextExpiration
      return isnew

   def _handleIpEthHeaderExpirationTimer( self ) -> None:
      '''
      Called by from tacc's main loop when we should check for IP eth header
      expiration.

      This calls the corresponding dot1xWebIpEthHeaderTableLocked counterpart, as we
      need the lock to check when we last used a header.
      We then uses the returned nextExpiration to schedule a new call.
      '''
      nextExpiration = \
         self._dot1xWebIpEthHeaderTableLocked.handleIpEthHeaderExpirationTimer(
            Tac.now() )
      if nextExpiration is not None:
         self.timer.timeMin = nextExpiration

   def __contains__( self, ip: str ) -> bool:
      return ip in self._dot1xWebIpEthHeaderTableLocked

   def finish( self ) -> None:
      self._dot1xWebIpEthHeaderTableLocked.finish()
