# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import atexit
import json
import io
import os
import re
import threading
import traceback
import time
import types
import mimetypes
import uwsgi

import AaaApiClient
import Agent
import Ark
import CapiConstants
import Cell
import CliCommon
import EapiClientLib
import Tac
import Tracing
import UwsgiConstants
import UwsgiRequestContext
import UwsgiSessionManager
from IpLibConsts import DEFAULT_VRF
from io import IOBase

# Dummy import to create dependency on Epoch-lib which
# provides HwEpoch::Status
import HwEpochPolicy # pylint: disable-msg=unused-import

# Dummy import to create dependency on Acl-lib which
# provides Acl::ServiceAclFilterSm
import AclLib # pylint: disable-msg=unused-import

traceHandle = Tracing.Handle( 'CapiUwsgiServer' )
log = traceHandle.trace0
warn = traceHandle.trace1
info = traceHandle.trace2
trace = traceHandle.trace3
debug = traceHandle.trace4

class ExecRequestReactor( Tac.Notifiee ):
   """ Adjusts to changes in the ExecRequest attributes """

   notifierTypeName = 'HttpService::ExecRequest'

   def __init__( self, execRequest, capiStatus ):
      trace( 'ExecRequestReactor.__init__ enter' )
      super().__init__( execRequest )
      self.capiStatus_ = capiStatus
      # Reset statistics if needed
      self.resetStatistics( execRequest.lastStatsResetTime )
      trace( 'ExecRequestReactor.__init__ exit' )

   @Tac.handler( 'lastStatsResetTime' )
   def handleLastStatsResetTime( self ):
      trace( 'ExecRequestReactor.handleLastStatsResetTime enter' )
      self.resetStatistics( self.notifier_.lastStatsResetTime )
      trace( 'ExecRequestReactor.handleLastStatsResetTime exit' )

   def resetStatistics( self, lastStatsResetTime ):
      if lastStatsResetTime <= self.capiStatus_.lastStatsResetTime:
         # We don't need to reset the statistics
         return

      self.capiStatus_.hitCount = 0
      self.capiStatus_.requestCount = 0
      self.capiStatus_.commandCount = 0
      self.capiStatus_.bytesInCount = 0
      self.capiStatus_.bytesOutCount = 0
      self.capiStatus_.lastHitTime = 0.0
      self.capiStatus_.requestDuration = 0.0
      self.capiStatus_.user.clear()
      self.capiStatus_.lastStatsResetTime = Tac.now()

class CapiConfigReactor( Tac.Notifiee ):
   """ Adjusts to changes in the Config attributes """
   notifierTypeName = 'HttpService::Config'

   def __init__( self, capiConfig, sessionManager ):
      trace( 'CapiConfigReactor.__init__ enter' )
      super().__init__( capiConfig )
      self.sessionManager_ = sessionManager
      trace( 'CapiConfigReactor.__init__ exit' )

   @Tac.handler( 'sessionTimeout' )
   def handleSessionTimeout( self ):
      trace( 'CapiConfigReactor.handleSessionTimeout enter' )
      # before passing into method, convert minutes into seconds
      self.sessionManager_.updateSessionTimeout( self.notifier_.sessionTimeout * 60 )
      trace( 'CapiConfigReactor.handleSessionTimeout exit' )

class CapiApp( Agent.Agent ):
   def __init__( self, entityManager ):
      trace( 'init entry' )
      Agent.Agent.__init__( self, entityManager, agentName="CapiApp" )
      self.threadLocalData_ = threading.local()
      self.sessionManager_ = None
      self.execRequestReactor_ = None
      self.capiConfigReactor_ = None
      self.capiStatus_ = None
      self.execRequest_ = None
      self.capiConfig_ = None
      self.httpServerStatus_ = None
      self.aclConfig_ = None
      self.aclConfigDir_ = None
      self.aclCliConfig_ = None
      self.aclConfigAggregatorSm_ = None
      self.aaaSessStatus_ = None
      self.serviceAclFilterSm_ = None
      self.warm_ = False
      self.statsUpdateLock = threading.Lock()
      self.sysname = entityManager.sysname()
      # Service ACL is always enabled now regardless of Epoch setting.
      # We keep this flag so in the future when we reimplement the service ACL
      # in the kernel we can come back and remove any code checking this flag.
      self.serviceAclEnabled = True
      trace( 'init exit' )

   @property
   def aaaApiClient_( self ):
      # AaaApiClient.AaaApiClient is not thread safe, so each thread should
      # create it's own client
      if not hasattr( self.threadLocalData_, 'aaaApiClient' ):
         self.threadLocalData_.aaaApiClient = AaaApiClient.AaaApiClient(
               self.sysname )
      return self.threadLocalData_.aaaApiClient

   def doInit( self, entityManager ):
      trace( 'doInit entry' )
      mountGroup = entityManager.mountGroup()
      # Set up the LogManager. This is basically everything that the
      # Agent base class normally does for us.
      Ark.configureLogManager( "Capi" )
      self.execRequest_ = mountGroup.mountPath( 'mgmt/httpserver/execRequest' )
      self.capiStatus_ = mountGroup.mountPath(
            'mgmt/httpserver/service/http-commands' )
      self.capiConfig_ = mountGroup.mountPath( 'mgmt/capi/config' )
      self.httpServerStatus_ = mountGroup.mountPath( 'mgmt/httpserver/status' )
      self.aclConfig_ = Tac.newInstance( "Acl::Config" )
      self.aclConfigDir_ = mountGroup.mountPath( 'acl/config/input' )
      self.aclCliConfig_ = mountGroup.mountPath( 'acl/config/cli' )
      self.aaaSessStatus_ = mountGroup.mountPath(
            Cell.path( 'security/aaa/status' ) )

      def mountDone():
         trace( 'mountDone entry' )
         self.sessionManager_ = UwsgiSessionManager.UwsgiSessionManager(
               entityManager.sysname(), self.capiConfig_.sessionTimeout * 60 )
         self.execRequestReactor_ = ExecRequestReactor( self.execRequest_,
                                                        self.capiStatus_ )
         self.capiConfigReactor_ = CapiConfigReactor( self.capiConfig_,
                                                     self.sessionManager_ )
         self.aclConfigAggregatorSm_ = Tac.newInstance(
                                          "Acl::AclConfigAggregatorSm",
                                          self.aclConfigDir_,
                                          self.aclCliConfig_,
                                          self.aclConfig_ )
         self.capiStatus_.aclStatus = ( 'Capi', )
         aclStatus = self.capiStatus_.aclStatus
         serviceMap = self.capiConfig_.serviceAclTypeVrfMap
         if self.serviceAclEnabled:
            self.serviceAclFilterSm_ = Tac.newInstance(
                                             'Acl::ServiceAclFilterSm',
                                             'Capi',
                                              self.aclConfig_,
                                              serviceMap,
                                              aclStatus )
         # BUG239350 Clear any leftover CAPI sessions.
         # Code looks sloppy, but regular dict iteration is not thread-safe.
         for sessId in self.aaaSessStatus_.session:
            sess = self.aaaSessStatus_.session.get( sessId )
            if sess and sess.tty == CapiConstants.TTY_NAME:
               self.aaaApiClient_.closeSession( sessId )

         self.warm_ = True
         trace( 'mountDone exit' )

      trace( 'closing mount group' )
      mountGroup.close( mountDone )
      trace( 'doInit exit' )

   def warm( self ):
      return self.warm_

   def updateStatistics( self, requestContext, bytesOut, requestCount,
                         commandCount, userContext ):
      trace( 'updateStatistics entry' )
      if requestCount is None:
         # No data read from socket to backend maybe due to a crash.
         return

      def update( stats ) :
         finish = Tac.now()
         stats.lastHitTime = finish
         stats.hitCount += 1
         stats.bytesInCount += requestContext.getContentLength()
         stats.bytesOutCount += bytesOut
         stats.requestCount += requestCount
         stats.commandCount += commandCount
         if requestCount > 0:
            stats.requestDuration += finish - requestContext.getRequestStartTime()
      with self.statsUpdateLock:
         update( self.capiStatus_ )
         username = userContext.user
         if not username:
            trace( 'updateStatistics exit' )
            return

         users = self.capiStatus_.user
         user = ( users[ username ] if username in users
                                    else users.newMember( username ) )
         update( user )
      trace( 'updateStatistics exit', username )

   def _processCommandApiGenerator( self, requestContext, userContext ):
      debug( "_processCommandApiGenerator start", requestContext )
      requestContent = requestContext.getRequestContent()
      if not requestContent:
         yield json.dumps( {
            "jsonrpc": "2.0",
            "error": {
               "code": CliCommon.JsonRpcErrorCodes.PARSE_ERROR,
               "message": "No data provided in HTTP POST body" },
            "id": None
         } ).encode()
         requestContext.deauthenticate( userContext )
         return

      ctx = None
      try:
         ctx = EapiClientLib.EapiClient( sysname=self.sysname,
                  disableAaa=requestContext.aaaDisabled(),
                  privLevel=userContext.privLevel,
                  uid=userContext.uid,
                  gid=userContext.gid,
                  aaaAuthnId=userContext.aaaAuthnId,
                  realTty=CapiConstants.TTY_NAME,
                  logName=userContext.user,
                  sshConnection=requestContext.getRequesterIp() )
         # connect statelessly since it will only run 1 set of commands
         ctx.connect( stateless=True )
         bytesOut = 0
         # pylint: disable-msg=protected-access
         resultGenerator, statsFunc = ctx.sendRpcRequest( requestContent )
         debug( "_processCommandApiGenerator running result generator" )
         for i in resultGenerator:
            debug( "_processCommandApiGenerator genertor result", i )
            if i is not None:
               bytesOut += len( i )
            yield i

         requestCount, commandCount = statsFunc()
         self.updateStatistics( requestContext, bytesOut, requestCount,
               commandCount, userContext )
      finally:
         requestContext.deauthenticate( userContext )
         if ctx:
            ctx.close()

   # pylint: disable-next=inconsistent-return-statements
   def _processCommandApi( self, requestContext ):
      trace( '_processCommandApi entry' )
      if requestContext.getHeader( 'HTTP_SEC_WEBSOCKET_KEY' ):
         self.handleWebsocketRequest( requestContext )
         return
      else:
         userContext = requestContext.authenticate(
               tty=CapiConstants.TTY_NAME,
               service=CapiConstants.PAM_SERVICE_NAME )
         generator = self._processCommandApiGenerator( requestContext, userContext )
         trace( '_processCommandApi exit' )
         return ( '200 OK', 'application/json', None, generator )

   def _getCookieHeader( self, sessionId, expiryTime ):
      # Expiry time is in seconds from epoc. Convert it to utc time.
      utcTime = expiryTime + Tac.utcNow() - Tac.now()
      expiresValue = time.strftime( "%a, %d-%b-%Y %H:%M:%S GMT",
                                    time.gmtime( utcTime ) )
      headerValue = "Session=%s;Expires=%s;Path=/;HttpOnly" \
                           % ( sessionId, expiresValue )
      return ( "Set-Cookie", headerValue )

   def _processLogin( self, requestContext ):
      """ Sends a set-cookie header back. """
      trace( 'Processing login' )
      userContext = requestContext.login(
         tty=CapiConstants.TTY_NAME,
         service=CapiConstants.PAM_SERVICE_NAME )
      cookieHeader = [ self._getCookieHeader( userContext.sessionId,
                                              userContext.expiryTime ) ]
      trace( 'Generated cookie', cookieHeader )
      return ( '200 OK', 'application/json', cookieHeader, '{ "Login" : true }' )

   def _processLogout( self, requestContext ):
      trace( 'Processing logout' )
      # Clear client cookie by setting it in the past.
      requestContext.logout()
      cookieHeader = [ self._getCookieHeader( 'Deleted', 1  ) ]
      trace( 'Generated cookie', cookieHeader )
      return ( '200 OK', 'application/json', cookieHeader, '{ "Logout" : true }' )

   def _processRedirectUrl( self, requestContext, redirectFile ):
      header = [ ( 'Location', requestContext.getRedirectUrl( redirectFile ) ) ]
      return ( '301 Moved Permanently', 'text/html', header, 'Page redirect' )

   def _processMainPage( self, requestContext ):
      trace( 'Process main page request' )
      return self._processRedirectUrl( requestContext, 'eapi/' )

   def _processStaticContent( self, requestContext ):
      trace( 'Processing static request' )
      parsedUrl = requestContext.getParsedUrl()
      path = requestContext.getDocRoot() + parsedUrl.path

      if not os.path.exists( path ) and parsedUrl.path.startswith( '/eapi/' ):
         path = requestContext.getDocRoot() + '/eapi/'

      if os.path.isdir( path ):
         if not path.endswith( '/' ):
            # if the path is a dir and doesn't end with a /,
            # we should force a redirect them to the page with a /
            return self._processRedirectUrl( requestContext,
                                             os.path.join( parsedUrl.path, '' ) )
         else:
            # If a request ends with a slash,
            # NGINX treats it as a request for a directory
            # and tries to find an index file in the directory.
            # we should do the same to keep behavior consistent.
            path = os.path.join( path, 'index.html' )


      if not os.path.exists( path ):
         path = path + '.gz'
         if not os.path.exists( path ):
            raise UwsgiRequestContext.HttpNotFound( 'Page not found' )
      if requestContext.getRequestType() not in [ 'HEAD', 'GET' ]:
         raise UwsgiRequestContext.HttpMethodNotAllowed( 'Request method '
                                                         'does not support' )
      ( content_type, encoding ) = mimetypes.guess_type( path )
      if not content_type:
         content_type = 'application/octet-stream'
      fileSize = os.path.getsize( path )
      utcTime = os.path.getmtime( path )
      fileMTime = time.strftime( "%a, %d %b %Y %H:%M:%S GMT",
                                 time.gmtime( utcTime ) )
      header = [ ( 'Content-Length', str( fileSize ) ),
                 ( 'Last-Modified', fileMTime ) ]
      if encoding:
         header.append( ( 'Content-Encoding', encoding ) )
      return ( '200 OK', content_type, header,
               requestContext.getStaticFile( path ) )

   def _checkCapiEnabled( self, requestContext ):
      vrf = requestContext.getVrfName() or DEFAULT_VRF
      s = CapiConstants.SERVICE_NAME
      return ( vrf in self.httpServerStatus_.vrfStatus and
               s in self.httpServerStatus_.vrfStatus[ vrf ].vrfService and
               self.httpServerStatus_.vrfStatus[ vrf ].vrfService[ s ].enabled )

   def _processEndpoint( self, requestContext ):
      endPoint = requestContext.getEndPoint()
      if self._checkCapiEnabled( requestContext ):
         if re.search( '^/%s.*' % CapiConstants.LOGIN_ENDPOINT, endPoint ):
            return self._processLogin( requestContext )
         elif re.search( '^/%s.*' % CapiConstants.LOGOUT_ENDPOINT, endPoint ):
            return self._processLogout( requestContext )
         elif ( re.search( '^/%s.*' % CapiConstants.COMMAND_ENDPOINT, endPoint )
                or requestContext.getRequesterIp().startswith( 'unix:' ) ):
            return self._processCommandApi( requestContext )
      # pylint: disable-next=consider-using-in
      if( endPoint == '/' or endPoint == '/explorer.html' or
          endPoint == '/overview.html' or endPoint == '/documentation.html' ):
         return self._processMainPage( requestContext )
      return self._processStaticContent( requestContext )

   def processRequest( self, request ):
      """Common implementation of all HTTP requests."""
      trace( 'processRequest entry' )
      try:
         requestContext = UwsgiRequestContext.UwsgiRequestContext(
               self.sysname,
               request,
               serviceAclFilterSm=self.serviceAclFilterSm_,
               sessionManager=self.sessionManager_,
               aaaApiClient=self.aaaApiClient_ )
         if not requestContext.aclPermitConnection():
            raise UwsgiRequestContext.HttpServiceAclDenied('Filtered by service ACL')
         return self._processEndpoint( requestContext )
      except UwsgiRequestContext.HttpException as e:
         trace( 'processRequest HttpException', e )
         return ( f'{e.code} {e.name}', e.contentType, e.additionalHeaders,
                  e.message )
      except Exception as e: # pylint: disable=broad-except
         trace( 'processRequest Exception', e )
         traceback.print_exc()  # Log stack trace to agent log.
         return ( '500 Internal Server Error', 'text/html', None, e )

   def _handleWsJsonrpcRequest( self, requestContext, userContext ):
      with EapiClientLib.EapiClient( sysname=self.sysname,
            disableAaa=requestContext.aaaDisabled(),
            privLevel=userContext.privLevel,
            uid=userContext.uid,
            gid=userContext.gid,
            aaaAuthnId=userContext.aaaAuthnId,
            realTty=CapiConstants.TTY_NAME,
            logName=userContext.user,
            sshConnection=requestContext.getRequesterIp() ) as ctx:
         trace( 'handleWebsocketRequest ctx created and connected to backend' )
         while True:
            try:
               # pylint: disable-msg=no-member
               trace( 'handleWebsocketRequest waiting on intput' )
               request = json.loads( uwsgi.websocket_recv() )[ 'request' ]
               trace( 'handleWebsocketRequest input received' )
               resultGenerator, statsFunc = ctx.sendRpcRequest(
                     json.dumps( request ) )
               with io.BytesIO() as responseBuffer:
                  for i in resultGenerator:
                     if i is None:
                        break
                     responseBuffer.write( i )
                  result = responseBuffer.getvalue().decode()
               requestCount, commandCount = statsFunc()

               trace( 'handleWebsocketRequest cmd done running' )
               uwsgi.websocket_send( f'{{"response": {result} }}' )
               trace( 'handleWebsocketRequest result sent to client' )
               bytesOut = len( result )
               self.updateStatistics( requestContext, bytesOut, requestCount,
                                      commandCount, userContext )
            except OSError:
               # this means that the front-end disconnected
               return

   def handleWebsocketRequest( self, requestContext ):
      trace( 'handleWebsocketRequest enter', requestContext )
      # pylint: disable-msg=no-member
      userContext = None
      httpSecWebsocketKey = requestContext.getHeader( 'HTTP_SEC_WEBSOCKET_KEY' )
      httpOrigin = requestContext.getHeader( 'HTTP_ORIGIN', '' )
      uwsgi.websocket_handshake( httpSecWebsocketKey, httpOrigin )
      trace( 'handleWebsocketRequest handshake complete' )

      try:
         negotiationInfo = json.loads( uwsgi.websocket_recv() )
         connectionType = negotiationInfo[ 'connectionType' ]
         authInfo = negotiationInfo[ 'authentication' ]
         username = str( authInfo[ 'username' ] )
         password = str( authInfo[ 'password' ] )
         try:
            userContext = requestContext.authenticate(
                  tty=CapiConstants.TTY_NAME,
                  service=CapiConstants.PAM_SERVICE_NAME,
                  user=username, passwd=password )
         except UwsgiRequestContext.HttpUnauthorized as e:
            response = { 'connectionType': connectionType,
                         'authentication': { 'status': 'error',
                                            'message': str( e ) } }
            uwsgi.websocket_send( json.dumps( response ) )
            return
         response = { 'connectionType': connectionType,
                      'authentication': { 'status': 'success' } }
         uwsgi.websocket_send( json.dumps( response ) )
         trace( 'handleWebsocketRequest authenticated', authInfo[ 'username' ] )
         if connectionType == 'jsonrpc':
            self._handleWsJsonrpcRequest( requestContext, userContext )
         else:
            assert False, 'Unknown connection type'
      except OSError:
         # this means that the front-end disconnected
         return
      except UwsgiRequestContext.HttpException as e:
         trace( 'handleWebsocketRequest HttpException', e )
         result = { 'error': { 'code': e.code,
                               'name': e.name,
                               'message': e.message } }
         uwsgi.websocket_send( json.dumps( result ) )
         return
      except Exception as e: # pylint: disable=broad-except
         trace( 'handleWebsocketRequest Exception', e )
         result = { 'error': { 'code': 500,
                               'name': 'Internal Server Error',
                               'message': str( e ) } }
         uwsgi.websocket_send( json.dumps( result ) )
         traceback.print_exc()  # Log stack trace to agent log.
         return
      finally:
         if userContext:
            self.aaaApiClient_.closeSession( userContext.aaaAuthnId )

class CapiApplication:
   def __init__( self ):
      trace( 'CapiApplication init entry' )
      self.container_ = Agent.AgentContainer( [ CapiApp ],
                                              agentTitle='CapiApp' )
      self.container_.startAgents()
      # pylint: disable-msg=protected-access
      atexit.register( lambda: os._exit( 0 ) ) # on exiting we have to be brutal
      Tac.activityThread().start( daemon=True )
      Tac.waitFor( self.capiAgentWarm, description="CapiApp to be warm",
                   maxDelay=1, sleep=True )
      self.capiAgent_ = self.container_.agents_[ 0 ]
      trace( 'CapiApplication init exit' )

   def capiAgentWarm( self ):
      return self.container_.agents_ and self.container_.agents_[ 0 ].warm()

   def __call__( self, request, start_response ):
      trace( '__call__ entry' )
      # If the client is already gone, don't bother, return immediatly with a
      # dummy response (which will leave a 'Broken pipe' skid mark in the logs).
      fd = uwsgi.connection_fd() # pylint: disable-msg=no-member
      try:
         os.write( fd, b"" )
      except OSError:
         start_response( '408 Request Timeout',
                         [ ( 'Content-Type', 'text/plain' ) ] )
         return [ b'Client already gone' ]

      requestResponse = self.capiAgent_.processRequest( request )
      if not requestResponse:
         return ''

      ( reponseCode, contentType, headers, body ) = requestResponse
      headers = headers if headers else []
      headers.append( ( 'Content-type', contentType ) )
      # pylint: disable-next=consider-merging-isinstance
      if ( isinstance( body, types.GeneratorType ) or
           isinstance( body, IOBase ) ):
         generator = body
      else:
         if body:
            if isinstance( body, str ):
               body = body.encode()
            headers.append( ( 'Content-length', str( len( body ) ) ) )
         generator = [ body ]
      start_response( reponseCode, UwsgiConstants.DEFAULT_HEADERS + headers )
      trace( '__call__ exit' )
      return generator
