# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function
from CliModel import Bool
from CliModel import Dict
from CliModel import List
from CliModel import Model
from CliModel import Str
from CliModel import Enum
from CliModel import Int
from CliModel import Float
from CliModel import Submodel
import TableOutput
import Tac
import textwrap
import datetime
import SslMonitorExpiry
from Toggles import MgmtSecurityToggleLib

# Copied from datetime docs to avoid pytz
ZERO = datetime.timedelta(0)
class UTC( datetime.tzinfo ):
   """UTC"""
   def utcoffset( self, dt ):
      return ZERO
   def tzname( self, dt ):
      return "UTC"
   def dst( self, dt ):
      return ZERO

utc = UTC()

ProfileState = Tac.Type( "Mgmt::Security::Ssl::ProfileState" )
ErrorType = Tac.Type( "Mgmt::Security::Ssl::ErrorType" )
ErrorAttr = Tac.Type( "Mgmt::Security::Ssl::ErrorAttr" )

def _printLineItem( label, content, space=30 ):
   fmt = f'{{:{space}s}}{{:}}'
   if space:
      print( fmt.format( f'{label}:', content ) )
   else:
      print( f'{label}: {content}' )

def printDN( dn ):
   _printLineItem( "      Common name", f"{dn.commonName}" )
   if dn.email:
      _printLineItem( "      Email address", f"{dn.email}" )
   if dn.organizationUnit:
      _printLineItem( "      Organizational unit", f"{dn.organizationUnit}" )
   if dn.organization:
      _printLineItem( "      Organization", f"{dn.organization}" )
   if dn.locality:
      _printLineItem( "      Locality", f"{dn.locality}" )
   if dn.stateOrProvince:
      _printLineItem( "      State", f"{dn.stateOrProvince}" )
   if dn.country:
      _printLineItem( "      Country", f"{dn.country}" )

class PublicKey( Model ):
   # Currently only RSA and ECDSA keys are supported
   encryptionAlgorithm = Enum( help="Encryption algorithm of the key", 
                               values=( "RSA", "ECDSA", ) )
   size = Int( help="Size of the key in bits" )
   modulus = Int( help="Modulus of the key (a value of 0 means no modulus)",
                  default=0 )
   publicExponent = Int( help=( "Public exponent of the RSA key. "
                                "This field is not present for DSA key" ),
                         optional=True )
   ellipticCurve = Str( help=( "Elliptic curve used to generate the ECDSA key."
                               " This field is not present for other key types" ),
                        optional=True )

class DistinguishedName( Model ):
   commonName = Str( help="Common name" )
   email = Str( help="Email address", optional=True )
   organization = Str( help="Name of the organization", optional=True )
   organizationUnit = Str( help="Name of the organizational unit", 
                           optional=True )
   locality = Str( "Name of the locality", optional=True )
   stateOrProvince = Str( "Name of the state or province", optional=True )
   country = Str( "Name of the country", optional=True )

class Extension( Model ):
   value = Str( help="Value of the X.509 version 3 extension" )
   critical = Bool( help="Whether the extension is critical" )
   
class Certificate( Model ):
   version = Int( help="X.509 version" )
   serialNumber = Int( help="Serial number" )
   subject = Submodel( help="Entity associated with the certificate",
                      valueType=DistinguishedName )
   issuer = Submodel( help="Entity who has signed and issued the certificate",
                      valueType=DistinguishedName )
   notBefore = Int( help=( "Timestamp on which the certificate " 
                           "validity period begins" ) )
   notAfter = Int( help=( "Timestamp on which the certificate " 
                          "validity period ends" ) )
   publicKey = Submodel( help="Public key information",
                         valueType=PublicKey )
   extension = Dict( help=( "Mapping from X.509 version 3 "
                            "extension name to extension value" ),
                     valueType=Extension,
                     optional=True )
   
   def render( self ):
      _printLineItem( "   Version", f"{self.version}" )
      _printLineItem( "   Serial Number", f"{self.serialNumber:x}" )
      _printLineItem( "   Issuer", "" )
      printDN( self.issuer )
      
      epochDt = datetime.datetime( 1970, 1, 1, tzinfo=utc )
      notBeforeDt = epochDt + datetime.timedelta( seconds=self.notBefore )
      notAfterDt = epochDt + datetime.timedelta( seconds=self.notAfter )
      
      _printLineItem( "   Validity", "" )
      _printLineItem( "      Not before",
                      f"{notBeforeDt.strftime( '%b %d %H:%M:%S %Y' )} GMT" )
      _printLineItem( "      Not After",
                      f"{notAfterDt.strftime( '%b %d %H:%M:%S %Y' )} GMT" )

      _printLineItem( "   Subject", "" )
      printDN( self.subject )
      _printLineItem( "   Subject public key info", "" )
      _printLineItem( "      Encryption Algorithm", 
                      f"{self.publicKey.encryptionAlgorithm}" )
      _printLineItem( "      Size", f"{self.publicKey.size} bits" )
      if self.publicKey.publicExponent:
         _printLineItem( "      Public exponent", 
                         f"{self.publicKey.publicExponent}" )
      if self.publicKey.modulus:
         hexmod = f"{self.publicKey.modulus:x}"
         output = textwrap.fill( hexmod, initial_indent=30 * ' ',
                                 subsequent_indent=30 * ' ',
                                 width=85 )
         _printLineItem( "      Modulus", output.lstrip() )
      if self.publicKey.ellipticCurve:
         _printLineItem( "      Elliptic curve", self.publicKey.ellipticCurve )
      if len( self.extension ):
         _printLineItem( "   X509v3 extensions", "" )
      for name, ext in sorted( self.extension.items() ):
         _printLineItem( f"      {name}",
                         "Critical" if ext.critical else "",
                         space=0 ) 
         print( textwrap.fill( ext.value,
                              initial_indent='         ',
                              subsequent_indent='         ',
                              width=85 ) )

class Certificates( Model ):
   certificates = Dict( help=( "Mapping from certificate name to certificate "
                               "used in SSL/TLS" ),
                        valueType=Certificate )
   def render( self ):
      for certId, cert in sorted( self.certificates.items() ):
         certificate = certId.split( "/" )
         if len( certificate ) > 1:
            name, lineno = certificate
            cn = cert.subject.commonName
            if cn != "":
               print( f"Certificate {name}, line {lineno}, common-name {cn}:" )
            else:
               print( f"Certificate {name}, line {lineno}:" )
         else:
            print( f"Certificate {certificate[ 0 ]}:" )
         cert.render()

class RevokedCertificate( Model ):
   serialNumber = Str( help="Serial number of the CRL" )
   revocationDate = Float( help="revocation Date" )

class CertificateSigningRequest( Model ):
   version = Int( help="X.509 version" )
   subject = Submodel( help="Entity associated with the certificate",
                      valueType=DistinguishedName )
   publicKey = Submodel( help="Public key information",
                         valueType=PublicKey )
   pemValue = Str( help="CSR in PEM format" )

   def render( self ):
      _printLineItem( "   Data", "" )
      _printLineItem( "   Version", f"{self.version}" )
      printDN( self.subject )
      _printLineItem( "   Subject public key info", "" )
      _printLineItem( "      Encryption Algorithm",
                      f"{self.publicKey.encryptionAlgorithm}" )
      _printLineItem( "      Size", f"{self.publicKey.size} bits" )
      if self.publicKey.publicExponent:
         _printLineItem( "      Public exponent",
                         f"{self.publicKey.publicExponent}" )
      if self.publicKey.modulus:
         hexmod = f"{self.publicKey.modulus:x}"
         output = textwrap.fill( hexmod, initial_indent=30 * ' ',
                                 subsequent_indent=30 * ' ',
                                 width=85 )
         _printLineItem( "      Modulus", output.lstrip() )
      _printLineItem( "      PEM Value", "\n" + self.pemValue )

class Crl( Model ):
   crlNumber = Int( help=( "CRL number" ) )
   issuer = Submodel( help="Entity who has signed and issued the certificate",
                      valueType=DistinguishedName )
   lastUpdate = Int( help=( "Timestamp on which the CRL " 
                            "validity period begins" ) )
   nextUpdate = Int( help=( "Timestamp on which the CRL " 
                            "validity period ends" ) )
   revokedList = List( help= "Serial number and Timestamp of the revoked "
                             "certificate", valueType=RevokedCertificate )

   def render( self ):
      _printLineItem( "   CRL Number", f"{self.crlNumber:x}" )
      _printLineItem( "   Issuer", "" )
      printDN( self.issuer )
      epochDt = datetime.datetime( 1970, 1, 1, tzinfo=utc )
      lastUpdate = epochDt + datetime.timedelta( seconds=self.lastUpdate )
      nextUpdate = epochDt + datetime.timedelta( seconds=self.nextUpdate )

      _printLineItem( "   Validity", "" )
      _printLineItem( "      Last Update",
                      f"{lastUpdate.strftime( '%b %d %H:%M:%S %Y' )} GMT" )
      _printLineItem( "      Next Update",
                      f"{nextUpdate.strftime( '%b %d %H:%M:%S %Y' )} GMT" )
      _printLineItem( "   Revoked Certificates",
                            "" if self.revokedList else "none" )
      for cert in self.revokedList:
         _printLineItem( "    - Serial Number", f"{cert.serialNumber}" )
         revocationDate = epochDt + datetime.timedelta( seconds=
               int( cert.revocationDate ) )
         _printLineItem( "      Revocation Date",
                         f"{revocationDate.strftime( '%b %d %H:%M:%S %Y' )} GMT" )

class Crls( Model ):
   crls = Dict( help=( "Mapping from CRL name to CRL used in SSL/TLS" ),
                       valueType=Crl )
   def render( self ):
      for crlId, crl in sorted( self.crls.items() ):
         crlSplit = crlId.split( "/" )
         if len( crlSplit ) > 1:
            name, lineno = crlSplit
            print( f"CRL {name}, line {lineno}:" )
         else:
            print( f"CRL {crlSplit[ 0 ]}:" )
         crl.render()

class PublicKeys( Model ):
   publicKeys = Dict( help="Mapping from key name to public key used in SSL/TLS",
                     valueType=PublicKey )
   
   def render( self ):
      for name, key in sorted( self.publicKeys.items() ):
         print( f"Key {name}:" )
         _printLineItem( "   Encryption Algorithm", f"{key.encryptionAlgorithm}" )
         _printLineItem( "   Size", f"{key.size} bits" )
         if key.publicExponent:
            _printLineItem( "   Public exponent", f"{key.publicExponent}" )
         if key.modulus:
            hexmod = f"{key.modulus:x}"
            output = textwrap.fill( hexmod, initial_indent=30 * ' ',
                                    subsequent_indent=30 * ' ',
                                    width=85 )
            _printLineItem( "   Modulus", output.lstrip() )
         if key.ellipticCurve:
            _printLineItem( "   Elliptic curve", key.ellipticCurve )

class DiffieHellmanParameters( Model ):
   size = Int( help="Size of the prime number in bits" )
   prime = Int( help="Prime number used in Diffie-Hellman key exchange" )
   generator = Int( help="Generator used in Diffie-Hellman key exchange" )
         
class DiffieHellman( Model ):
   dhparamsResetInProgress = Bool( 
            help="Whether Diffie-Hellman parameters is being reset" )
   dhparamsLastResetFailed = Bool(
            help="Whether last attempt to reset Diffie-Hellman parameters failed",
            optional=True )
   dhparamsLastSuccessfulReset = Int( 
            help="Last successful Diffie-Hellman parameters reset timestamp",
            optional=True )
   diffieHellmanParameters = Submodel( 
            help="Diffie-Hellman parameters",
            valueType=DiffieHellmanParameters,
            optional=True )

   def render( self ):
      if self.dhparamsResetInProgress:
         print( "Diffie-Hellman parameters reset in progress" )
      else:
         if self.dhparamsLastResetFailed:
            print( "Last attempt to reset Diffie-Hellman parameters failed" )
         if self.dhparamsLastSuccessfulReset:
            resetDt = datetime.datetime.fromtimestamp( 
                                        self.dhparamsLastSuccessfulReset )
            print( "Last successful reset on "
                   f"{resetDt.strftime( '%b %d %H:%M:%S %Y' )}" )
         dh = self.diffieHellmanParameters
         if not dh:
            return
         print( f"Diffie-Hellman Parameters {dh.size} bits" )
         _printLineItem( "   Generator", f"{dh.generator}", space=20 )

         hexprime = f"{dh.prime:x}"
         output = textwrap.fill( hexprime, initial_indent=20*' ',
                                 subsequent_indent=20*' ',
                                 width=85 )
         _printLineItem( "   Prime", f"{output.lstrip()}", space=20 )

class InvalidCertificateDates( Model ):
   notBefore = Float( help="Certificate validity start date", optional=True )
   notAfter = Float( help="Certificate expiry date", optional=True )

class ProfileError( Model ):
   errorAttr = Enum( help=( "SSL profile attribute to which the"
                            " error applies" ),
                           values=ErrorAttr.attributes )
   errorAttrValue = Str( help="SSL profile attribute value" )
   errorType = Enum( help="Error type",
                     values=ErrorType.attributes )
   errorTypeExtra = Str( help="Error type extra information",
                         optional=True )
   errorDates = Submodel(
            help="Invalid certificate dates",
            valueType=InvalidCertificateDates,
            optional=True )

class ProfileStatus( Model ):
   profileError = List( help="List of SSL profile errors in 'invalid' state",
                        valueType=ProfileError, optional=True )
   profileState = Enum( help="SSL profile state",
                        values=ProfileState.attributes )
   certName = Str( help="Certificate name",
                  optional=True )
   keyName = Str( help="Key name",
                 optional=True )
   crls = List( help="Certificate revocation lists", valueType=str,
               optional=True )
   trustedCertificates = List( help="Trusted certificates", valueType=str,
                              optional=True )
   chainedCertificates = List( help="Chained certificates", valueType=str,
                              optional=True )
   ocspProfileName = Str( help="OCSP profile name", optional=True )
   tlsVersion = List( help="TLS versions", valueType=str, optional=True )
   fipsMode = Bool( help="Whether FIPS restriction is enabled", optional=True )
   cipherList = Str( help="Cipher list of TLSv1.2 and below", optional=True )
   cipherSuiteV1_3 = Str( help="Cipher suite names of TLSv1.3", optional=True )

class SslStatus( Model ):
   profileStatus = Dict( help="Mapping from SSL profile name to status",
                         valueType=ProfileStatus )
   _hasError = Bool( help="Whether there is at least one SSL profile with error" )
   _detail = Bool( help="Detailed SSL status" )
   
   def _getCertCloseToExpiryMsg( self, msg, errorTypeExtra ):
      notAfter = float( errorTypeExtra )
      now = Tac.utcNow()
      if now > notAfter:
         return None
      timeUntilExpiry, timeUnit = SslMonitorExpiry.getDisplayableTimeAndUnit(
         notAfter - now )
      return f'! {msg} in {timeUntilExpiry} {timeUnit}'

   def _epochSecondsToTimestamp( self, seconds ):
      converted = datetime.datetime.utcfromtimestamp( seconds )
      return converted.strftime( '%b %d %H:%M:%S %Y' )

   def _addErrorType( self, e, msg ):
      if e.errorDates:
         msg += self._formatValidityMessage(
            e.errorDates.notBefore, e.errorDates.notAfter )
      else:
         msg += self._errorType( e.errorType )

      return msg

   def _formatValidityMessage( self, notBefore, notAfter ):
      if notBefore:
         return ( "is valid from "
                 f"{self._epochSecondsToTimestamp( notBefore )} GMT" )
      elif notAfter:
         return ( "has expired on "
                 f"{self._epochSecondsToTimestamp( notAfter )} GMT" )
      return ""

   def _errorType( self, errType ):
      errTypeDict = { "noProfileData": "has no data",
                      "notExist": "does not exist",
                      "notMatchingCertKey": "does not match with key",
                      "certNotYetValid": "is not yet valid",
                      "certChainNotValid": "has invalid certificate chain",
                      "certTrustChainNotValid": ( "has invalid trusted certificate"
                                                  " chain" ),
                      "missingCrlForTrustChain": ( "has missing crl issued by" 
                                                   " the trusted chain" ),
                      "certCloseToExpiry": "will expire",
                      "certExpired": "has expired",
                      "noExtendedKeyUsage": "has no extended key usage value",
                      "noCABasicConstraintTrust": ( "is trusted certificate but"
                                                    " does not have CA basic"
                                                    " constraint set to True" ),
                      "noCABasicConstraintChain": ( "is chained certificate but" 
                                                    " does not have CA basic" 
                                                    " constraint set to True" ),
                      "noCrlSign": ( "is a CRL and signed by a CA who does not have"
                                     " the cRLSign key usage bits set" ),
                      "crlNotSignedByCa": ( "is a CRL but is not signed by any"
                                            " configured trusted certificate" ),
                      "hostnameMismatch": ( "hostname of this device does not match"
                                            " any entry of the Common Name nor"
                                            " Subject Alternative Name in the"
                                            " certificate" ),
                      "fileMultiplePEMs": ( "file has multiple PEM encoded"
                                            " certificates" ),
                      "notFqdn": ( "has Common Name or Subject Alternative Name"
                                   " in the trusted certificate which uses non-FQDN"
                                   " hostnames" ),
                      "vrfDoesNotExist": ( "is configured to use a VRF which does"
                                            " not exist" )
                      }
      return errTypeDict[ errType ]

   def _errorMessage( self, profileError ):
      msgs = []
      attrDict = { "profile": "Profile",
                   "certificate": "Certificate",
                   "key": "Key",
                   "trustedCertificate": "Certificate",
                   "chainedCertificate": "Certificate",
                   "crl": "CRL",
                   "ocsp": "OCSP profile" }

      for e in profileError:
         msg = ""
         msg += attrDict[ e.errorAttr ] + " "
         if e.errorAttr != "profile":
            if " " in e.errorAttrValue:
               msg += f"{e.errorAttrValue} "
            else:
               msg += f"'{e.errorAttrValue}' "
         if MgmtSecurityToggleLib.toggleExtendedInvalidCertificateStateEnabled():
            msg = self._addErrorType( e, msg )
         else:
            msg += self._errorType( e.errorType )
         if e.errorTypeExtra:
            if e.errorType == ErrorType.certCloseToExpiry:
               msg = self._getCertCloseToExpiryMsg( msg, e.errorTypeExtra )
               if not msg:
                  continue
         msgs.append( msg )
      return '\n'.join( msgs )

   def render( self ):

      def renderStatus():
         if self._hasError:
            tableHeadings = ( "Profile", "State", "Additional Information" )
         else:
            tableHeadings = ( "Profile", "State" )
         table = TableOutput.createTable( tableHeadings )
         f = TableOutput.Format( justify="left", maxWidth=40, wrap=True )
         f.noPadLeftIs( True )
         f.padLimitIs( True )
         table.formatColumns( *( [ f ] * len( tableHeadings ) ) )
         for name, profStatus in sorted( self.profileStatus.items() ):
            table.newRow(
               name,
               profStatus.profileState,
               self._errorMessage( profStatus.profileError ) )
         print( table.output() )

      def renderStatusDetail():
         for i, name in enumerate( sorted( self.profileStatus.keys() ) ):
            profStatus = self.profileStatus[ name ]
            print( "Profile:", name )
            print( "State:", profStatus.profileState )
            if ( MgmtSecurityToggleLib
                  .toggleExtendedInvalidCertificateStateEnabled() ):
               additionalInfo = self._errorMessage( profStatus.profileError )
               if additionalInfo:
                  print( "Additional information:", additionalInfo )
            print( "Certificate:", profStatus.certName if
                  profStatus.certName else "n/a" )
            print( "Key:", profStatus.keyName if profStatus.keyName else "n/a" )

            print( "Chained certificates: "
                   f"{'n/a' if not profStatus.chainedCertificates else ''}" )
            for cert in profStatus.chainedCertificates or []:
               print( "\t-", cert )

            print( "Trusted certificates: "
                   f"{'n/a' if not profStatus.trustedCertificates else ''}" )
            for cert in profStatus.trustedCertificates or []:
               print( "\t-", cert )

            print( "Certificate revocation lists: "
                   f"{'n/a' if not profStatus.crls else ''}" )
            for crl in profStatus.crls or []:
               print( "\t-", crl )

            print( f"OCSP profile: {profStatus.ocspProfileName or 'n/a'}" )
            print( "TLS versions:", " ".join( profStatus.tlsVersion ) )
            print( "FIPS restrictions: "
                   f"{'enabled' if profStatus.fipsMode else 'disabled'}" )
            print( "TLS cipher-list(v1.2 and below):", profStatus.cipherList )
            print( "TLS cipher-suite(v1.3):", profStatus.cipherSuiteV1_3 or "n/a" )

            if i != len( self.profileStatus ) - 1:
               print()

      if self._detail:
         renderStatusDetail()
         return

      renderStatus()

class ProfileCipher( Model ):
   cipherList = List( help="Cipher suite names of TLSv1.2 and below",
                      valueType=str )
   cipherSuiteV1_3 = List( help="Cipher suite names of TLSv1.3",
                           valueType=str )

class ProfileCiphers( Model ):
   profileCiphers = Dict( keyType=str, valueType=ProfileCipher,
         help="Mapping from SSL profile name to configured cipher suite names" )

   def render( self ):
      for index, name in enumerate( self.profileCiphers ):
         ciphers = self.profileCiphers[ name ]
         print( "Profile:", name )
         print( f"TLSv1.2 and below: {'n/a' if not ciphers.cipherList else ''}" )
         for cipher in ciphers.cipherList:
            print( "\t-", cipher )

         print( f"TLSv1.3: {'n/a' if not ciphers.cipherSuiteV1_3 else ''}" )
         for cipher in ciphers.cipherSuiteV1_3:
            print( "\t-", cipher )

         if index != len( self.profileCiphers ) - 1:
            print()
