# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function
from CliPlugin.SslModel import Certificate
from CliPlugin.SslModel import CertificateSigningRequest
from CliPlugin.SslModel import Crl
from CliPlugin.SslModel import DiffieHellmanParameters
from CliPlugin.SslModel import DistinguishedName
from CliPlugin.SslModel import Extension
from CliPlugin.SslModel import PublicKey
from CliPlugin.SslModel import RevokedCertificate
from CliPlugin.SslModel import utc
from dateutil import parser
import ssl
import datetime
import time
import Tac
import re

PublicKeyAlgorithm = Tac.Type( "Mgmt::Security::Ssl::PublicKeyAlgorithm" )

def _getDates( cert, isFile=True ):
   if isFile:
      output = Tac.run( [ 'openssl', 'x509', '-in', cert, '-dates', '-noout' ],
            stdout=Tac.CAPTURE, stderr=Tac.CAPTURE,
            ignoreReturnCode=True )
   else:
      output = Tac.run( [ 'openssl', 'x509', '-dates', '-noout' ],
            stdout=Tac.CAPTURE, stderr=Tac.CAPTURE,
            ignoreReturnCode=True, input=cert )
   matchObj = re.match( r'notBefore=(.*?)\s*notAfter=(.*)', output )
   return matchObj.group( 1 ), matchObj.group( 2 )

def getDatesEpoch( certificate, isFile=True ):
   epochDt = datetime.datetime( 1970, 1, 1, tzinfo=utc )
   # The M2Crypto errors out for away dates
   # hence using openssl command to extract dates
   ( nb, na ) = _getDates( certificate, isFile=isFile )
   notBeforeDt = parser.parse( nb, tzinfos={ 'GMT': utc } )
   notAfterDt = parser.parse( na, tzinfos={ 'GMT': utc } )
   notBeforeFromEpoch = int( ( notBeforeDt - epochDt ).total_seconds() )
   notAfterFromEpoch = int( ( notAfterDt - epochDt ).total_seconds() )
   return notBeforeFromEpoch, notAfterFromEpoch

def getCertificateModel( certificate, isFile=True ):
   # ssl is a marked as memory hog module in Cli.
   # So, conditionally imported the ssl module only
   # when ssl related commands are executed.
   from M2Crypto import X509, m2 # pylint: disable-msg=C0415
   import SslCertKey # pylint: disable-msg=C0415
   certModel = Certificate()
   certModel.subject = DistinguishedName()
   certModel.issuer = DistinguishedName()
   
   cert = X509.load_cert_string( SslCertKey.removeTrusted( certificate,
                                                      isFile=isFile ) )
   
   certModel.version = cert.get_version()  + 1
   certModel.serialNumber = cert.get_serial_number()
   
   subject = cert.get_subject()
   certModel.subject.commonName = ""
   if subject.CN:
      certModel.subject.commonName = subject.CN
   if subject.Email:
      certModel.subject.email = subject.Email
   if subject.O:
      certModel.subject.organization = subject.O
   if subject.OU:
      certModel.subject.organizationUnit = subject.OU
   if subject.L:
      certModel.subject.locality = subject.L
   if subject.SP:
      certModel.subject.stateOrProvince = subject.SP
   if subject.C:
      certModel.subject.country = subject.C
   
   issuer = cert.get_issuer()
   certModel.issuer.commonName = ""
   if issuer.CN:
      certModel.issuer.commonName = issuer.CN
   if issuer.Email:
      certModel.issuer.email = issuer.Email
   if issuer.O:
      certModel.issuer.organization = issuer.O
   if issuer.OU:
      certModel.issuer.organizationUnit = issuer.OU
   if issuer.L:
      certModel.issuer.locality = issuer.L
   if issuer.SP:
      certModel.issuer.stateOrProvince = issuer.SP
   if issuer.C:
      certModel.issuer.country = issuer.C

   certModel.notBefore, certModel.notAfter = getDatesEpoch( certificate, isFile )

   pkey = cert.get_pubkey()
   
   publicKey =  PublicKey()
   if isFile:
      with open( certificate, 'r' ) as fp:
         certData = fp.read()
   else:
      certData = certificate
   publicKey.encryptionAlgorithm = SslCertKey.getCertPublicKeyAlgo( certData )
   publicKey.size = SslCertKey.getCertPublicKeySize( certData )
   if publicKey.encryptionAlgorithm == PublicKeyAlgorithm.RSA:
      rsakey = pkey.get_rsa()
      publicKey.modulus = int( pkey.get_modulus(), 16 )
      # pylint: disable-next=no-member
      publicKey.publicExponent = int( m2.bn_to_hex( m2.mpi_to_bn( rsakey.e ) ),
                                         16 )
   elif publicKey.encryptionAlgorithm == PublicKeyAlgorithm.ECDSA:
      publicKey.ellipticCurve = ( SslCertKey.getEcCurveFromCert( certificate,
                                                                 isFile )
                                  or "Unknown" )

   certModel.publicKey = publicKey
   
   numExtension = cert.get_ext_count()
   for i in range( 0, numExtension ):
      ext = cert.get_ext_at( i )
      if not ext:
         continue
      
      try:
         extName = ext.get_name()
         extValue = ext.get_value()
         extCritical = ext.get_critical()
      except: # pylint: disable-msg=W0702
         continue
      
      if not extName or not extValue:
         continue

      extValue = extValue.strip()
      if ssl.OPENSSL_VERSION_INFO[ 0 ] == 3:
         # In openssl1 'authorityKeyIdentifier' value was prepended with "keyid:",
         # Maintain this behaviour on el9/openssl3.
         hexIdRe = r'([0-9a-fA-F]{2}:)+[0-9a-fA-F]{2}'
         if extName == "authorityKeyIdentifier" and \
               re.match( hexIdRe, extValue ) and not extValue.startswith( "keyid" ):
            extValue = "keyid:" + extValue

      certModel.extension[ extName ] = Extension( 
                                        value=extValue,
                                        critical=bool( extCritical ) )
   return certModel

def getRsaPublicKeyModel( keyFile ):
   from M2Crypto import RSA, m2 # pylint: disable-msg=C0415
   rsa = RSA.load_key( keyFile )
   keyModel = PublicKey()
   ( e, m ) = rsa.pub()
   keyModel.size = len( rsa )
   # pylint: disable-next=no-member
   keyModel.publicExponent = int( m2.bn_to_hex( m2.mpi_to_bn( e ) ), 16 )
   # pylint: disable-next=no-member
   keyModel.modulus = int( m2.bn_to_hex( m2.mpi_to_bn( m ) ), 16 )
   keyModel.encryptionAlgorithm = "RSA"
   
   return keyModel

def getEcPublicKeyModel( keyFile ):
   from M2Crypto import EC # pylint: disable-msg=C0415
   import SslCertKey # pylint: disable-msg=C0415
   ec = EC.load_key( keyFile )
   keyModel = PublicKey()
   keyModel.size = len( ec )
   keyModel.encryptionAlgorithm = "ECDSA"
   keyModel.ellipticCurve = SslCertKey.getEcCurveFromKey( keyFile ) or "Unknown"
   return keyModel

def getPublicKeyModel( keyFile ):
   from M2Crypto import RSA # pylint: disable-msg=C0415
   try:
      return getRsaPublicKeyModel( keyFile )
   except RSA.RSAError:
      return getEcPublicKeyModel( keyFile )

def getCrlModel( crl, isFile=True ):
   crlModel = Crl()
   crlModel.issuer = DistinguishedName()
   if isFile:
      output = Tac.run( [ 'openssl', 'crl', '-noout', '-crlnumber', '-issuer',
                       '-lastupdate', '-nextupdate', '-in', crl ],
                        stdout=Tac.CAPTURE, stderr=Tac.CAPTURE,
                        ignoreReturnCode=True )
   else:
      output = Tac.run( [ 'openssl', 'crl', '-noout', '-crlnumber', '-issuer',
                       '-lastupdate', '-nextupdate' ],
                       stdout=Tac.CAPTURE, stderr=Tac.CAPTURE,
                       ignoreReturnCode=True, input=crl )

   values = dict( k.split( '=', 1 ) for k in output.strip().split( '\n' ) )
   try:
      # Note: openssl1 crlnumber are not prepended by 0x but openssl3 are.
      crlModel.crlNumber = int( values[ 'crlNumber' ], 16 )
   except ValueError:
      crlModel.crlNumber = 0

   if values[ 'issuer' ].startswith( '/' ):
      issuer = dict( k.split( '=', 1 ) for k in
                  values[ 'issuer' ].strip( '/' ).split( '/' ) )
   else: # openssl3 uses `,` as separator
      issuer = dict( [ s.strip() for s in k.split( '=', 1 ) ]
            for k in values[ 'issuer' ].strip().split( ',' ) )
   crlModel.issuer.commonName = issuer.get( 'CN' )
   crlModel.issuer.email = issuer.get( 'emailAddress' )
   crlModel.issuer.organization = issuer.get( 'O' )
   crlModel.issuer.organizationUnit = issuer.get( 'OU' )
   crlModel.issuer.stateOrProvince = issuer.get( 'ST' )
   crlModel.issuer.country = issuer.get( 'C' )
   lastUpdate = parser.parse( values[ 'lastUpdate' ], tzinfos={ 'GMT' : utc } )
   nextUpdate = parser.parse( values[ 'nextUpdate' ], tzinfos={ 'GMT' : utc } )
   
   epochDt = datetime.datetime( 1970, 1, 1, tzinfo=utc )
   crlModel.lastUpdate = int( ( lastUpdate - epochDt ).total_seconds() )
   crlModel.nextUpdate = int( ( nextUpdate - epochDt ).total_seconds() )

   #Get list of revoked certs in CRL
   if isFile:
      txtOutput = Tac.run( [ 'openssl', 'crl', '-noout', '-text', '-in', crl ],
                  stdout=Tac.CAPTURE, stderr=Tac.CAPTURE, ignoreReturnCode=True )
   else:
      txtOutput = Tac.run( [ 'openssl', 'crl', '-noout', '-text' ],
                  stdout=Tac.CAPTURE, stderr=Tac.CAPTURE, ignoreReturnCode=True,
                  input=crl )
   pattern = r'\s+Serial Number: (?P<serial>\w+)\s+Revocation Date: (?P<date>.*)'
   for match in re.finditer( pattern, txtOutput, re.MULTILINE ):
      serialNumber, revocationDate = match.group( 'serial', 'date' )
      revokedCert = RevokedCertificate()
      revokedCert.serialNumber = serialNumber
      revokedCert.revocationDate = float( time.mktime( time.strptime(
          revocationDate, '%b %d %H:%M:%S %Y GMT')))
      crlModel.revokedList.append( revokedCert ) # pylint: disable-msg=E1101

   return crlModel

def getDhparamsModel( dhparamsFile ):
   from M2Crypto import DH, m2 # pylint: disable-msg=C0415
   dhModel = DiffieHellmanParameters()
   dh = DH.load_params( dhparamsFile )
   dhModel.size = len( dh ) << 3
   # pylint: disable-next=no-member
   dhModel.prime = int( m2.bn_to_hex( m2.mpi_to_bn( dh.p ) ), 16 )
   # pylint: disable-next=no-member
   dhModel.generator = int( m2.bn_to_hex( m2.mpi_to_bn( dh.g ) ), 16 )
   return dhModel

def getCertificateSigningRequestModel( csrFile ):
   # ssl is a marked as memory hog module in Cli.
   # So, conditionally imported the ssl module only
   # when ssl related commands are executed.
   from M2Crypto import X509, m2 # pylint: disable-msg=C0415
   import SslCertKey # pylint: disable-msg=C0415
   csrModel = CertificateSigningRequest()
   csrModel.subject = DistinguishedName()
   with open( csrFile ) as f:
      pemValue = f.read()

   csrModel.pemValue = pemValue
   csr = X509.load_request_string( pemValue )
   csrModel.version = csr.get_version()
   subject = csr.get_subject()
   csrModel.subject.commonName = subject.CN if subject.CN else ''
   csrModel.subject.email = subject.Email if subject.Email else None
   csrModel.subject.organization = subject.O if subject.O else None
   csrModel.subject.organizationUnit = subject.OU if subject.OU else None
   csrModel.subject.locality = subject.L if subject.L else None
   csrModel.subject.stateOrProvince = subject.SP if subject.SP else None
   csrModel.subject.country = subject.C if subject.C else None

   publicKey = PublicKey()
   publicKey.encryptionAlgorithm = SslCertKey.getCsrPublicKeyAlgo( pemValue )
   publicKey.size = SslCertKey.getCsrPublicKeySize( pemValue )
   if publicKey.encryptionAlgorithm == PublicKeyAlgorithm.RSA:
      pkey = csr.get_pubkey()
      rsakey = pkey.get_rsa()
      publicKey.modulus = int( pkey.get_modulus(), 16 )
      # pylint: disable-next=no-member
      publicKey.publicExponent = int( m2.bn_to_hex( m2.mpi_to_bn( rsakey.e ) ),
                                         16 )
   elif publicKey.encryptionAlgorithm == PublicKeyAlgorithm.ECDSA:
      publicKey.ellipticCurve = ( SslCertKey.getEcCurveFromCsr( csrFile )
                                  or "Unknown" )

   csrModel.publicKey = publicKey
   return csrModel

# List the allowed ciphersuite names based on openssl ciphers command output
# Parameters:
#       cipherListStr: the cipher list string for TLSv1.2 and below
#       cipherSuiteStr: the ciphersuite string for TLSv1.3
#
# For example:
# To get TLSv1.2 and below ciphersuites:
# > openssl ciphers -s -tls1 -tls1_1 -tls1_2 'EECDH+AESGCM+ECDSA:AES256+EECDH:!TLSv1'
# ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES256
# -GCM-SHA384:ECDHE-ECDSA-AES256-CCM8:ECDHE-ECDSA-AES256-CCM:
# ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384
# To get TLSv1.3 ciphersuites:
# > openssl ciphers -s -tls1_3 -ciphersuites TLS_AES_128_GCM_SHA256:
# TLS_AES_128_CCM_SHA256
# TLS_AES_128_GCM_SHA256:TLS_AES_128_CCM_SHA256
def listCipherSuiteNames( cipherListStr="", cipherSuiteStr="", verbose=False ):
   cmd = [ "openssl", "ciphers" ]
   if verbose:
      cmd += [ "-v" ]
   if cipherSuiteStr:
      if not cipherListStr:
         # To check ciphersuites for TLSv1.3, we need to use
         # a different option '-ciphersuites'
         cmd += [ "-s", "-tls1_3" ]
      cmd += [ "-ciphersuites", cipherSuiteStr ]

   if cipherListStr:
      if not cipherSuiteStr:
         cmd += [ "-s", "-tls1", "-tls1_1", "-tls1_2" ]
      cmd += [ cipherListStr ]

   output = Tac.run( cmd, stdout=Tac.CAPTURE, stderr=Tac.CAPTURE,
                     ignoreReturnCode=True )
   return output.strip()
