# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
'''
Dot1xWebAgentLib.py has the core of the agent that handles dot1x web authentication
'''

import Agent
import Ark
import Tac
import Cell
import QuickTrace
import http.server
import socketserver
from time import time
import threading
import SharedMem
import Smash
from multiprocessing.pool import ThreadPool
from Arnet.NsLib import socketAt, DEFAULT_NS
import ssl
from MgmtSecuritySslStatusSm import SslStatusSm, ProfileState
from GenericReactor import GenericReactor
from Dot1xWebL2ForwarderLib import WebAuthNs, WebAuthWireIntfName
from Dot1xWebL2ForwarderLib import Dot1xL2Forwarder
from Dot1xWebIpEthHeaderTable import Dot1xWebIpEthHeaderTable

warn = QuickTrace.trace0
trace = QuickTrace.trace1
bv = QuickTrace.Var

DOT1XWEB_NUM_THREADS = 2

PrivateTcpPorts = Tac.Type( "Arnet::PrivateTcpPorts" )

def protoName( https ):
   return 'https' if https else 'http'

def thId():
   '''Shorthand for getting the current thread's name'''
   # It's worth pointing out that the natural "pythonic" solution here would be to
   # wrap trace* in a function that added the thread name as the first argument - BUT
   # that doesn't work, because trace* does some tricks with the frames.
   # Details: /src/Ark/BothTrace.py
   return bv( threading.current_thread().name )

class Dot1xHttpServer( http.server.HTTPServer, Tac.Notifiee ):
   '''
   Dot1xHttpServer: server class that adapts python's http.server.HTTPServer to
   work with a listening socket that is registered in tacc's activity loop and to use
   a thread pool instead of handling the requests inline.

   We also prevent instances from resolving host names.
   '''
   notifierTypeName = 'Tac::FileDescriptor'

   def __init__( self, agent=None, threadPool=None, sslContext=None ):
      '''
      Replace the constructor of http.server.HTTPServer to register the listening
      socket FD in tacc so that we can handle new connections in the main thread.

      As we are replacing the constructor, might as well move some other
      initializations here:
      - HTTPS socket wrapping if applicable
      - ThreadPool
      - Hard-code server_address and HttpReqHandler

      References:
      - Parent http.server.HTTPServer:
        https://github.com/python/cpython/blob/2.7/Lib/BaseHTTPServer.py#L114
      - SocketServer.TCPServer
        https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L161
      - multiprocessing doc (ThreadPool):
        https://docs.python.org/2/library/multiprocessing.html
      '''
      self.sslContext = sslContext
      server_address = agent.getIpPort( https=bool( self.sslContext ) )
      # The following code comes from TCPServer.__init__
      # We are importing it to replace the socket creation, because we need it to
      # be created inside a WebAuthNs
      # Note: TCPServer is our grandparent; our parent, BaseHTTPServer.HTTPServer,
      # doesn't have a constructor.
      # Refs:
      # https://github.com/python/cpython/blob/2.7/Lib/BaseHTTPServer.py#L102
      # https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L413
      # {{{
      # pylint: disable=non-parent-init-called
      socketserver.BaseServer.__init__( self, server_address, HttpReqHandler )
      self.socket = socketAt( self.address_family, self.socket_type, ns=WebAuthNs )
      try:
         self.server_bind()
         self.server_activate()
      except:
         self.server_close()
         raise
      # }}}
      self.allow_reuse_address = True
      self.threadPool = threadPool
      self.agent = agent
      # Initialize tac reactor aspect:
      self.tacFd = Tac.newInstance( 'Tac::FileDescriptor',
                                    # pylint: disable-next=consider-using-f-string
                                    'fd%d' % ( self.fileno() ) )
      self.tacFd.nonBlocking = True
      self.tacFd.notificationInterface = 'levelTriggered'
      self.tacFd.notifyOnReadable = True
      self.tacFd.descriptor = self.fileno()
      Tac.Notifiee.__init__( self, self.tacFd )

   def server_bind( self ):
      '''
      Replace HTTPServer.server_bind to prevent instances from trying to resolve host
      names

      Original function:
      https://github.com/python/cpython/blob/2.7/Lib/BaseHTTPServer.py#L106
      '''
      socketserver.TCPServer.server_bind( self )
      host, port = self.socket.getsockname()[ : 2 ]
      self.server_name = host
      self.server_port = port

   def getInterfaceID( self, client_address ):
      mac = ''
      if mac := self.agent.dot1xWebIpEthHeaderTable.getIpEthHeader(
            client_address[ 0 ] ):
         for intf, dis in self.agent.dot1xStatus.dot1xIntfStatus.items():
            if mac.src in dis.authSupplicant:
               return Tac.Value( "Arnet::IntfId", intf )
      return None

   @Tac.withActivityLock
   def updateCaptivePortalCounters( self, counterType, client_address ):
      trace( thId(), "update captive portal counters" )
      intfId = self.getInterfaceID( client_address )
      if intfId and intfId in self.agent.webAgentStatus.captivePortalCounters:
         counter = Tac.nonConst( self.agent.webAgentStatus.
                                 captivePortalCounters[ intfId ] )
      else:
         trace( thId(), bv( client_address ), "interface doesn't exist" )
         return

      if bool( self.sslContext ):
         if counterType == "request":
            counter.httpsRequest += 1
         elif counterType == "redirect":
            counter.httpsRedirect += 1
         elif counterType == "invalid":
            counter.httpsInvalid += 1
      else:
         if counterType == "request":
            counter.httpRequest += 1
         elif counterType == "redirect":
            counter.httpRedirect += 1
         elif counterType == "invalid":
            counter.httpInvalid += 1
      self.agent.webAgentStatus.addCaptivePortalCounters( counter )

   def _handle_request_noblock( self ):
      """Handle one request, without blocking.

      Overidden to update the Captive Portal request counter
      """
      try:
         request, client_address = self.get_request()
      except OSError:
         return
      self.updateCaptivePortalCounters( "request", client_address )
      if self.verify_request( request, client_address ):
         try:
            self.process_request( request, client_address )
         except Exception:
            self.handle_error( request, client_address )
            self.shutdown_request( request )
            raise
      else:
         self.shutdown_request( request )

   def process_request_in_pool( self, request, client_address ):
      '''
      Method executed in a secondary thread from the pool to handle a request.

      Essentially a copy of python's ThreadingMixIn.process_request_thread:
      https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L585
      '''
      now = time()
      trace( thId(), bv( client_address ), 'process_request_in_pool start' )
      try:
         self.finish_request( request, client_address )
         self.shutdown_request( request )
      # pylint: disable=broad-except
      except Exception as e:
         warn( thId(), bv( client_address ), 'EXCEPTION in process_request_in_pool:',
               bv( str( e ) ) )
         self.handle_error( request, client_address )
         self.shutdown_request( request )
      finally:
         trace( thId(), bv( client_address ), 'process_request_in_pool done, took',
                bv( time() - now ) )

   def handle_error( self, request, client_address ):
      '''
      Handle the errors related to the server

      Overriden to increment invalid count in Captive Portal Counters
      '''
      self.updateCaptivePortalCounters( "invalid", client_address )
      super().handle_error( request, client_address )

   def process_request( self, request, client_address ):
      '''
      Use the thread pool to handle incoming requests

      Original:
      https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L315

      This is equivalent to what SocketServer.ThreadingMixin does:
      https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L685

      It's unfortunate that SocketServer doesnt' offer a ThreadPoolMixin.
      '''
      now = time()
      trace( thId(), bv( client_address ), 'process_request scheduling' )
      self.threadPool.apply_async( self.process_request_in_pool,
                                   ( request, client_address ) )
      trace( thId(), bv( client_address ), 'process_request scheduled, took',
             bv( time() - now ) )

   @Tac.handler( 'readableCount' )
   def handleReadableCount( self ):
      '''
      Tac reactor to new connections in the listening socket. Called from tac's
      activity loop in the main thread.

      It calls SocketServer.handle_request:
      https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L251

      handle_request in turn:
      That gets the following steps executed:
      - handle_request checks timeouts, does a select: harmless, we got here because
        we already have a connection
      - _handle_request_noblock
      - self.get_request, which is a simple "return self.socket.accept()"
        https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L461
      - self.verify_request, which is just a "return True"
        https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L307
      - self.process_request, defined above.
      '''
      trace( thId(), 'Dot1xHttpServer.handleReadableCount got new connection' )
      self.handle_request()
      # ^ Ends up calling process_request, see python's upstream at
      # https://github.com/python/cpython/blob/2.7/Lib/SocketServer.py#L251

   def shutdown( self ):
      trace( 'Dot1xHttpServer.shutdown: calling server_close' )
      self.server_close()
      trace( 'Dot1xHttpServer.shutdown: done' )

class HttpReqHandler( http.server.BaseHTTPRequestHandler ):
   """Our HTTP request handler

   Socket handling is done by functions in the
   http.server.BaseHTTPRequestHandler base class.
   The read() done in that class, that gets the request proper, blocks - and this is
   the reason we chose to leave Tac's activity loop in the main thread and handle
   the HTTP(S) requests via thread pool"""

   timeout = 60
   # ^ Timeout for reading the client request - copied from nginx:
   # http://nginx.org/en/docs/http/ngx_http_core_module.html#client_header_timeout

   def __init__( self, request, client_address, server ):
      self.request = request
      self.client_address = client_address
      trace( thId(), bv( client_address ), 'HttpReqHandler' )
      self.mac = self.getClientMac( server )
      self.cp = self.getCaptivePortal( server, self.mac )
      if server.sslContext:
         if self.timeout is not None:
            # Done in StreamRequestHandler.setup for regular sockets.
            # For HTTPS, we have to set the timeout before we wrap the socket.
            self.request.settimeout( self.timeout )
         self.request = server.sslContext.wrap_socket( self.request,
                                                       server_side=True )
      http.server.BaseHTTPRequestHandler.__init__( self, self.request,
                                                   self.client_address,
                                                   server )

   def address_string( self ):
      '''
      Replace BaseHTTPRequestHandler.address_string to prevent instances from
      trying to resolve host names

      Original function:
      https://github.com/python/cpython/blob/2.7/Lib/BaseHTTPServer.py#L500
      '''
      host, _ = self.client_address[ : 2 ]
      return host

   @Tac.withActivityLock
   def getClientMac( self, server ):
      mac = server.agent.dot1xWebIpEthHeaderTable.getIpEthHeader(
         self.client_address[ 0 ] )
      trace( thId(), bv( self.client_address ), 'getClientMac got', bv( mac ) )
      if not mac:
         return ""
      return mac.src

   @Tac.withActivityLock
   def getCaptivePortal( self, server, mac ):
      '''
      Get the captive portal config from sysdb

      Decorate withActivityLock so that we don't read from sysdb while the main
      thread is updating the value.
      '''
      dot1xStatus = server.agent.dot1xStatus
      if mac == "":
         trace( thId(), bv( self.client_address ), 'getCaptivePortal for', bv( mac ),
                '(no mac) got', bv( dot1xStatus.captivePortal ), 'from dot1xStatus' )
         return dot1xStatus.captivePortal
      for intf in dot1xStatus.dot1xIntfStatus:
         dis = dot1xStatus.dot1xIntfStatus[ intf ]
         if mac in dis.supplicant:
            supp = dis.supplicant[ mac ]
            trace( thId(), bv( self.client_address ), 'getCaptivePortal for',
                   bv( mac ), 'got', bv( supp.captivePortal ),
                   'from dot1xIntfStatus' )
            return supp.captivePortal
      trace( thId(), bv( self.client_address ), 'getCaptivePortal for', bv( mac ),
             'got', bv( dot1xStatus.captivePortal ), 'from dot1xStatus' )
      return dot1xStatus.captivePortal

   def do_GET( self ):
      """Invoked for every HTTP GET request. Defined by the base class."""
      try:
         trace( thId(), bv( self.client_address ), 'do_GET', 'redirecting to',
                bv( self.cp ) )
         # temporary redirection code 302
         self.send_response( 302 )
         self.send_header( "Location", self.cp )
         self.end_headers()
         self.server.updateCaptivePortalCounters( "redirect", self.client_address )

      # pylint: disable=broad-except
      except Exception as e:
         # internal server error
         warn( thId(), bv( self.client_address ), 'EXCEPTION in do_GET:', bv( e ) )
         self.send_response( 500 )
         self.server.updateCaptivePortalCounters( "invalid", self.client_address )
      trace( thId(), bv( self.client_address ), 'do_GET done' )

   # Ignore pylint warning caused by the "format" argument:
   # pylint: disable=redefined-builtin
   def log_message( self, format, *args ):
      '''
      Replace BaseHTTPRequestHandler.log_message to redirect messages to trace

      Ref: https://github.com/python/cpython/blob/2.7/Lib/BaseHTTPServer.py#L449
      '''
      msg = format % args
      trace( thId(), bv( self.client_address ), bv( msg ) )

   def log_error( self, format, *args ):
      '''
      Called whenever there is error in the requestHandler class.
      '''
      self.server.updateCaptivePortalCounters( "invalid", self.client_address )
      self.log_message( format, *args )
      msg = format % args
      # Send out a redirection if we are timing out the client:
      if msg.startswith( 'Request timed out' ):
         # Just send a hard-coded redirection
         self.wfile.write( b'HTTP/1.0 302 Found\r\n'
                           b'Location: %s\r\n\r\n' % bytes( self.cp, 'ascii' ) )
         trace( thId(), bv( self.client_address ),
                'log_error timeout, sent redirection to', bv( self.cp ) )

class Dot1xWebSslStatusSm( SslStatusSm ):
   def __init__( self, dot1xweb, sslStatus, profileName ):
      trace( thId(), "Dot1xWebSslStatusSm being created for profile",
             bv( profileName ) )
      self.dot1xweb = dot1xweb
      SslStatusSm.__init__( self, sslStatus, profileName, 'Dot1xWeb' )
      if profileName not in self.sslStatus_.profileStatus:
         # SslStatusSm doesnt' call us if the profile is missing
         self.handleProfileDelete()
      trace( thId(), "Dot1xWebSslStatusSm created for profile",
             bv( profileName ) )

   def handleProfileState( self ):
      profileStatus = self.profileStatus_
      if profileStatus.state == ProfileState.valid:
         trace( thId(), "handleProfileState for", bv( self.profileName_ ),
                "valid, state", bv( profileStatus.state ) )
         self.dot1xweb.updateHttpsCertificate( profileStatus.certKeyPath )
      else:
         trace( thId(), "handleProfileState for", bv( self.profileName_ ),
                "not valid, state", bv( profileStatus.state ) )
         for e in profileStatus.error.values():
            trace( thId(), "handleProfileState for", bv( self.profileName_ ),
                   "got error", bv( e.errorAttr ), bv( e.errorType ) )
         self.dot1xweb.updateHttpsCertificate( None )

   def handleProfileDelete( self ):
      trace( thId(), "handleProfileDelete for", bv( self.profileName_ ) )
      self.dot1xweb.updateHttpsCertificate( None )

class Dot1xIntfConfigReactor( Tac.Notifiee ):
   notifierTypeName = "Dot1x::Dot1xIntfConfig"

   def __init__( self, notifier, parent ):
      self.parent = parent
      Tac.Notifiee.__init__( self, notifier )

   @Tac.handler( 'dot1xEnabled' )
   def handleDot1xEnabled( self ):
      intfId = self.notifier_.intfId
      if self.parent.dot1xConfig.dot1xEnabled and \
         self.parent.dot1xConfig.captivePortal.enabled and \
         self.parent.dot1xConfig.dot1xIntfConfig[ intfId ].dot1xEnabled:
         counter = Tac.Value( "Dot1x::CaptivePortalCounter", intfId=intfId )
         self.parent.webAgentStatus.addCaptivePortalCounters( counter )
      else:
         del self.parent.webAgentStatus.captivePortalCounters[ intfId ]

   def close( self ):
      Tac.Notifiee.close( self )

class Dot1xWebCounterSm():
   def __init__( self, dot1xConfig, configReq, webAgentStatus ):
      trace( thId(), "Dot1x Web CounterSm created" )
      self.dot1xConfig = dot1xConfig
      self.configReq = configReq
      self.webAgentStatus = webAgentStatus
      self.cpCounterClearReactor = GenericReactor( self.configReq,
                                                   [ 'clearCaptivePortalCounter' ],
                                                   self.clearCaptivePortalCounters )
      self.dot1xEnabledReactor = GenericReactor( self.dot1xConfig,
                                                 [ 'dot1xEnabled' ],
                                                  self.maybeAddOrRemoveIntfCounters )
      self.captivePortalReactor = GenericReactor( self.dot1xConfig,
                                                  [ 'captivePortal' ],
                                                  self.maybeAddOrRemoveIntfCounters )
      self.dot1xIntfConfigReactor = Tac.collectionChangeReactor(
                                                  self.dot1xConfig.dot1xIntfConfig,
                                                  Dot1xIntfConfigReactor,
                                                  reactorArgs=( self, ) )
      # for init
      self.maybeAddOrRemoveIntfCounters()

   def maybeAddOrRemoveIntfCounters( self, notifiee=None ):
      if self.dot1xConfig.dot1xEnabled and self.dot1xConfig.captivePortal.enabled:
         # iterate through all the interfaces in dot1xConfig
         for intfId, dot1xIntfConfig in self.dot1xConfig.dot1xIntfConfig.items():
            if dot1xIntfConfig.dot1xEnabled:
               if intfId not in self.webAgentStatus.captivePortalCounters:
                  counter = Tac.Value( "Dot1x::CaptivePortalCounter", intfId=intfId )
                  self.webAgentStatus.addCaptivePortalCounters( counter )
            else:
               del self.webAgentStatus.captivePortalCounters[ intfId ]

         # remove counters for which dot1x IntfConfig doesn't exist anymore.
         for intfId in self.webAgentStatus.captivePortalCounters:
            if intfId not in self.dot1xConfig.dot1xIntfConfig:
               del self.webAgentStatus.captivePortalCounters[ intfId ]
      else:
         self.webAgentStatus.captivePortalCounters.clear()

   def clearCaptivePortalCounter( self, intfId ):
      counter = Tac.Value( "Dot1x::CaptivePortalCounter", intfId=intfId )
      self.webAgentStatus.addCaptivePortalCounters( counter )

   def clearCaptivePortalCountersAll( self ):
      trace( thId(), 'clear captive portal counters all' )
      for intfId in self.webAgentStatus.captivePortalCounters:
         self.clearCaptivePortalCounter( intfId )

   def clearCaptivePortalCounters( self, notifiee=None ):
      trace( thId(), 'clear captive portal counters' )
      intfId = self.configReq.clearCaptivePortalCounter.intfId
      if not intfId:
         self.clearCaptivePortalCountersAll()
      elif intfId in self.webAgentStatus.captivePortalCounters:
         self.clearCaptivePortalCounter( intfId )
      trace( thId(), 'clear captive portal counters done' )

class Dot1xWeb( Agent.Agent ):
   def __init__( self, entityManager ):
      trace( 'Dot1x Web Agent init entry' )
      self.agentName = name()
      Agent.Agent.__init__( self, entityManager, agentName=self.agentName )
      self.dot1xConfig = None
      self.dot1xStatus = None
      self.webAgentStatus = None
      self.configReq = None
      self.httpServers = {}
      self.warm_ = None
      self.arpStatus = None
      self.cpReactor = None
      self.intfReactor = None
      self.cpCounterClearReactor = None
      self.kniStatus = None
      self.threadPool = None
      self.sslProfileNameReactor = None
      self.sslStatus = None
      self.netStatus = None
      self.sslStatusSm = None
      self.sslContext = None
      self.dot1xWebCounterSm = None
      self.sysname = entityManager.sysname()
      self.sEm = SharedMem.entityManager( self.sysname, entityManager.isLocalEm() )
      self.dot1xL2Forwarder = None
      self.dot1xWebIpEthHeaderTable = Dot1xWebIpEthHeaderTable()
      trace( 'Dot1x Web Agent init exit' )

   def doInit( self, entityManager ):
      trace( 'doInit entry' )
      mg = entityManager.mountGroup()
      Ark.configureLogManager( self.agentName )
      self.dot1xConfig = mg.mount( 'dot1x/config', 'Dot1x::Config', 'r' )
      self.dot1xStatus = mg.mount( 'dot1x/status', 'Dot1x::Status', 'r' )
      self.webAgentStatus = mg.mount( 'dot1x/webAgentStatus',
                                      'Dot1x::WebAgentStatus', 'w' )
      self.configReq = mg.mount( 'dot1x/configReq', 'Dot1x::ConfigReq', 'r' )
      self.sslStatus = mg.mount( 'mgmt/security/ssl/status',
                                 'Mgmt::Security::Ssl::Status', 'r' )
      self.netStatus = mg.mount( f'cell/{Cell.cellId()}/sys/net/status',
                                 'System::NetStatus', 'r' )
      shMemEm = SharedMem.entityManager( sysdbEm=entityManager )
      self.kniStatus = shMemEm.doMount( f"kni/ns/{DEFAULT_NS}/status",
                                        "KernelNetInfo::Status",
                                        Smash.mountInfo( 'keyshadow' ) )

      def mountDone():
         trace( 'mountDone entry' )
         self.traceCaptivePortal()
         self.webAgentStatus.running = False
         self.webAgentStatus.httpsRunning = False
         # Create L2 forwarder and HTTP servers if everything ok
         # The HTTPS is never created by this call because self.sslContext is not
         # initialized yet, even when the SSL profile is already configured
         self.maybeCreateServers()
         self.cpReactor = GenericReactor( self.dot1xConfig, [ 'captivePortal' ],
                                          self.handleCaptivePortal )
         self.intfReactor = GenericReactor( self.kniStatus, [ 'interface' ],
                                            self.handleIntf )
         self.sslProfileNameReactor = GenericReactor(
            self.dot1xConfig, [ 'captivePortalSslProfileName' ],
            self.handleSslProfileName )
         self.dot1xWebCounterSm = Dot1xWebCounterSm( self.dot1xConfig,
                                                     self.configReq,
                                                     self.webAgentStatus )
         # Forced call to self.handleSslProfileName to initialize self.sslContext if
         # the SSL profile is already configured. This may trigger another call to
         # self.maybeCreateServers() to create the HTTPS server.
         self.handleSslProfileName()
         self.warm_ = True
         trace( 'mountDone exit' )

      trace( 'closing mount group' )
      mg.close( mountDone )
      trace( 'doInit exit' )

   # no sysdb read or write
   def warm( self ):
      return self.warm_

   def getIpPort( self, https ):
      return ( '0.0.0.0',
               PrivateTcpPorts.dot1xWebHttpsPort if https
               else PrivateTcpPorts.dot1xWebHttpPort )

   def maybeCreateServers( self ):
      '''
      Start L2 forwarder and web servers if appropriate

      We don't need the activity lock here because only the main thread is running.
      '''
      if ( not self.dot1xConfig.captivePortal.enabled or
           not self.isWebAuthWireIntfReady() ):
         trace( thId(), "maybeCreateServers: skipping, "
                "enable", bv( self.dot1xConfig.captivePortal.enabled ),
                "intfready", bv( self.isWebAuthWireIntfReady() ) )
         return
      if not self.dot1xL2Forwarder:
         self.dot1xL2Forwarder = Dot1xL2Forwarder( self.dot1xConfig,
                                                   self.netStatus,
                                                   self.webAgentStatus,
                                                   self.configReq,
                                                   self.dot1xWebIpEthHeaderTable )
      if not self.threadPool:
         self.threadPool = ThreadPool( processes=DOT1XWEB_NUM_THREADS )
      if 'http' not in self.httpServers:
         trace( thId(), "maybeCreateServers: creating HTTP server" )
         self.httpServers[ 'http' ] = Dot1xHttpServer(
            agent=self, threadPool=self.threadPool )
      self.webAgentStatus.running = True
      if 'https' not in self.httpServers:
         # self.sslContext tells Dot1xHttpServer to start an HTTPS server
         if not self.sslContext:
            trace( thId(), "maybeCreateServers: skipping HTTPS server, "
                   "no sslContext" )
         else:
            trace( thId(), "maybeCreateServers: creating HTTPS server" )
            self.httpServers[ 'https' ] = Dot1xHttpServer(
               agent=self, threadPool=self.threadPool, sslContext=self.sslContext )
            self.webAgentStatus.httpsRunning = True

   def stopServers( self ):
      '''
      Stop web servers if they are running

      We can't get the activity lock here because the handler threads might want to
      get the captive portal config after a timeout, and we wait for them to finish
      when we self.threadPool.join - that would cause a deadlock.
      '''
      if self.threadPool:
         trace( 'stopServers close threads' )
         self.threadPool.close()
         trace( 'stopServers join threads' )
         self.threadPool.join()
         trace( 'stopServers joined threads' )
      for proto, server in self.httpServers.items():
         trace( 'stopServers shutting down', bv( proto ) )
         server.shutdown()
      self.sslContext = None
      self.httpServers = {}
      self.threadPool = None
      if self.dot1xL2Forwarder:
         self.dot1xL2Forwarder.finish()
         self.dot1xL2Forwarder = None
      trace( 'stopServers done, ready for agent shutdown' )
      self.webAgentStatus.running = False
      self.webAgentStatus.httpsRunning = False

   def traceCaptivePortal( self ):
      for attr in self.dot1xConfig.captivePortal.attributes:
         val = getattr( self.dot1xConfig.captivePortal, attr )
         trace( thId(), 'captivePortal', bv( attr ), '=', bv( val ) )

   def handleCaptivePortal( self, notifiee=None ):
      trace( thId(), 'handleCaptivePortal' )
      self.traceCaptivePortal()
      if self.dot1xConfig.captivePortal.enabled:
         self.maybeCreateServers()
      else:
         self.stopServers()
      trace( thId(), 'handleCaptivePortal done' )

   def handleSslProfileName( self, notifiee=None ):
      trace( thId(), 'handleSslProfileName for',
             bv( self.dot1xConfig.captivePortalSslProfileName ) )
      if self.dot1xConfig.captivePortalSslProfileName:
         self.sslStatusSm = Dot1xWebSslStatusSm(
            self, self.sslStatus, self.dot1xConfig.captivePortalSslProfileName )
      else:
         self.sslStatusSm = None
         self.updateHttpsCertificate( None )
      trace( thId(), 'handleSslProfileName done' )

   def updateHttpsCertificate( self, certKeyPath ):
      httpsServer = self.httpServers.get( 'https' )
      if httpsServer:
         trace( thId(), "updateHttpsCertificate: shutting down HTTPS server" )
         httpsServer.shutdown()
         httpsServer.tacFd.close()
         del self.httpServers[ 'https' ]
         self.sslContext = None
         self.webAgentStatus.httpsRunning = False
      if certKeyPath:
         trace( thId(), "updateHttpsCertificate: creating SSL context for",
                bv( certKeyPath ) )
         self.sslContext = ssl.SSLContext( ssl.PROTOCOL_TLS )
         self.sslContext.load_cert_chain( certKeyPath, certKeyPath )
         self.maybeCreateServers()

   def isWebAuthWireIntfReady( self ):
      '''Check if the webauth platform interface is ready'''
      intfNames = { intf.deviceName
                    for intf in self.kniStatus.interface.values() }
      if WebAuthWireIntfName in intfNames:
         trace( 'isWebAuthWireIntfReady intf', WebAuthWireIntfName, 'found' )
         return True
      trace( 'isWebAuthWireIntfReady intf', WebAuthWireIntfName, 'not found in',
             bv( intfNames ) )
      return False

   def handleIntf( self, notifiee=None, key=None ):
      '''
      Reacts to interfaces appearing in kniStatus (i.e. in Linux)

      Calls maybeCreateServers if "key" points to the webauth interface.
      '''
      if self.webAgentStatus.running:
         return
      if key in self.kniStatus.interface:
         if self.kniStatus.interface[ key ].deviceName == WebAuthWireIntfName:
            trace( thId(), 'handleIntf',
                   bv( self.kniStatus.interface[ key ].deviceName ) )
            self.maybeCreateServers()

def name():
   ''' Call this to establish an explicit dependency on the Dot1xWeb
   agent executable, to be discovered by static analysis. '''
   return 'Dot1xWeb'
