# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import atexit
import json
import os
import threading
import traceback

import Agent
import Cell
import JsonApiConstants
import Plugins
import Tac
import Tracing
import AaaApiClient
import UwsgiConstants

from ApiBaseModels import ModelJsonSerializer
from ControllerdbEntityManager import Controllerdb
from UwsgiRequestContext import UwsgiRequestContext, HttpBadRequest, HttpException
from UwsgiRequestContext import HttpForbidden, HttpServiceAclDenied
from UrlMap import getHandler

traceHandle = Tracing.Handle( 'JsonUwsgiServer' )
log = traceHandle.trace0
warn = traceHandle.trace1
info = traceHandle.trace2
trace = traceHandle.trace3
debug = traceHandle.trace4

Constants = Tac.Value( "Controller::Constants" )

class JsonApiApp( Agent.Agent ):
   def __init__( self, sysEm ):
      trace( 'init entry' )
      Agent.Agent.__init__( self, sysEm, agentName="JsonApiApp" )
      self.mgDone = False
      self.cmgDone = False
      self.sysname = sysEm.sysname()
      self.threadLocalData_ = threading.local()
      self.sysdbPluginMounts = {}
      self.cdbPluginMounts = {}
      self.mounts = {}
      self.serviceAclFilterSm_ = None
      # Service ACL is always enabled now regardless of Epoch setting.
      # We keep this flag so in the future when we reimplement the service ACL
      # in the kernel we can come back and remove any code checking this flag.
      self.serviceAclEnabled_ = True

      trace( "Connecting to controllerdb" )
      cdbSock = os.environ.get( 'CONTROLLERDBSOCKNAME',
                                Constants.controllerdbDefaultSockname )
      self.cEm = Controllerdb( sysEm.sysname(),
                               controllerdbSockname_=cdbSock,
                               dieOnDisconnect=True,
                               mountRoot=False )
      self.aclConfigAggregatorSm = None
      self.aclCpConfigAggregatorSm = None
      self.aclConfig = None
      self.aclCpConfig = None

      self._loadPlugins()
      trace( 'init exit' )

   @property
   def aaaApiClient_( self ):
      # AaaApiClient.AaaApiClient is not thread safe, so each thread should
      # create it's own client
      if not hasattr( self.threadLocalData_, 'aaaApiClient' ):
         self.threadLocalData_.aaaApiClient = AaaApiClient.AaaApiClient(
               self.sysname )
      return self.threadLocalData_.aaaApiClient

   def _loadPlugins( self ):
      # TODO this function needs de-duping to make sure we do not
      # mount the same thing twice
      trace( "Loading JSONApiPlugin" )
      pd = Plugins.loadPlugins( "JSONApiPlugin" )
      for plugin in pd.plugins():
         if not plugin:
            continue
         self.sysdbPluginMounts.update( plugin[ 0 ] )
         self.cdbPluginMounts.update( plugin[ 1 ] )

   def handleMountFailure( self, mountUrl=None ):
      warn( "Failed to mount from ControllerDb" )
      raise Exception( "Failed to mount from ControllerDb" )

   def doInit( self, sysEm ): # pylint: disable=arguments-renamed
      def mountDone():
         trace( 'mountDone entry' )
         self.aclConfigAggregatorSm = Tac.newInstance(
                                                   "Acl::AclConfigAggregatorSm",
                                                   self.mounts[ 'aclConfigDir' ],
                                                   self.mounts[ 'aclCliConfig' ],
                                                   self.aclConfig )
         self.aclCpConfigAggregatorSm = Tac.newInstance(
                                                   "Acl::AclCpConfigAggregatorSm",
                                                   self.mounts[ 'aclCpConfigDir' ],
                                                   self.mounts[ 'aclCliCpConfig' ],
                                                   self.aclCpConfig )
         serviceMap = self.mounts[ 'openStackAgentConfig' ].serviceAclTypeVrfMap
         self.mounts[ 'osApiStatus' ].aclStatus = ( 'OpenStack', )
         aclStatus = self.mounts[ 'osApiStatus' ].aclStatus
         if self.serviceAclEnabled_:
            self.serviceAclFilterSm_ = Tac.newInstance( 'Acl::ServiceAclFilterSm',
                                                        'OpenStack',
                                                        self.aclConfig,
                                                        serviceMap,
                                                        aclStatus )
         for sessId, sess in self.mounts[ 'aaaSessStatus' ].session.items():
            if sess.tty == JsonApiConstants.TTY_NAME:
               self.aaaApiClient_.closeSession( sessId )
         self.mgDone = True
         trace( 'mountDone exit' )

      def cmgDone():
         trace( 'cmgDone entry' )
         self.cmgDone = True
         trace( 'cmgDone exit' )

      trace( 'doInit entry' )

      self.aclConfig = Tac.newInstance( "Acl::Config" )
      self.aclCpConfig = Tac.newInstance( "Acl::CpConfig" )

      cMountGroup = self.cEm.mountGroup(
            mountFailureCallback=self.handleMountFailure,
            persistent=True )

      # We must mount the root or else other mounts fail
      self.mounts[ 'controllerdbRoot' ] = cMountGroup.mount( "", "Tac::Dir", "rt" )

      for name, mountInfo in self.cdbPluginMounts.items():
         path, obj, mode = mountInfo
         # pylint: disable-next=consider-using-f-string
         trace( "Adding ControllerDb mount %s, %s, %s under key %s" %
               ( path, obj, mode, name ) )
         self.mounts[ name ] = cMountGroup.mount( path, obj, mode )
      cMountGroup.close( cmgDone )

      sMountGroup = sysEm.mountGroup()
      self.mounts[ 'aclConfigDir' ] = sMountGroup.mountPath( 'acl/config/input' )
      self.mounts[ 'aclCliConfig' ] = sMountGroup.mountPath( 'acl/config/cli' )
      self.mounts[ 'aclCpConfigDir' ] = sMountGroup.mountPath( 'acl/cpconfig/input' )
      self.mounts[ 'aclCliCpConfig' ] = sMountGroup.mountPath( 'acl/cpconfig/cli' )
      self.mounts[ 'aclParamConfig' ] = sMountGroup.mountPath( 'acl/paramconfig' )
      # TODO - move this to something not in the openstack mount point
      self.mounts[ 'osApiStatus' ] = sMountGroup.mountPath(
         'mgmt/openstack/osApiStatus' )
      self.mounts[ 'aaaSessStatus' ] = sMountGroup.mountPath(
         Cell.path( 'security/aaa/status' ) )
      self.mounts[ 'openStackAgentConfig' ] = sMountGroup.mountPath(
         'mgmt/openstack/config' )

      for name, path in self.sysdbPluginMounts.items():
         trace( f"Adding Sysdb mount {path} under key {name}" )
         self.mounts[ name ] = sMountGroup.mountPath( path )

      trace( 'closing mount group' )
      sMountGroup.close( mountDone )
      trace( 'doInit exit' )

   def warm( self ):
      return self.mgDone and self.cmgDone

   def _checkAuthRole( self, requestContext, userContext ):
      cfgAuthRole = self.mounts[ 'openStackAgentConfig' ].authRole
      reqAuthRoles = self.aaaApiClient_.getSessionRoles( userContext.aaaAuthnId )
      if cfgAuthRole != '' and cfgAuthRole not in reqAuthRoles:
         raise HttpForbidden( 'User is not authorized' )

   def processRequest( self, request ):
      try:
         requestContext = UwsgiRequestContext( self.sysname, request,
            serviceAclFilterSm=self.serviceAclFilterSm_,
            aaaApiClient=self.aaaApiClient_ )

         if not requestContext.aclPermitConnection():
            raise HttpServiceAclDenied( 'Filtered by service ACL' )

         userContext = requestContext.authenticate( tty=JsonApiConstants.TTY_NAME,
                  service=JsonApiConstants.PAM_SERVICE_NAME )
         try:
            self._checkAuthRole( requestContext, userContext )
            rType = requestContext.getRequestType()
            parsedUrl = requestContext.getParsedUrl()
            func, kwargs = getHandler( rType, parsedUrl )
            if func is not None:
               # pylint: disable-next=consider-using-f-string
               trace( 'calling handler %s' % func.__name__ )
               result = func( requestContext, self.mounts, **kwargs )
               result = [
                  chunk.encode()
                  for chunk in ModelJsonSerializer().iterencode( result )
               ]
            else:
               raise HttpBadRequest( 'Invalid endpoint requested' )
            trace( 'processRequest exit', result )
            return ( '200 OK', 'application/json', None, result )
         finally:
            self.aaaApiClient_.closeSession( userContext.aaaAuthnId )
      except HttpException as e:
         traceback.print_exc()
         trace( 'processRequest HttpException', e )
         msg = json.dumps( { 'error': e.message } )
         msg = msg.encode()
         return ( f'{e.code} {e.name}', 'application/json',
                  e.additionalHeaders, msg )
      except Exception as e: # pylint: disable=broad-except
         trace( 'processRequest Exception', e )
         traceback.print_exc()  # Log stack trace to agent log.
         msg = json.dumps( { 'error': str( e ) } )
         msg = msg.encode()
         return ( '500 Internal Server Error', 'application/json', None, msg )

class JsonApplication:
   def __init__( self ):
      trace( 'init entry' )
      self.container_ = Agent.AgentContainer( [ JsonApiApp ],
                                           agentTitle="JsonApiApp" )
      self.container_.startAgents()
      # pylint: disable-msg=W0212
      atexit.register( lambda: os._exit( 0 ) ) # on exiting we have to be brutal
      Tac.activityThread().start( daemon=True )
      Tac.waitFor( self.apiAgentWarm, description="JsonApiApp to be warm",
                   maxDelay=1, sleep=True )
      self.apiAgent_ = self.container_.agents_[ 0 ]
      trace( 'init exit' )

   def apiAgentWarm( self ):
      return self.container_.agents_ and self.container_.agents_[ 0 ].warm()

   def __call__( self, request, start_response ):
      ( reponseCode, contentType, headers, body ) = \
         self.apiAgent_.processRequest( request )
      headers = headers if headers else []
      headers.append( ( 'Content-type', contentType ) )
      if body:
         if isinstance( body, list ):
            length = sum( len( chunk ) for chunk in body )
         else:
            length = len( body )
            body = [ body ]
         headers.append( ( 'Content-length', str( length ) ) )
      start_response( reponseCode,
                      UwsgiConstants.DEFAULT_HEADERS + headers )
      return body
