# Copyright (c) 2007-2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function
import base64
import binascii
import json
import os
import re
from socket import IPPROTO_TCP

import CliCommon
import Tac
import Tracing
import UwsgiAaa
import UwsgiConstants
import UwsgiSessionManager
import six
from  six.moves import urllib

traceHandle = Tracing.Handle( 'UwsgiRequestContext' )
warn = traceHandle.trace1
info = traceHandle.trace2
trace = traceHandle.trace3
debug = traceHandle.trace4

class HttpException( Exception ):
   def __init__( self, name, message, code, contentType='text/plain', 
                 additionalHeaders=None ):
      Exception.__init__( self, message )
      self.name = name
      self.message = message
      self.code = code
      self.contentType = contentType
      self.additionalHeaders = additionalHeaders if additionalHeaders else []
      warn( 'HttpException:', self )

   def __str__( self ):
      return 'http %s: %s: %s' % ( self.code, self.name, self.message )

class HttpBadRequest( HttpException ):
   """ Raised due to a protocol violation. """
   def __init__( self, message ):
      HttpException.__init__( self, 'Bad Request', message, 400 )
      
class HttpUnauthorized( HttpException ):
   """ Raised to send an authentication challenge. """

   def __init__( self, message, useBasicAuthHeaders=False ):
      if useBasicAuthHeaders:
         addHeaders = [ ( 'WWW-Authenticate',
             'Basic realm="%s"' % UwsgiConstants.AUTH_REALM ) ]
      else:
         addHeaders = []
      HttpException.__init__( self, 'Unauthorized', message, 401,
                              additionalHeaders=addHeaders )

class HttpNotFound( HttpException ):
   """ Raised due to a 404. """
   def __init__( self, message ):
      HttpException.__init__( self, 'Not Found', message, 404 )

class HttpForbidden( HttpException ):
   """ Raised due to a permissions violation. """

   def __init__( self, message ):
      HttpException.__init__( self, 'Forbidden', message, 403 )

class HttpMethodNotAllowed( HttpException ):
   """ Raised due to request method cannot be used."""

   def __init__( self, message ):
      HttpException.__init__( self, 'Method Not Allowed', message, 405 )

class HttpServiceAclDenied( HttpException ):
   """ Raised due to a permissions violation and cause nginx to close. """

   def __init__( self, message ):
      HttpException.__init__( self, 'Service ACL Denied', message, 444 )

class UserContext( object ):
   def __init__( self, user, passwd, uid, gid, privLevel, aaaAuthnId, sessionId,
                 expiryTime ):
      trace( 'UserContext.__init__ entry' )
      self.user = user
      self.passwd = passwd
      self.uid = uid
      self.gid = gid
      self.privLevel = privLevel
      self.aaaAuthnId = aaaAuthnId
      self.sessionId = sessionId
      self.expiryTime = expiryTime
      trace( 'UserContext.__init__ exit' )

class UwsgiRequestContext( object ):  
   """ A request context is obtained from the agent at the start of each request """
   def __init__( self, sysname, request, serviceAclFilterSm=None,
         sessionManager=None, aaaApiClient=None ):
      """ Initializes attributes tracking the request. """
      trace( 'RequestContext.__init__ entry' )
      self.requestStartTime_ = Tac.now()
      self.request_ = request
      self.sysname_ = sysname
      self.aaaApiClient_ = aaaApiClient
      self.sessionManager_ = sessionManager
      self.serviceAclFilterSm_ = serviceAclFilterSm

      # Trace the headers.
      for header, value in six.iteritems( request ):
         if header.upper() == 'HTTP_AUTHORIZATION':  # Base64 is easy to reverse...
            value = '<elided>'     # ... so make sure passwords don't get logged.
         debug( 'req. header', header, value )

      # BUG343599: a request MUST always be read in a uwsgi application.
      # so we do that here as soon as we create the context
      self.getRequestContent()
      trace( 'RequestContext.__init__ exit' )

   def authenticate( self, tty, service, user=None, passwd=None ):
      if user is not None:
         # this means that the frontend has decided that it really wants a 
         # particular username/password. Let them have their wish
         assert passwd is not None, 'NULL password not allowed, try empty str'
         sessionId = None
         disableAuth = False
      else:
         user, passwd, sessionId, disableAuth = self._getUserFromReq()
      if user:
         user = six.ensure_str( user )
      if passwd:
         passwd = six.ensure_str( passwd )
      if self.aaaDisabled():
         return UserContext( user=user, passwd=passwd, uid=0, gid=0,
                             privLevel=CliCommon.MAX_PRIV_LVL,
                             aaaAuthnId=None, sessionId=None, expiryTime=None )
      try:
         if sessionId:
            authEntry, expiryTime = self.sessionManager_.getSession( sessionId )
            user = authEntry.username
         else:
            remoteHost = self.getRequesterIp()
            if disableAuth:
               aaaResult = self.aaaApiClient_.createAndAuthorizeSession(
                     user, service, 'certificate', remoteHost=remoteHost, tty=tty,
                     validUserOnly=False )
            else:
               aaaResult = self.aaaApiClient_.authenticateAndAuthorizeSession(
                     user, passwd,
                     service, remoteHost=remoteHost, tty=tty )
            authEntry = UwsgiAaa.parseAaaResults( aaaResult, user )
            debug( 'authEntiry for user', user, authEntry )
            sessionId = None
            expiryTime = None
      except UwsgiAaa.AuthenticationError as e:
         # Don't add basic auth header with a sessionId
         raise HttpUnauthorized( 'Unable to authenticate user: %s' % str( e ),
                                  sessionId is None  )
      return UserContext( user=user, passwd=passwd, uid=authEntry.uid,
                          gid=authEntry.gid, privLevel=authEntry.privLevel,
                          aaaAuthnId=authEntry.aaaAuthnId, sessionId=sessionId,
                          expiryTime=expiryTime )

   def _getUserFromReq( self ):
      # first we try to get the user from the certificate
      user = self._parseSSLClientCertificate()
      if user:
         return user, '', None, True

      # doesn't have a certificate lets try to find a sessionId
      sessionId = self._parseSessionId()
      if sessionId:
         return None, None, sessionId, False

      # nope, no session either, lets look for a basic auth header
      user, passwd = self._parseAuthorizationHeader() 
      return user, passwd, None, False

   def deauthenticate( self, userContext ):
      sessionId = self._parseSessionId()
      if sessionId:
         self.sessionManager_.releaseSession( sessionId )
      else:
         self.aaaApiClient_.closeSession( userContext.aaaAuthnId )

   def login( self, tty, service ):
      trace( 'RequestContext.login entry' )
      # If the request body object is empty, assume this
      # is a request to validate the supplied sessionId.
      # Otherwise treat it as a login request, ignoring
      # the supplied sessionId.
      user, passwd = self._getUserPasswd()
      try:
         if user is not None or passwd is not None:
            trace( '_getUserInfo with username and/or passwd' )
            if user is None or passwd is None:
               raise HttpBadRequest( 'Malformed login request' )
            ( authEntry,
              sessionId,
              expiryTime ) = self.sessionManager_.createSession(
                    self.getRequesterIp(), self.getUserAgent(), user,
                    passwd.encode(),
                    tty=tty, service=service )
         else:
            sessionId = self._parseSessionId()
            trace( 'login session id validation', sessionId )
            if not sessionId:
               raise HttpUnauthorized( 'No session id' )
            authEntry, expiryTime = self.sessionManager_.getSession( sessionId,
                                                            incrementUsageCnt=False )
      except UwsgiAaa.AuthenticationError as e:
         raise HttpUnauthorized( 'Unable to authenticate user: %s' % str( e ),
                                 False )
      trace( 'RequestContext.login exit' )
      return UserContext( user=user, passwd=passwd, uid=authEntry.uid,
                          gid=authEntry.gid, privLevel=authEntry.privLevel,
                          aaaAuthnId=authEntry.aaaAuthnId, sessionId=sessionId,
                          expiryTime=expiryTime )

   def logout( self ):
      trace( 'RequestContext.logout entry' )
      try:    
         sessionId = self._parseSessionId() 
         if not sessionId:
            raise HttpBadRequest( 'Invalid logout request.' )
         self.sessionManager_.logoutSession( sessionId )
      except UwsgiSessionManager.SessionLogoutError: 
         raise HttpBadRequest( 'Invalid logout request.' )
      trace( 'RequestContext.logout exit' )

   def getRequestStartTime( self ):
      """ The time at which the request was created """
      return self.requestStartTime_
   
   def getRedirectUrl( self, redirectFile ):
      """ Get the request redirct url for main page. """
      return urllib.parse.urljoin( '%s://%s' %
                                      ( self.request_.get( 'REQUEST_SCHEME' ),
                                        self.request_.get( 'HTTP_HOST' ) ),
                                   redirectFile )

   def getDocRoot( self ):
      """ Get the root path of static documents. """
      return self.request_.get( 'DOCUMENT_ROOT', '/' )

   def getStaticFile( self, path ):
      fd = open( path, 'rb' )
      block_size = 32 * 1024
      if 'wsgi.file_wrapper' in self.request_:
         return self.request_[ 'wsgi.file_wrapper' ]( fd, block_size )
      else:
         return iter( lambda: fd.read( block_size ), b'' )

   @Tac.memoize
   def aaaDisabled( self ):
      # if in the environment we have AAA disabled then it is always disabled
      if os.getenv( 'DISABLE_AAA', '0' ) == '1':
         return True

      # if in the request we don't have the DISABLE_AAA flag then we aren't disabled
      if self.request_.get( 'DISABLE_AAA' ) != '1':
         return False

      # if on localhost (either ipv4 or ipv6) then we can disable AAA
      if ( self.request_[ 'REMOTE_ADDR' ] == '127.0.0.1' or
           self.request_[ 'REMOTE_ADDR' ] == '::1' ):
         return True

      # if on UDS then we can disable AAA
      if self.request_[ 'REMOTE_ADDR' ].startswith( 'unix:' ):
         return True

      return False
   
   @Tac.memoize
   def getRequestContent( self ):
      """Returns the request body content as a string."""
      trace( 'getRequestContent entry' )

      if not self.request_.get( 'wsgi.input' ):
         return None

      requestContentLength = self.getContentLength()
      if requestContentLength == 0:
         return ''

      requestContent = self.request_[ 'wsgi.input' ].read( requestContentLength )
      if len( requestContent ) <= 300:
         # Strip out password in login request
         endIdx = requestContent.find( b'password' )
         if endIdx == -1:
            endIdx = len( requestContent ) 
         debug( 'request content', requestContent[ :endIdx ] )
      else:
         endIdx = requestContent[ :297 ].find( b'password' )
         if endIdx == -1:
            endIdx = 297
         debug( 'request content', requestContent[ :endIdx ] + b'...' )
      
      trace( 'getRequestContent exit content len', len( requestContent ) )
      return requestContent

   def getHeader( self, header, defaultValue=None ):
      return self.request_.get( header, defaultValue )
   
   def getContentLength( self ):
      if 'HTTP_CONTENT_LENGTH' in self.request_:
         return int( self.request_[ 'HTTP_CONTENT_LENGTH' ] or 0 )
      elif 'CONTENT_LENGTH' in self.request_:
         return int( self.request_[ 'CONTENT_LENGTH' ] or 0 )
      return 0

   def getRequestType( self ):
      return self.request_[ 'REQUEST_METHOD' ]

   def getParsedUrl( self ):
      return urllib.parse.urlparse( self.request_[ 'REQUEST_URI' ] )

   def getUrlQuery( self ):
      """
      Parses URL query in the form ?key1=value1&key2=value2&...
      """
      return urllib.parse.parse_qs( self.getParsedUrl().query )

   def getUserAgent( self ):
      """ Returns the user agent making this http request """
      return self.request_.get( 'HTTP_USER_AGENT' )
 
   def getEndPoint( self ):
      """ Parses and returns endpoint. """
      return self.request_.get( 'PATH_INFO', '/' )

   def getRequesterIp( self ):
      """ Get the requester ip from the request. However if it is an IPv4 address
      it will be in the form ::ffff:<ipv4 addr>. So we have as the beginning we
      remove the intitial part"""
      requesterIp = self.request_.get( 'REMOTE_ADDR' )
      return self._removeIfStartsWith( requesterIp, '::ffff:' )

   @staticmethod
   def _removeIfStartsWith( address, prefix ):
      if address is not None:
         if address.startswith( prefix ):
            return address.replace( prefix, '', 1 )
      return address

   def _parseCertDn( self, clientDn ):
      result = {}
      for field in clientDn.split( '/' ):
         if not field:
            continue
         fieldName, fieldValue = field.split( '=' )
         result[ fieldName ] = fieldValue
      return result

   def getRemotePort( self ):
      """ Get the remote port from the request, as an integer. Return None
      if not present or if it cannot be converted to an integer"""
      try:
         return int( self.request_.get( 'REMOTE_PORT' ) )
      except( ValueError, TypeError ):
         return None

   def getServerPort( self ):
      """ Get the server port from the request, as an integer. Return None
      if not present or if it cannot be converted to an integer"""
      try:
         return int( self.request_.get( 'SERVER_PORT' ) )
      except( ValueError, TypeError ):
         return None

   def getLocalAddr( self ):
      """ Get the Local Address from the request
          Return None if not present. If it is an IPv4 address it will be in the
          form ::ffff:<ipv4 addr>."""
      localAddr = self.request_.get( 'SERVER_ADDR' )
      return self._removeIfStartsWith( localAddr, '::ffff:' )

   def getVrfName( self ):
      """ Get the nginx VRF name from request"""
      return self.request_.get( 'VRF_NAME' )

   def getServerName( self ):
      """ Get the Server Name from the request. """
      return self.request_.get( 'SERVER_NAME' )

   def aclPermitConnection( self ):
      """ Returns True if Service Acl permits this connection"""
      if self.serviceAclFilterSm_ is None:
         # Implies service Acl is not enabled in this software revision
         return True
      srcAddr = self.getRequesterIp()
      dstAddr = self.getLocalAddr()
      vrf = self.getVrfName()

      # if connection is through a unixServer or localHttpServer, permit
      if self.getServerName() == 'localhost':
         return True

      srcIp = Tac.Value( 'Arnet::IpGenAddr', srcAddr )
      dstIp = Tac.Value( 'Arnet::IpGenAddr', dstAddr )
      srcPort = self.getRemotePort()
      dstPort = self.getServerPort()
      prot = IPPROTO_TCP
      standardAcl = False
      return self.serviceAclFilterSm_.matchesConnection(
         vrf, srcIp, dstIp, srcPort, dstPort, prot, standardAcl )
   
   def _parseSSLClientCertificate( self ):
      if self.request_.get( 'SSL_CLIENT_VERIFY' ) != 'SUCCESS':
         trace( '_parseSSLClientCertificate certificate not valid' )
         return None

      spiffeId = self.request_.get( 'SSL_CLIENT_SPIFFE' )
      if not spiffeId:
         trace( '_parseSSLClientCertificate certificate does not have SPIFFE ID' )
      else:
         spiffePattern = ( r"spiffe://(\S+)\.(\S+)\.(?P<realm>\S+)\.\S+\.\S+\.\S+"
                           r"/role/(?P<role>\S+)" )
         try:
            spiffeMatch = re.search( spiffePattern, spiffeId )
            securityRealm = spiffeMatch.group( "realm" )
            userRole = spiffeMatch.group( "role" )
            user = securityRealm + "." + userRole
            trace( '_parseSSLClientCertificate user from SPIFFE ID is', user )
            return user
         except AttributeError:
            trace( '_parseSSLClientCertificate can not get user from SPIFFE ID' )

      clientDn = self.request_.get( 'SSL_CLIENT_S_DN' )
      if not clientDn:
         trace( '_parseAuthorizationHeader invalid client distinguished name' )
         return None
      
      fields = self._parseCertDn( clientDn )
      if 'CN' not in fields:
         trace( '_parseAuthorizationHeader no common name found' )
         return None
      
      cnFields = fields[ 'CN' ].split( ' ' )
      if len( cnFields ) > 2:
         trace( '_parseAuthorizationHeader common name does not'
                ' have 1 fields (username)' )
         return None

      return cnFields[ 0 ]
      
   def _parseAuthorizationHeader( self ):
      """ Locates and parses the HTTP authorization header, returning then
      2-tuple (part1, part2). Raises an exception if invalid. """

      trace( '_parseAuthorizationHeader entry' )
      if self.aaaDisabled():
         return UwsgiConstants.LOCAL_USERNAME, None

      authorizationHeader = self.request_.get( 'HTTP_AUTHORIZATION' )
      if not authorizationHeader:
         raise HttpUnauthorized( 'No authentication header found',
                                 useBasicAuthHeaders=True )
      headerparts = authorizationHeader.split()
      if len( headerparts ) != 2:
         raise HttpBadRequest( 'Invalid authentication header format' )
      scheme = headerparts[ 0 ]
      if scheme.lower() != 'basic':
         raise HttpBadRequest( 'Authentication scheme "%s" is unsupported' % scheme )
      try:
         auth = headerparts[ 1 ]
         if isinstance( auth, str ):
            auth = auth.encode()
         plainauth = base64.b64decode( auth )
      except ( TypeError, binascii.Error ):
         raise HttpBadRequest( 'Invalid basic authentication value' )
      authparts = plainauth.split( b':', 1 )
      if len( authparts ) != 2:
         raise HttpBadRequest( 'Invalid basic authentication value' )
      trace( '_parseAuthorizationHeader exit' )
      return ( authparts[ 0 ], authparts[ 1 ] )

   def _parseSessionId( self ):
      cookie = self.request_.get( 'HTTP_COOKIE' )
      if cookie is None:
         return None

      sessionIndex = cookie.find( 'Session=' )
      if sessionIndex == -1:
         return None

      return cookie[ sessionIndex + len( 'Session=' ) : ]

   def _getUserPasswd( self ):
      """Parse username/password from post data.
      Ensure the returned 2-tuple contains only
      string or None values."""
      requestContent = self.getRequestContent()
      try:
         reqDict = json.loads( requestContent )
      except ValueError:
         raise HttpBadRequest( 'Malformed login request' ) 
      def safeget( key ):
         val = reqDict.get( key )
         return None if val is None else str( val )
      username = safeget( 'username' )
      password = safeget( 'password' )
      if username:
         username = six.ensure_str( username )
      if password:
         password = six.ensure_str( password )
      return username, password
