# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=import-outside-toplevel
# pylint: disable=superfluous-parens
# pylint: disable=consider-using-f-string
# pylint: disable=consider-using-with
# pylint: disable=raise-missing-from

from __future__ import absolute_import, division, print_function
import datetime
from dateutil import parser
import errno
import grp
import os
import pwd
import re
import tempfile
import time
import hashlib
import Tac
import Tracing
import Logging
import socket
import subprocess
import CliGlobal
import SslMonitorExpiry
import six
from six.moves import map
from TypeFuture import TacLazyType

# A command in Eos-initscripts.spec generates the 'eosadmin' UNIX group, which
# SslCertKey.py attempts to access through the grp library.
# pkgdeps: rpm Eos-initscripts

traceHandle = Tracing.Handle( 'MgmtSecuritySslCertKey' )
error = traceHandle.trace0
warn  = traceHandle.trace1
info  = traceHandle.trace2
trace = traceHandle.trace3
debug = traceHandle.trace4

ErrorType = Tac.Type( "Mgmt::Security::Ssl::ErrorType" )

Constants = Tac.Type( "Mgmt::Security::Ssl::Constants" )
CertificateInfo = Tac.Type( "Mgmt::Security::Ssl::CertificateInfo" )
PublicKeyAlgorithm = Tac.Type( "Mgmt::Security::Ssl::PublicKeyAlgorithm" )
CertLocation = TacLazyType( "Mgmt::Security::Ssl::CertLocation" )

USE_FIPS_FLAG = True

# See RFC3279 section 2.3
PUBKEY_ALGO_TO_ENUM = { 'rsaEncryption': PublicKeyAlgorithm.RSA,
                        'id-ecPublicKey': PublicKeyAlgorithm.ECDSA,
                      }

SECURITY_SSL_KEY_CERT_CREATED = Logging.LogHandle(
              "SECURITY_SSL_KEY_CERT_CREATED",
              severity=Logging.logInfo,
              fmt="SSL %s %s has been created with the %s hash of %s%s.",
              explanation="The SSL private key, certificate or certificate signing "
                          "request has been created. This can happen when a SSL "
                          "private key or certificate is generated in the system.",
              recommendedAction=Logging.NO_ACTION_REQUIRED )
SECURITY_SSL_KEY_CERT_UPDATED = Logging.LogHandle(
              "SECURITY_SSL_KEY_CERT_UPDATED",
              severity=Logging.logInfo,
              fmt="SSL %s %s has been updated with the %s hash of %s%s from %s.",
              explanation="The SSL private key, certificate or certificate signing "
                          "request has been updated. This can happen when an "
                          "existing SSL private key or certificate is updated "
                          "in the system.",
              recommendedAction=Logging.NO_ACTION_REQUIRED )
SECURITY_SSL_KEY_CERT_DELETED = Logging.LogHandle(
              "SECURITY_SSL_KEY_CERT_DELETED",
              severity=Logging.logInfo,
              fmt="SSL %s %s has been deleted with the previous %s hash of %s%s.",
              explanation="The SSL private key, certificate or certificate signing "
                          "has been deleted. This can happen when a SSL private key "
                          "or certificate is deleted from the system.",
              recommendedAction=Logging.NO_ACTION_REQUIRED )
SECURITY_SSL_KEY_CERT_IMPORTED = Logging.LogHandle(
              "SECURITY_SSL_KEY_CERT_IMPORTED",
              severity=Logging.logInfo,
              fmt="SSL %s %s has been imported with the %s hash of %s%s.",
              explanation="The SSL private key or certificate has been imported. "
                          "This can happen when a SSL private key or cerficiate "
                          "is installed in the system.",
              recommendedAction=Logging.NO_ACTION_REQUIRED )

# Cache to hold the certificate and CRL dates. This file is
# used in CliPlugin hence defining Cli theadsafe global
gv = CliGlobal.CliGlobal( dict( dateCache=dict(), # pylint: disable=use-dict-literal
                          # pylint: disable-next=use-dict-literal
                          trustedCertHashCache=dict(), crlHashCache=dict() ) )

def dirCreate( path ):
   trace( "dirCreate start for: ", path )
   try:
      os.makedirs( path )
   except OSError as e:
      if e.errno == errno.EEXIST:
         trace( "Directory already exists: ", path )
      else:
         error( "Cannot create directory: ", path, "errno: ", e.strerror )
         raise

   uid = pwd.getpwnam( Constants.sslDirOwner ).pw_uid
   gid = grp.getgrnam( Constants.sslDirGroup ).gr_gid
   os.chown( path, uid, gid )
   os.chmod( path, Constants.sslDirPerm )

def createSslDirs():
   dirCreate( Constants.baseDir )
   dirCreate( Constants.certsDirPath() )
   dirCreate( Constants.autoCertsDirPath() )
   dirCreate( Constants.keysDirPath() )
   dirCreate( Constants.autoKeysDirPath() )
   dirCreate( Constants.profileBaseDirPath() )
   dirCreate( Constants.rotationBaseDirPath() )

def isSslDirsCreated():
   # rotationBaseDirPath is the last one to be created. Only
   # check if this dir is created.
   lastDir = Constants.rotationBaseDirPath()
   try:
      if not os.path.isdir( lastDir ):
         trace( "Dir not present", lastDir )
         return False

      uid = pwd.getpwnam( Constants.sslDirOwner ).pw_uid
      gid = grp.getgrnam( Constants.sslDirGroup ).gr_gid
      statInfo = os.stat( lastDir )
      
      trace( "statInfo uid:gid:mode", statInfo.st_uid, 
             statInfo.st_gid, oct( statInfo.st_mode ) )
      
      if statInfo.st_uid != uid or statInfo.st_gid != gid:
         return False
      
      if statInfo.st_mode & 0o7777 != Constants.sslDirPerm:
         return False
      return True
   except EnvironmentError as e:
      trace( "isSslDirsCreated error:", str( e ) )
      return False

def getAllPem( pem, isFile=True ):
   if isFile:
      with open( pem, 'r' ) as fp:
         filetext = fp.read()
   else:
      filetext = pem
   matches = re.findall( "-----BEGIN.*?-----.*?-----END.*?-----", filetext, 
                         flags=re.S )
   debug( "Matches are:", matches )
   return matches

def _getPemCount( pem, isFile=True ):
   trace( "_getPemCount start:", pem, "isFile:", isFile )
   matches = getAllPem( pem, isFile=isFile )
   trace( "_getPemCount end:", pem )
   return len( matches )

def _parseDates( *dates ):
   # pytz imports zipfile which is a memory hog modul blacklisted
   # by Cli. Since CliPlugin imports SslCertKey, conditionally import
   # pytz
   import pytz
   epochDt = datetime.datetime( 1970, 1, 1, tzinfo=pytz.utc )
   def parseDate( date ):
      dt = parser.parse( date, tzinfos={ 'GMT' : pytz.utc } )
      return int( ( dt - epochDt ).total_seconds() )
   return list( map( parseDate, dates ) )

def createProfileCache( certOrCrls ):
   if gv.dateCache or gv.trustedCertHashCache or gv.crlHashCache:
      trace( "Clearing old profile cache" )
      clearProfileCache()

   subproc = dict() # pylint: disable=use-dict-literal
   for certOrCrl in certOrCrls:
      if hasCrl( certOrCrl ):
         cmd = [ 'openssl', 'crl', '-lastupdate', '-nextupdate', '-hash', '-noout' ]
      else:
         cmd = [ 'openssl', 'x509', '-dates', '-hash', '-noout' ]
      subproc[ certOrCrl ] = subprocess.Popen(
         cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
         universal_newlines=True )
      subproc[ certOrCrl ].stdin.write( certOrCrl )
      subproc[ certOrCrl ].stdin.close()

   # Wait for all the subprocess to finish. Since we have spawned them
   # parallely, the max wait time should be more or less equal to the
   # wait time of one subprocess.
   for certOrCrl in subproc: # pylint: disable=consider-using-dict-items
      subproc[ certOrCrl ].wait()

   for certOrCrl in subproc: # pylint: disable=consider-using-dict-items
      output = subproc[ certOrCrl ].stdout.read()
      if hasCrl( certOrCrl ):
         # for OpenSSL 1.0.2k-fips  26 Jan 2017, we get output like:
         # lastUpdate=May  4 23:27:07 2023 GMT
         # nextUpdate=Jun 11 23:27:07 7023 GMT
         # de903576
         # While for OpenSSL 3.0.1 14 Dec 2021 (Library: OpenSSL 3.0.1 14 Dec 2021)
         # we get ouptut like
         # lastUpdate=May  4 23:27:07 2023 GMT
         # nextUpdate=Jun 11 23:27:07 7023 GMT
         # issuer name hash=de903576
         matchObj = re.match( r'lastUpdate=(.*?)\s*nextUpdate=(.*)\s*\n'
                              r'(issuer name hash=|)(.*)', output )
      else:
         matchObj = re.match( r'notBefore=(.*?)\s*notAfter=(.*)\s*\n(.*)', output )
      if matchObj is None:
         clearProfileCache()
         error( "Unable to get dates and hash for certs or crl for\n %s"
                % certOrCrl )
         return
      gv.dateCache[ certOrCrl ] = _parseDates( matchObj.group( 1 ),
                                               matchObj.group( 2 ) )
      if hasCrl( certOrCrl ):
         # the third capture group may have either an empty string ""
         # or "issuer name hash=", while the 4th capture group will have the crl name
         gv.crlHashCache[ certOrCrl ] = matchObj.group( 4 )
      else:
         gv.trustedCertHashCache[ certOrCrl ] = matchObj.group( 3 )

def clearProfileCache():
   gv.dateCache.clear()
   gv.trustedCertHashCache.clear()
   gv.crlHashCache.clear()

def getCertificateDates( certData ):
   if certData in gv.dateCache:
      trace( "Returning cert date from cache" )
      return gv.dateCache[ certData ]
   try:
      from M2Crypto import X509
      cert = X509.load_cert_string( removeTrusted( certData, isFile=False ) )
      notBefore = cert.get_not_before().get_datetime().strftime( "%Y%m%d%H%M%SZ" )
      notAfter = cert.get_not_after().get_datetime().strftime( "%Y%m%d%H%M%SZ" )
      return _parseDates( notBefore, notAfter )
   except X509.X509Error:
      raise SslCertKeyError( "Can't get certificate dates" )

def getCrlDates( crlData ):
   if crlData in gv.dateCache:
      trace( "Returning crl date from cache" )
      return gv.dateCache[ crlData ]
   try:
      from M2Crypto import X509
      crl = X509.load_crl_string( crlData.encode( "utf-8" ) )
      # pylint: disable-next=no-member
      notBefore = crl.get_not_before().get_datetime().strftime( "%Y%m%d%H%M%SZ" )
      # pylint: disable-next=no-member
      notAfter = crl.get_not_after().get_datetime().strftime( "%Y%m%d%H%M%SZ" )
      return _parseDates( notBefore, notAfter )
   except X509.X509Error:
      raise SslCertKeyError( "Can't get certificate dates" )

def getCertificateOrCrlDates( certData ):
   if hasCrl( certData ):
      return getCrlDates( certData )
   return getCertificateDates( certData )

def validateCertificateData( certData,
                             validateExtended=False,
                             validateCa=False,
                             validateStartDate=True,
                             validateExpiryDate=True,
                             treatCertExpiredAsWarning=False,
                             isTrust=False,
                             validateHostname=False,
                             treatHostnameMismatchAsWarning=False,
                             validateFqdn=False,
                             validateExpiryMonitoringConfig=None ):
   """
   Validates specific certificate errors and warnings, depending on parameters.

   Parameters:
      certData (str): Certificate data
      validateExtended (bool): Validate extended key usage
      validateCa (bool): Validate CA basic constrains
      validateStartDate (bool): Validate if the certificate is not yet valid
      validateExpiryDate (bool): Validate if the certificate is expired
      treatCertExpiredAsWarning (bool): Treat the certificate expired error as a 
      warning
      isTrust (bool): It's a trusted certificate or not
      validateHostname (bool): Validate if the hostname matches with the certificate
      treatHostnameMismatchAsWarning (bool): Treat the certificate hostname mismatch
      as a warning
      validateFqdn (bool): Validate if the certificate uses FQDN hostnames
      validateExpiryMonitoringConfig (config): Monitoring expiry config to be used
      for returning warnings for certificates approaching expiration. If this
      parameter is set to none, no expiry monitoring checks are made
   
   Returns:
      errorAndWarnings (ErrorType, [(ErrorType, str)]): Returns a tuple, with the 
      first element being the ErrorType identified and the second one the list of 
      warnings. If an error is identified, it gets returned immediately, along with
      the current list of warnings computed up to that moment.
      If no error was found, then the full list of warnings is returned as the second
      element. Each element in the list of warnings is a tuple, containing ErrorType
      and a string with extra information for that warning.
   """
   
   now = int( Tac.utcNow() )
   ( nb, na ) = getCertificateOrCrlDates( certData )
   warnings = []

   if ( nb > now ) and validateStartDate:
      return ErrorType.certNotYetValid, warnings
   elif ( na <= now ) and validateExpiryDate:
      errorType = ErrorType.certExpired
      if treatCertExpiredAsWarning:
         warnings.append( ( errorType, "" ) )
      else:
         return errorType, warnings

   if validateExpiryMonitoringConfig:
      warningType = SslMonitorExpiry.validateCertificateExpiry(
         validateExpiryMonitoringConfig, now, na )
      if warningType is not ErrorType.noError:
         warnings.append( ( warningType, str( na ) ) )

   if hasCertificate( certData ):
      if validateExtended:
         trace( "validating extended key usage" )
         errorType = validateExtendedKeyUsage( certData )
         if errorType is not ErrorType.noError:
            return errorType, warnings
      if validateCa:
         errorType = validateBasicConstraint( certData, isTrust )
         if errorType is not ErrorType.noError:
            return errorType, warnings
      if validateHostname:
         errorType = validateHostnameMatch( certData )
         if errorType is not ErrorType.noError:
            if treatHostnameMismatchAsWarning:
               warnings.append( ( errorType, "" ) )
            else:
               return errorType, warnings
      if validateFqdn:
         errorType = validateHostnameFqdn( certData )
         if errorType is not ErrorType.noError:
            return errorType, warnings
   return ErrorType.noError, warnings

def validateCrlCa( crlData, caCerts, profileName, validateExtended=False ):
   from M2Crypto import X509, m2
   if validateExtended:
      crl = X509.load_crl_string( crlData.encode( "utf-8" ) )
      for caCert in caCerts.get( profileName, [] ):
         cert = X509.load_cert_string( removeTrusted( caCert, isFile=False ) )
         crlIssuer = crl.get_issuer().as_hash()
         caSubject = cert.get_subject().as_hash()
         if crlIssuer == caSubject:
            # pylint: disable-next=no-member
            if cert.check_purpose( m2.X509_PURPOSE_CRL_SIGN, 0 ):
               return ErrorType.noError
            else:
               return ErrorType.noCrlSign
      return ErrorType.crlNotSignedByCa
   return ErrorType.noError

def isClientCert( certData ):
   from M2Crypto import X509, m2
   cert = X509.load_cert_string( removeTrusted( certData, isFile=False ) )
   # pylint: disable-next=no-member
   return ( cert.check_purpose( m2.X509_PURPOSE_SSL_CLIENT, 0 ) == 1 )

def isServerCert( certData ):
   from M2Crypto import X509, m2
   cert = X509.load_cert_string( removeTrusted( certData, isFile=False ) )
   # pylint: disable-next=no-member
   return ( cert.check_purpose( m2.X509_PURPOSE_SSL_SERVER, 0 ) == 1 )

def getCertInfo( certData ):
   ( nb, na ) = getCertificateOrCrlDates( certData )
   return CertificateInfo( certData, nb, na )

def getCommonName( certData ):
   from M2Crypto import X509
   cert = X509.load_cert_string( removeTrusted( certData, isFile=False ) )
   commonName = cert.get_subject().CN
   if commonName is None:
      commonName = ''
   return commonName

def extractPems( name, begin, end, isFile, certLocation=CertLocation.certs ):
   pems = []
   if isFile:
      if certLocation == CertLocation.autoCerts:
         path = Constants.autoCertPath( name )
      else:
         path = Constants.certPath( name )
      try:
         with open( path, 'r' ) as fp:
            data = fp.read()
      except IOError as e:
         error( "Cannot open file", name,
               "errno", e.errno )
         return pems
   else:
      data = name
   for m in re.finditer( '(' + begin + '.*?' + end + ')',
         data, re.DOTALL ):
      lineno = data.count( '\n', 0, m.start() ) + 1
      pem = data[ m.start() : m.end() ] + '\n'
      pems.append( ( pem, lineno ) )
   return pems

def extractCerts( cert, isFile=True, certLocation=CertLocation.certs ):
   begin = '-----BEGIN (TRUSTED )?CERTIFICATE-----'
   end = '-----END (TRUSTED )?CERTIFICATE-----'
   return extractPems( cert, begin, end, isFile, certLocation )

def extractCrls( crl, isFile=True ):
   begin = '-----BEGIN X509 CRL-----'
   end = '-----END X509 CRL-----'
   return extractPems( crl, begin, end, isFile )

def validateExtendedKeyUsage( certData ):
   from M2Crypto import X509
   cert = X509.load_cert_string( removeTrusted( certData, isFile=False ) )
   for extIndex in range( cert.get_ext_count() ):
      if cert.get_ext_at( extIndex ).get_name() == 'extendedKeyUsage':
         return ErrorType.noError
   return ErrorType.noExtendedKeyUsage

def validateBasicConstraint( certData, isTrust ):
   from M2Crypto import X509
   cert = X509.load_cert_string( removeTrusted( certData, isFile=False ) )
   if cert.check_ca() != 1:
      if isTrust:
         return ErrorType.noCABasicConstraintTrust
      else:
         return ErrorType.noCABasicConstraintChain
   return ErrorType.noError

def validateHostnameMatch( certData ):
   from M2Crypto import X509, m2
   hostname = socket.gethostname()
   trace( "Hostname is", hostname )
   cert = X509.load_cert_string( removeTrusted( certData, isFile=False ) )
   # pylint: disable-next=no-member
   for commonName in cert.get_subject().get_entries_by_nid( m2.NID_commonName ):
      trace( "Common name is", commonName )
      if hostname == commonName.get_data().as_text():
         return ErrorType.noError
   try:
      hostStr = ":" + hostname
      sanStr = cert.get_ext( "subjectAltName" ).get_value()
      trace( "Subject alt names are", sanStr )
      if ( hostStr + "," ) in sanStr or sanStr.endswith( hostStr ):
         return ErrorType.noError
   except LookupError:
      trace( "Certificate does not use subjectAltName" )
   return ErrorType.hostnameMismatch

def validateHostnameFqdn( certData ):
   from M2Crypto import X509, m2
   hostnameRe = ( r"^@$|^(\*)$|^(\*\.)?(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]"
                  r"*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]"
                  r"*[A-Za-z0-9](\.?))$" )
   hostnamePattern = re.compile( hostnameRe )
   cert = X509.load_cert_string( removeTrusted( certData, isFile=False ) )
   # pylint: disable-next=no-member
   for commonName in cert.get_subject().get_entries_by_nid( m2.NID_commonName ):
      commonNameStr = commonName.get_data().as_text()
      if hostnamePattern.match( commonNameStr ) and "*" in commonNameStr:
         return ErrorType.notFqdn
   try:
      # The string from get_value looks like "IP Address:10.1.2.3, DNS:*.com"
      sanList = cert.get_ext( "subjectAltName" ).get_value().split( ", " )
      for sanEntry in sanList:
         sanDns = sanEntry.lstrip( "DNS:" ) if sanEntry.startswith( "DNS:" ) else ""
         if sanDns and hostnamePattern.match( sanDns ) and "*" in sanDns:
            return ErrorType.notFqdn
   except LookupError:
      trace( "Certificate does not use subjectAltName" )
   return ErrorType.noError

def removeTrusted( cert, isFile=True ):
   if isFile:
      with open( cert, 'r' ) as fp:
         certData = fp.read()
   else:
      certData = cert
   return certData.replace( "TRUSTED CERTIFICATE", "CERTIFICATE" )

def hasCrl( buf ):
   return 'X509 CRL' in buf

def hasCertificate( buf ):
   return 'CERTIFICATE' in buf

def validateCertificateOrCrl( filename ):
   with open( filename ) as f:
      buf = f.read()
      if hasCrl( buf ):
         validateCrl( filename, maxPemCount=None )
      elif hasCertificate( buf ):
         validateCertificate( filename, maxPemCount=None )
      else:
         raise SslCertKeyError( "Invalid certificate or CRL" )

def validateCrl( crlFile, maxPemCount=1 ):
   """ Validates the following about the CRL:
   1) Whether its a valid crl file.
   3) If there is exactly one PEM entity in the file.
   """ 
   from M2Crypto import X509
   trace ( "validateCrl start:", crlFile )
   pemCount = _getPemCount( crlFile )
   debug( "PEM tags count is", pemCount )
   if ( maxPemCount is not None ) and ( pemCount > maxPemCount ):
      error( "validateCrl: crl ", crlFile, "has", pemCount, "PEM tags" )
      raise SslCertKeyError( "The number of PEM entities in the file "
                             "exceeds the limit %d" % maxPemCount )
   
   try:
      with open( crlFile, 'r' ) as fp:
         crlData = fp.read()
      crls = extractCrls( crlData, isFile=False )
      if not crls:
         error( "No PEM format CRL in the file" )
         raise SslCertKeyError( "Invalid CRL" )
      if len( crls ) != pemCount:
         error( "validateCertificate: crls count", crls,
                "does not match pem count", pemCount )
         raise SslCertKeyError( "Multiple types of entities in CRL file not "
                                "supported" )
      for crl, _ in crls:
         try:
            X509.load_crl_string( crl.encode() )
         except X509.X509Error as e:
            error( "validateCrl exception:", e )
            raise SslCertKeyError( "Invalid CRL" )

   except IOError:
      raise SslCertKeyError( "Unable to read CRL file" )
   
def validateCertificate( cert, isFile=True, validateDates=False, maxPemCount=1 ):
   """ Validates the following about certificate:
   1) Whether its a valid certificate file.
   2) Whether the certificate has RSA or ECDSA key (DSA is not supported).
   3) If there is exactly <maxPemCount> PEM entity in the file. 
      None means by the check
   """ 
   from M2Crypto import X509
   trace ( "validateCertificate start:", cert )

   pemCount = _getPemCount( cert, isFile=isFile )
   debug( "PEM tags count is", pemCount )
   if maxPemCount is not None and pemCount > maxPemCount:
      error( "validateCertificate: cert", cert, "has", pemCount, "PEM tags" )
      raise SslCertKeyError( "The number of PEM entities in the file "
                             "exceeds the limit %d" % maxPemCount )

   try:
      if isFile:
         with open( cert, 'r' ) as fp:
            certData = fp.read()
      else:
         certData = cert

      certs = extractCerts( certData, isFile=False )
      if not certs:
         error( "No PEM format certificate in the file" )
         raise SslCertKeyError( "Invalid certificate" )
      if len( certs ) != pemCount:
         error( "validateCertificate: certs count", len( certs ),
                "does not match pem count", pemCount )
         raise SslCertKeyError( "Multiple types of entities in certificate file "
                                "not supported" )
      for certificate, lineno in certs:
         try:
            # M2Crypto cannot handle TRUSTED
            # in pem header. So remove TRUSTED before
            # giving it to M2Crypto
            X509.load_cert_string( removeTrusted( certificate, isFile=False ) )
         except X509.X509Error as e:
            error( "validateCertificate exception:", e )
            raise SslCertKeyError( "Invalid certificate" )

         commonName = getCommonName( certificate )
         pubKeyAlgo = getCertPublicKeyAlgo( certificate )
         if pubKeyAlgo == PublicKeyAlgorithm.UNSUPPORTED:
            if len( certs ) > 1: # pylint: disable=no-else-raise
               raise SslCertKeyError( "Certificate at line %d with CN: %s does "
                    "not have supported key (RSA, ECDSA)" % ( lineno, commonName ) )
            else:
               raise SslCertKeyError( "Certificate does not have supported key"
                                      " (RSA, ECDSA)" )

         if validateDates:
            result, _ = validateCertificateData( certificate )
            if result == ErrorType.certNotYetValid:
               if len( certs ) > 1: # pylint: disable=no-else-raise
                  raise SslCertKeyError( "Certificate at line %d with CN: %s is not"
                     " yet valid" % ( lineno, commonName ) )
               else:
                  SslCertKeyError( "Certificate is not yet valid" )
            elif result == ErrorType.certExpired:
               if len( certs ) > 1: # pylint: disable=no-else-raise
                  raise SslCertKeyError( "Certificate at line %d with CN: %s has"
                     " expired" % ( lineno, commonName ) )
               else:
                  raise SslCertKeyError( "Certificate has expired" )
   except IOError:
      raise SslCertKeyError( "Unable to read Certificate file" )

def getEcCurveFromCert( cert, isFile=True ):
   """ Determine the EC curve used to generate the cert's key.
   If the key was generated with a named curve, there will be an ASN1 OID field
   saying what the curve is. If the key was generated with explicit parameters
   and thus does not have the OID field, TLS (and possibly other things) won't work.
   """
   return _getAsn1Oid( cert, "x509", isFile )

def getEcCurveFromKey( key, isFile=True ):
   """ Determine the EC curve used to generate the key.
   If the key was generated with a named curve, there will be an ASN1 OID field
   saying what the curve is. If the key was generated with explicit parameters
   and thus does not have the OID field, TLS (and possibly other things) won't work.
   """
   return _getAsn1Oid( key, "ec", isFile )

def getEcCurveFromCsr( csr, isFile=True ):
   return _getAsn1Oid( csr, "req", isFile )

def _getAsn1Oid( pem, pemType, isFile=True ):
   cmd = [ "openssl", pemType, "-text", "-noout" ]
   pemData = pem
   if isFile:
      cmd += [ "-in", pem ]
      pemData = None
   output = Tac.run( cmd, stdout=Tac.CAPTURE, stderr=Tac.CAPTURE,
                     ignoreReturnCode=True, input=pemData )
   matchObj = re.search( r'ASN1 OID: (.*)', output )
   if matchObj is None:
      debug( "No ASN1 OID found" )
      debug( "Output of cmd ", " ".join( cmd ), " is ", output )
      return None
   asn1Oid = matchObj.group( 1 )
   debug( "ASN1 OID is ", asn1Oid )
   return asn1Oid

def validateEcPrivateKey( key, isFile=True ):
   """ Validates if the key has a valid EC key
   which is not password protected and it has exactly
   one PEM entity.
   """
   from M2Crypto import EC, util, Err
   trace( "validateEcPrivateKey start:", key )
   try:
      if isFile:
         EC.load_key( key, callback=util.no_passphrase_callback )
      else:
         EC.load_key_string( key.encode(), callback=util.no_passphrase_callback )
   except SystemError as e:
      error( "validateEcPrivateKey exception:", e )
      # In python3, this is thrown when the password callback fails
      raise SslCertKeyError( "Password protected keys are not supported" )
   except ValueError as e:
      error( "validateEcPrivateKey exception:", e )
      if six.PY2:
         errMsg = Err.get_error_message()
         error( "validateEcPrivateKey error message:", errMsg )
         if errMsg == "bad password read":
            raise SslCertKeyError( "Password protected keys are not supported" )
      raise KeyTypeMismatchError( "Invalid EC key" )

   # Make sure EC key was generated with a named curve
   if not getEcCurveFromKey( key, isFile ):
      raise SslCertKeyError( "EC key must be generated using a named curve" )

def validateRsaPrivateKey( key, isFile=True ):
   """ Validates if the key has a valid RSA key 
   which is not password protected and it has exactly 
   one PEM entity.
   """ 
   from M2Crypto import RSA, util
   trace( "validateRsaPrivateKey start:", key )
   try:
      if isFile:
         RSA.load_key( key, callback=util.no_passphrase_callback )
      else:
         RSA.load_key_string( key.encode(), callback=util.no_passphrase_callback )
   except SystemError as e:
      error( "validateRsaPrivateKey exception:", e )
      # In python3, this is thrown when the password callback fails
      raise SslCertKeyError( "Password protected keys are not supported" )
   except RSA.RSAError as e:
      error( "validateRsaPrivateKey exception:", e )
      if six.PY2:
         if str( e ) == "bad password read":
            raise SslCertKeyError( "Password protected keys are not supported" )
      raise KeyTypeMismatchError( "Invalid RSA key" )

def validatePrivateKeyPemCount( key, isFile=True ):
   pemCount = _getPemCount( key, isFile=isFile )
   debug( "PEM tags count is", pemCount )
   if pemCount > 1:
      error( "validatePrivateKeyPemCount: key", key, "has", pemCount, "PEM tags" )
      raise SslCertKeyError( "Multiple PEM entities in single file not supported" )

def validatePrivateKey( key, isFile=True ):
   """ Validate if the key has a valid RSA or EC key """
   validateKeyFns = [ validateRsaPrivateKey, validateEcPrivateKey ]
   for validateKeyFn in validateKeyFns:
      try:
         validateKeyFn( key, isFile )
      except KeyTypeMismatchError:
         continue
      else:
         validatePrivateKeyPemCount( key, isFile=isFile )
         return
   raise SslCertKeyError( "Invalid or unsupported key" )

def validateCSR( csrFile ):
   """ Validates if the csrFile has a valid CSR and 
   it has exactly one PEM entity.
   """ 
   from M2Crypto import X509
   trace( "validateCSR start:", csrFile )
   try:
      X509.load_request( csrFile )
   except ( ValueError, X509.X509Error ) as e:
      error( "validateCSR exception:", str( e ) )
      raise SslCertKeyError( "Invalid CSR" )
   
   pemCount = _getPemCount( csrFile )
   debug( "PEM tags count is", pemCount )
   if pemCount > 1:
      error( "validateCSR: ", csrFile, "has", pemCount, "PEM tags" )
      raise SslCertKeyError( "Multiple PEM entities in single file not supported" )

def isCertificateMatchesRsaKey( cert, keyData ):
   from M2Crypto import RSA, m2
   try:
      rsa = RSA.load_key_string( keyData.encode( "utf-8" ) )
      certModulus = cert.get_pubkey().get_modulus()
   except ( ValueError, RSA.RSAError ):
      return False
   ( _, m ) = rsa.pub()
   keyModulus = m2.bn_to_hex( m2.mpi_to_bn( m ) ) # pylint: disable=no-member
   return str( certModulus ) == str( keyModulus )

def isCertificateMatchesEcKey( cert, keyData ):
   from M2Crypto import EC
   try:
      ec = EC.load_key_string( keyData.encode( "utf-8" ) )
   except ValueError:
      return False
   return cert.get_pubkey().as_der() == ec.pub().get_der()

def isCertificateMatchesKey( certData, keyData ):
   from M2Crypto import X509
   cert = X509.load_cert_string( removeTrusted( certData, isFile=False ) )
   return ( isCertificateMatchesRsaKey( cert, keyData ) or
            isCertificateMatchesEcKey( cert, keyData ) )

def verifyCertCorrect( cert, issuerCert ):
   """
   Checks that the cert was signed by the issuerCert and
   that the cert is within the current valid time period.
   """
   if not cert.verify( issuerCert.get_pubkey() ):
      error( "Cert failed signature check" )
      return False
   try:
      certNotBefore, certNotAfter = getCertificateDates( cert.as_pem().decode(
         "utf-8" ) )
   except SslCertKeyError:
      # If parsing fails for whatever reason, assume the cert is bad.
      error( "Failed to parse cert date" )
      return False
   timeNow = time.time()
   if not( certNotBefore < timeNow < certNotAfter ):
      error( "Cert was not in valid time" )
      return False
   return True

def verifyCertChain( certFile, chainedCert, certDict, verifyEndsInRootCA=False ):
   from M2Crypto import X509
   trace ( "verifyCertChain start:", certFile, chainedCert, 
           " verifyEndsInRootCA:", verifyEndsInRootCA )
   if certFile not in certDict:
      return False
   certData = next( iter( certDict[ certFile ].values() ) )
   cert = X509.load_cert_string( removeTrusted( certData, False ) )
   chainDict = {}
   for c in chainedCert:
      if c not in certDict:
         return False
      for lineno, certData in certDict[ c ].items():
         x509Cert = X509.load_cert_string( removeTrusted( certData, False ) )
         try:
            subject = x509Cert.get_subject()
            subjectHash = subject.as_hash()
            chainDict[ subjectHash ] = x509Cert
            debug( "Insert cert:", c, "line:", lineno,
                   "subject:", subject, "hash:", subjectHash,
                   "into chainDict" )
         except Exception as e: # pylint: disable=broad-except
            error( "Exception while getting subject hash", str( e ) )
            return False

   chainLength = len( chainDict )
   trace( "Chain length is", chainLength )
   # Need to find issuer for certficiates starting from the leaf cert.
   # This number will be same as chainLength.
   for __ in range( chainLength ):
      try:
         subjectStr = cert.get_subject().as_text()
      except Exception as e: # pylint: disable=broad-except
         debug( "Exception while getting subject str: ", str( e ) )
         subjectStr = ""

      try:
         issuerStr = cert.get_issuer().as_text()
      except Exception as e: # pylint: disable=broad-except
         debug( "Exception while getting issuer str: ", str( e ) )
         issuerStr = ""

      debug( "Finding issuer", issuerStr, "for cert", subjectStr )
      try:
         issuer = chainDict[ cert.get_issuer().as_hash() ]
      except KeyError:
         # The chain is not complete yet.
         error( "No issuer found" )
         return False
      except Exception as e: # pylint: disable=broad-except
         error( "Exception raised: ", str( e ) )
         return False

      if not verifyCertCorrect( cert, issuer ):
         error( "verifyCertCorrect: cert", subjectStr,
                "failed check by", issuerStr )
         return False
      cert = issuer
   if verifyEndsInRootCA and ( not verifyCertCorrect( cert, cert ) ):
      error( "Failed to verify final cert as self-signed" )
      return False
   return True

def validateTrustedChain( profileConfig, certDict, crlDict ):
   from M2Crypto import X509
   trustedCert = profileConfig.trustedCert
   configuredTrustedCert = []
   for crl in profileConfig.crl:
      if crl not in crlDict:
         # Return with noError and let _checkCrl in SslReactor handle this
         return ErrorType.noError 
   for certName in trustedCert:
      if certName not in certDict:
         # This has been handle in _checkTrustedCert in SslReactor
         # Return with noError
         return ErrorType.noError
      for lineno in certDict[ certName ]:
         configuredTrustedCert.append( ( certName, lineno ) )
   validatedTrustedCert = {}
   
   def hasIssuedCrl( cert ):
      # Check if a given cert has issued any configured CRL
      for crlName in profileConfig.crl:
         for crlData in crlDict[ crlName ].values():
            crl = X509.load_crl_string( crlData.encode( "utf-8" ) )
            if ( crl.get_issuer().as_hash() == cert.get_subject().as_hash() ):
               return True
      return False
       
   # Find all self-signed certs and mark them if they've issued any CRL
   for certName in trustedCert:
      for lineno, certData in certDict[ certName ].items():
         cert = X509.load_cert_string( removeTrusted( certData, isFile=False ) )
         if ( cert.get_subject().as_hash() == cert.get_issuer().as_hash() ):
            configuredTrustedCert.remove( ( certName, lineno ) )
            validatedTrustedCert[ ( certName, lineno ) ] = hasIssuedCrl( cert )
   if trustedCert and len( validatedTrustedCert ) == 0:
      return ErrorType.certTrustChainNotValid
   noChangesMade = False
   # Find all certs that are signed by validated trusted certs
   # pylint: disable-msg=R1702
   while not noChangesMade:
      noChangesMade = True
      curConfiguredTrustedCert = configuredTrustedCert[:]
      for certName, lineno in curConfiguredTrustedCert:
         cert = X509.load_cert_string( removeTrusted( certDict[ certName ]
                                            [ lineno ], isFile=False ) )
         for validatedCertName, issuerHasCrl in validatedTrustedCert.items():
            c, l = validatedCertName
            validatedCert = X509.load_cert_string( \
                  removeTrusted( certDict[ c ][ l ],
                                 isFile=False ) )
            if ( validatedCert.get_subject().as_hash() ==\
                  cert.get_issuer().as_hash()
                  and cert.verify( validatedCert.get_pubkey() ) ):
               certHasCrl = hasIssuedCrl( cert )
               if issuerHasCrl == certHasCrl: # pylint: disable=no-else-break
                  configuredTrustedCert.remove( ( certName, lineno ) )
                  validatedTrustedCert[ ( certName, lineno ) ] = certHasCrl
                  noChangesMade = False
                  break
               else:
                  cn = getCommonName( certDict[ certName ][ lineno ] )
                  if len( certDict[ certName ] ) > 1:
                     error( "Certificate with CN: ", cn, " at line ", lineno,
                            " in file ", certName, " has not signed any CRL" )
                  else:
                     error( "Cert ", certName, " has not signed any CRL" )
                  return ErrorType.missingCrlForTrustChain
   if len( configuredTrustedCert ) == 0:
      return ErrorType.noError
   else:
      return ErrorType.certTrustChainNotValid

def isSelfSignedRootCertificate( certData ):
   from M2Crypto import X509
   cert = X509.load_cert_string( removeTrusted( certData, isFile=False ) )
   return ( str( cert.get_subject() ) == str( cert.get_issuer() ) )

def generateRsaPrivateKey( keyFilepath, keyBits, useTmpFile=True ):
   """
   Generates new RSA private key in keyFilepath
   """
   trace( "generateRsaPrivateKey start:", keyFilepath, ":", str( keyBits ) )
   tempKeyFilepath = None
   keyFilepath = os.path.abspath( keyFilepath )   
   
   if useTmpFile:
      keyDir = os.path.dirname( keyFilepath )
      try:
         tempKeyHandle, tempKeyFilepath = tempfile.mkstemp( dir=keyDir )
      except ( OSError, IOError ) as e:
         raise SslCertKeyError( "%s" % e.strerror )
      
      # Close file handle since we will no longer write to it directly in this code
      os.close( tempKeyHandle )
      os.chmod( tempKeyFilepath, Constants.sslKeyPerm )
   else:
      tempKeyFilepath = keyFilepath
   
   cmd = [ "openssl" ]
   # Use FIPS mode for this (only needed for openssl earlier than v3)
   if USE_FIPS_FLAG:
      cmd += [ "--fips" ]
   cmd += [ "genrsa", "-out", tempKeyFilepath, str( keyBits ) ]
   output = Tac.run( cmd, stdout=Tac.CAPTURE, 
                     stderr=Tac.CAPTURE, asRoot=True,
                     ignoreReturnCode=True )
   debug( "genrsa output:", output )

   validateRsaPrivateKey( tempKeyFilepath )
   validatePrivateKeyPemCount( tempKeyFilepath )

   if useTmpFile:
      # Rename to final name since key seems correct
      os.rename( tempKeyFilepath, keyFilepath )

def generateEcdsaPrivateKey( keyFilepath, useTmpFile=True, curve="prime256v1" ):
   """
   Generates new ECDSA private key in keyFilepath.
   Default to a built-in named curves ('openssl ecparam -list_curves')
   """
   trace( "generateEcdsaPrivateKey start:", keyFilepath )
   tempKeyFilepath = None
   keyFilepath = os.path.abspath( keyFilepath )

   if useTmpFile:
      keyDir = os.path.dirname( keyFilepath )
      try:
         tempKeyHandle, tempKeyFilepath = tempfile.mkstemp( dir=keyDir )
      except ( OSError, IOError ) as e:
         raise SslCertKeyError( "%s" % e.strerror )

      # Close file handle since we will no longer write to it directly in this code
      os.close( tempKeyHandle )
      os.chmod( tempKeyFilepath, Constants.sslKeyPerm )
   else:
      tempKeyFilepath = keyFilepath

   cmd = [ "openssl", "ecparam", "-name", curve, "-genkey", "-noout",
           "-out", tempKeyFilepath ]
   output = Tac.run( cmd, stdout=Tac.CAPTURE,
                     stderr=Tac.CAPTURE, asRoot=True,
                     ignoreReturnCode=True )
   debug( "ecparam output:", output )

   validateEcPrivateKey( tempKeyFilepath )
   validatePrivateKeyPemCount( tempKeyFilepath )

   if useTmpFile:
      # Rename to final name since key seems correct
      os.rename( tempKeyFilepath, keyFilepath )

def genOpensslConf( signRequest=False, xKeyUsage=None,
                    sanIp=None, sanDns=None, sanEmailAddress=None,
                    sanUri=None, basicConstraints=False ):

   sanIpList = [ "IP:%s" % x for x in sanIp ] if sanIp else []
   sanDnsList = [ "DNS:%s" % x for x in sanDns ] if sanDns else []
   sanEmailList = [ "email:%s" % x for x in sanEmailAddress ]\
                  if sanEmailAddress else []
   sanUriList = [ "URI:%s" % x for x in sanUri ] if sanUri else []
   sanValue = ",".join( sanIpList + sanDnsList + sanEmailList + sanUriList )

   cnf = ""
   cnf += "[req]\n"
   cnf += "distinguished_name=req_distinguished_name\n"
   cnf += "%s_extensions=v3_ext\n" % ( "req" if signRequest else "x509" ) 
   cnf += "[req_distinguished_name]\n"
   cnf += "[v3_ext]\n"
   cnf += "subjectAltName=%s\n" % sanValue if sanValue else ""
   cnf += "extendedKeyUsage=%s\n" % xKeyUsage if xKeyUsage else ""
   cnf += "basicConstraints=CA:TRUE\n" if basicConstraints else ""
   return cnf
   
def generateCertificate( commonName="self.signed",
                         keyFilepath=None, 
                         certFilepath=None,
                         signRequest=False,
                         genNewKey=True,
                         keyType=PublicKeyAlgorithm.RSA,
                         newKeyBits=2048,
                         digest="sha256",
                         curve="prime256v1",
                         country=None,
                         state=None,
                         locality=None,
                         orgName=None,
                         orgUnitName=None,
                         emailAddress=None,
                         sanIp=None,
                         sanDns=None,
                         sanEmailAddress=None,
                         sanUri=None,
                         xKeyUsage=None,
                         validity=30000 ):
   """Generates either a certificate request or a self-signed x509 certificate.
   Returns the certificate as a (certificate, key) string pair. If an error occurs,
   throws SslCertKeyError with the msg attribute set to error text.
   
   If genNewKey is set, a new key is generated.
   If genNewKey is not set, keyFilepath must be supplied.
   If genNewKey and keyFilepath are set, the new key will be saved to keyFilepath.
   If certFilepath is set, the cert or CSR will be saved to certFilepath.
   """
   debug( "locals:", locals() )
   
   assert commonName
   if genNewKey:
      assert newKeyBits or curve
   else:
      assert keyFilepath and os.path.isfile( keyFilepath )

   keyFilepath = os.path.abspath( keyFilepath ) if keyFilepath else None
   certFilepath = os.path.abspath( certFilepath ) if certFilepath else None
   certDir = os.path.dirname( certFilepath ) if certFilepath else None
   keyDir = os.path.dirname( keyFilepath ) if keyFilepath else None
   
   tmpCertFile = tempfile.NamedTemporaryFile( dir=certDir, delete=False, mode="w+" )
   os.chmod( tmpCertFile.name, Constants.sslCertPerm )
   tmpKeyFile = tempfile.NamedTemporaryFile( dir=keyDir, delete=False, mode="w+" )
   os.chmod( tmpKeyFile.name, Constants.sslKeyPerm )
   confFile = tempfile.NamedTemporaryFile( mode="w" )
   conf = genOpensslConf( signRequest=signRequest,
                          xKeyUsage=xKeyUsage,
                          sanIp=sanIp, sanDns=sanDns,
                          sanEmailAddress=sanEmailAddress,
                          sanUri=sanUri )
   debug( "openssl conf:" )
   debug( "\n" + conf )
   confFile.write( conf )
   confFile.flush()
   errorOccured = False

   try:
      if genNewKey:
         if keyType == PublicKeyAlgorithm.RSA:
            trace( "Generating new RSA key" )
            generateRsaPrivateKey( tmpKeyFile.name, newKeyBits, useTmpFile=False )
         elif keyType == PublicKeyAlgorithm.ECDSA:
            trace( "Generating new ECDSA key" )
            generateEcdsaPrivateKey( tmpKeyFile.name, useTmpFile=False, curve=curve )
         else:
            assert False, "Unsupported key type: {}".format( keyType )
      else:
         trace( "Using RSA key from", keyFilepath )
         with open( keyFilepath, 'r' ) as f:
            tmpKeyFile.write( f.read() )
            tmpKeyFile.flush()
      
      def paramStr( param ):
         return "" if not param else param

      subjStr = "/C={country}/ST={state}/L={locality}/O={org}"
      subjStr += "/OU={orgUnit}/CN={commonName}/emailAddress={email}/"
      subjStr = subjStr.format( country=paramStr( country ), 
                                state=paramStr( state ), 
                                locality=paramStr( locality ),
                                org=paramStr( orgName ),
                                orgUnit=paramStr( orgUnitName ), 
                                commonName=paramStr( commonName ),
                                email=paramStr( emailAddress ) )

      cmd = [ "openssl" ]
      if USE_FIPS_FLAG:
         cmd += [ "--fips" ]
      cmd += [ "req", "-new", "-%s" % digest,
              "-subj", subjStr, "-key", "%s" % tmpKeyFile.name, 
              "-out", "%s" % tmpCertFile.name, "-config", 
              "%s" % confFile.name ]
      if not signRequest:
         cmd += [ "-x509", "-days", "%d" % validity ]
      
      trace( "Openssl cmd is", " ".join( cmd ) )

      output = Tac.run( cmd, stdout=Tac.CAPTURE, stderr=Tac.CAPTURE,
                        ignoreReturnCode=True )
      debug( "Openssl output is:" )
      debug( "\n" + output )
      
      if signRequest:
         validateCSR( tmpCertFile.name ) 
      else:
         validateCertificate( tmpCertFile.name )

      return ( tmpCertFile.read(), tmpKeyFile.read() )
   except SslCertKeyError as e:
      errorOccured = True
      error( "SslCertKeyError:", str( e ) )
      raise SslCertKeyError( str( e ) )
   except Exception as e:
      errorOccured = True
      error( "StandardError: ", str( e ) )
      raise
   finally:
      tmpCertFile.close()
      tmpKeyFile.close()
      confFile.close()
      
      if errorOccured:
         os.remove( tmpCertFile.name )
         os.remove( tmpKeyFile.name )
      else:
         if certFilepath:
            os.rename( tmpCertFile.name, certFilepath )
         else:
            os.remove( tmpCertFile.name )
            
         if keyFilepath and genNewKey:
            os.rename( tmpKeyFile.name, keyFilepath )
         else:
            os.remove( tmpKeyFile.name )

def getCertPem( certFilepath ):
   cmd = [ "openssl", "x509", "-in", certFilepath ]
   return Tac.run( cmd, stdout=Tac.CAPTURE, stderr=Tac.CAPTURE,
                   ignoreReturnCode=True )

class SslCertKeyError( Exception ):
   pass

class KeyTypeMismatchError( SslCertKeyError ):
   pass

def getCertPublicKeyAlgo( certData ):
   return _getPublicKeyAlgo( certData, "x509" )

def getCsrPublicKeyAlgo( csrData ):
   return _getPublicKeyAlgo( csrData, "req" )

def _getCertOrRequestData( pemData, pemType ):
   from M2Crypto import X509
   if pemType == "x509":
      cert = X509.load_cert_string( removeTrusted( pemData, isFile=False ) )
   else:
      cert = X509.load_request_string( removeTrusted( pemData, isFile=False ) )
   return cert.as_text()

def _getPublicKeyAlgo( pemData, pemType ):
   certContent = _getCertOrRequestData( pemData, pemType )
   matchObj = re.search( r'Public Key Algorithm: (.*)', certContent )
   if matchObj is None:
      debug( "Certificate content is ", certContent )
      raise SslCertKeyError( "Can't get public key algorithm" )
   pubKeyAlgo = matchObj.group( 1 )
   return PUBKEY_ALGO_TO_ENUM.get( pubKeyAlgo, PublicKeyAlgorithm.UNSUPPORTED )

def getCertPublicKeySize( certData ):
   return _getPublicKeySize( certData, "x509" )

def getCsrPublicKeySize( csrData ):
   return _getPublicKeySize( csrData, "req" )

def _getPublicKeySize( pemData, pemType ):
   certContent = _getCertOrRequestData( pemData, pemType )
   matchObj = re.search( r'Public-Key: \((\d+) bit\)', certContent )
   if matchObj is None:
      debug( "Certificate content is ", certContent )
      raise SslCertKeyError( "Can't get size of the public key" )
   size = matchObj.group( 1 )
   debug( "Public key size is ", size )
   return int( size )

def getRsaPublicKey( privateKeyPath, publicKeyPath ):
   # Extract the RSA public key part from SSL key pair
   from M2Crypto import RSA
   try:
      RSA.load_key( privateKeyPath ).save_pub_key( publicKeyPath )
   except RSA.RSAError as e:
      error( "getRsaPublicKey exception:", e )
      raise SslCertKeyError( "Unable to extract public key from SSL private key %s"
                             % privateKeyPath.split( "/" )[ -1 ] )

def getEcPublicKey( privateKeyPath, publicKeyPath ):
   # Extract the EC public key part from SSL key pair
   from M2Crypto import EC
   try:
      EC.load_key( privateKeyPath ).save_pub_key( publicKeyPath )
   except ValueError as e:
      error( "getEcPublicKey exception:", e )
      raise SslCertKeyError( "Unable to extract public key from SSL private key %s"
                             % privateKeyPath.split( "/" )[ -1 ] )

def getPublicKey( privateKeyPath, publicKeyPath ):
   # Extract the RSA or EC public key part from SSL key pair
   getPubKeyFns = [ getRsaPublicKey, getEcPublicKey ]
   getPubKeyErr = None
   for getPubKeyFn in getPubKeyFns:
      try:
         getPubKeyFn( privateKeyPath, publicKeyPath )
      except SslCertKeyError as e:
         getPubKeyErr = e
         continue
      else:
         return
   raise SslCertKeyError( str( getPubKeyErr ) )

def generateKeyCertHash( filePath, logType, hashAlgo='SHA-256', blockSize=65536 ):
   if not os.path.exists( filePath ):
      raise SslCertKeyError( "SSL %s does not exist" % filePath )

   # For now, we always use SHA-256.
   assert hashAlgo == 'SHA-256'

   # Get the file path to hash
   tmpPubKey = None
   if logType == "sslkey:":
      tmpPubKey = tempfile.NamedTemporaryFile()
      getPublicKey( filePath, tmpPubKey.name )

   # generate the SHA-256 hash of file content
   fileToHash = tmpPubKey.name if logType == "sslkey:" else filePath
   sha256sum = hashlib.sha256()
   with open( fileToHash, 'rb' ) as f:
      for block in iter( lambda: f.read( blockSize ), b'' ):
         sha256sum.update( block )
   return sha256sum.hexdigest()

def getLogActionAndFileHash( filePath, logType, defaultAction, hashAlgo='SHA-256' ):
   if filePath and os.path.exists( filePath ):
      logAction = "updated"
      fileHash = generateKeyCertHash( filePath, logType, hashAlgo )
   else:
      logAction = defaultAction
      fileHash = ""
   return logAction, fileHash

def generateSslKeyCertSysLog( filePath, logType, logAction,
                              oldFileHash='', hashAlgo='SHA-256' ):
   trace( "generateSslKeyCertSysLog for", logAction, filePath )
   fileName = filePath.split( "/" )[ -1 ]
   logTypeStr = "private key" if logType == "sslkey:" else "certificate"
   fileTypeStr = "corresponding public " if logType == "sslkey:" else ""
   if logAction == "created":
      Logging.log( SECURITY_SSL_KEY_CERT_CREATED, logTypeStr, fileName,
                   hashAlgo, fileTypeStr,
                   generateKeyCertHash( filePath, logType, hashAlgo ) )
   elif logAction == "updated":
      Logging.log( SECURITY_SSL_KEY_CERT_UPDATED, logTypeStr, fileName,
                   hashAlgo, fileTypeStr, oldFileHash,
                   generateKeyCertHash( filePath, logType, hashAlgo ) )
   elif logAction == "deleted":
      Logging.log( SECURITY_SSL_KEY_CERT_DELETED, logTypeStr, fileName,
                   hashAlgo, fileTypeStr, oldFileHash )
   elif logAction == "imported":
      Logging.log( SECURITY_SSL_KEY_CERT_IMPORTED, logTypeStr, fileName,
                   hashAlgo, fileTypeStr,
                   generateKeyCertHash( filePath, logType, hashAlgo ) )

# openssl commands are used to generate Hash because of below reasons:
# Unable to use M2Crypto library as hash created  by M2Crypto hash
# is not same as openssl Hash
# Importing crypto library found memory leak in because of cffi backend
# Fixes are done in pyzmq https://pyzmq.readthedocs.io/en/latest/changelog.html
# >>> p.communicate() returns a list as here :('05e71200\n', None)
# hence we are just picking up index 0 and removing \n so that we can add .0
# for creating the file.
def generateTrustedCertHash( certData, tempDir ):
   filePath = tempDir + "/cert.pem"
   with open( filePath, "w" ) as f1:
      f1.write( certData )
   if certData in gv.trustedCertHashCache:
      certHash = gv.trustedCertHashCache[ certData ]
   else:
      cmd = [ 'openssl', 'x509', '-hash', '-noout', '-in' ]
      cmd.append( filePath )
      trace( "Generating hashfile with cert data as content at path ", filePath )
      p = subprocess.Popen( cmd, stdin=subprocess.PIPE,
                         stdout=subprocess.PIPE, stderr=subprocess.STDOUT )
      certHash = p.communicate()[ 0 ].rstrip( b'\n' ).decode()

   return filePath, certHash

def generateCrlHash( crlData, tempDir ):
   filePath = tempDir + "/crl.pem"
   with open( filePath, "w" ) as f1:
      f1.write( crlData )
   if crlData in gv.crlHashCache:
      crlHash = gv.crlHashCache[ crlData ]
   else:
      cmd = [ 'openssl', 'crl', '-hash', '-noout', '-in' ]
      cmd.append( filePath )
      trace( "Generating hashfile with crl data as content at path ", filePath )
      p = subprocess.Popen( cmd, stdin=subprocess.PIPE,
                            stdout=subprocess.PIPE, stderr=subprocess.STDOUT )
      crlHash = p.communicate()[ 0 ].rstrip( b'\n' ).decode()
   return filePath, crlHash

def renameDirAtomically( currDir, symOpenSslHash, symlinkDirPath ):
   os.symlink( currDir, symlinkDirPath + "symtempDir" )
   os.rename( symlinkDirPath + "symtempDir", symOpenSslHash )
