# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from CliPlugin import SslCliLib
from CliPlugin import SslModel
from CliPlugin.Ssl import (
   config,
   SslProfileConfigMode,
   status,
   _listdir
)
import DefaultSslLib
import DefaultSslProfile
import LazyMount
import os
import SslCertKey
import Tac
from TypeFuture import TacLazyType
from Toggles import MgmtSecurityToggleLib

execRequest = None
Constants = Tac.Type( "Mgmt::Security::Ssl::Constants" )
NamedDhparams = Tac.Type( "Mgmt::Security::Ssl::NamedDhparams" )
CertLocation = TacLazyType( "Mgmt::Security::Ssl::CertLocation" )
KeyLocation = TacLazyType( "Mgmt::Security::Ssl::KeyLocation" )
ErrorType = Tac.Type( "Mgmt::Security::Ssl::ErrorType" )

def GotoSslProfileModeCmd_handler( mode, args ):
   profileName = args[ 'PROFILE_NAME' ]
   childMode = mode.childMode( SslProfileConfigMode, profileName=profileName )
   if childMode.profileConfig_.profileType != "profileTypeSsl":
      mode.addError( "Not an SSL profile" )
      return
   mode.session_.gotoChildMode( childMode )

def GotoSslProfileModeCmd_noOrDefaultHandler( mode, args ):
   profileName = args[ 'PROFILE_NAME' ]
   if profileName == DefaultSslProfile.ARISTA_PROFILE:
      mode.addError( "Cannot delete default profile" )
      return
   profileConfig = config.profileConfig.get( profileName )
   if profileConfig and profileConfig.profileType != "profileTypeSsl":
      mode.addError( "Not an SSL profile" )
      return
   del config.profileConfig[ profileName ]

def convertToSeconds( amountOfTime, timeUnit ):
   if timeUnit == "days":
      amountOfTime *= 24
   return amountOfTime * 3600

def SslMonitorExpiry_handler( mode, args ):
   timeUntilExpiry = convertToSeconds( args[ 'TIME_UNTIL_EXPIRY' ],
                                       args[ 'TIME_UNIT' ] )
   repeatInterval = convertToSeconds( args[ 'REPEAT_INTERVAL' ],
                                      args[ 'REPEAT_INTERVAL_TIME_UNIT' ] )

   if repeatInterval > timeUntilExpiry:
      mode.addError( "Repeat interval needs to be less than expiry time" )
      return

   if timeUntilExpiry in config.monitorConfig and \
      repeatInterval == config.monitorConfig[ timeUntilExpiry ]:
      return

   if timeUntilExpiry in config.monitorConfig:
      del config.monitorConfig[ timeUntilExpiry ]

   monitorConfig = config.newMonitorConfig( timeUntilExpiry )
   monitorConfig.repeatInterval = repeatInterval

def SslMonitorExpiry_noOrDefaultHandler( mode, args ):
   if timeUnit:= args.get( 'TIME_UNIT' ):
      timeUntilExpiry = convertToSeconds( args[ 'TIME_UNTIL_EXPIRY' ], timeUnit )
      del config.monitorConfig[ timeUntilExpiry ]
   else:
      config.monitorConfig.clear()

def VerifyHostnameActionCmd_handler( mode, args ):
   mode.enableVerifyHosenameMatch()

def VerifyHostnameActionCmd_noOrDefaultHandler( mode, args ):
   mode.disableVerifyHosenameMatch()

def VerifyExtendedParametersCmd_handler( mode, args ):
   mode.enableExtendedParameters()

def VerifyExtendedParametersCmd_noOrDefaultHandler( mode, args ):
   mode.disableExtendedParameters()

def VerifyChainBasicConstraintTrustCmd_handler( mode, args ):
   if 'trust' in args:
      mode.enableVerifyBasicConstraintTrust()
   else:
      mode.enableVerifyBasicConstraintChain()

def VerifyChainBasicConstraintTrustCmd_noOrDefaultHandler( mode, args ):
   if 'trust' in args:
      mode.disableVerifyBasicConstraintTrust()
   else:
      mode.disableVerifyBasicConstraintChain()

def IgnoreExpiryDateCmd_handler( mode, args ):
   mode.disableVerifyExpiryDateEndCert()

def IgnoreExpiryDateCmd_noOrDefaultHandler( mode, args ):
   mode.enableVerifyExpiryDateEndCert()

def IgnoreExpiryDateTrustCmd_handler( mode, args ):
   mode.disableVerifyExpiryDateTrustCert()

def IgnoreExpiryDateTrustCmd_noOrDefaultHandler( mode, args ):
   mode.enableVerifyExpiryDateTrustCert()

def VerifyTrustHostnameFqdnCmd_handler( mode, args ):
   mode.enableVerifyTrustHostnameFqdn()

def VerifyTrustHostnameFqdnCmd_noOrDefaultHandler( mode, args ):
   mode.disableVerifyTrustHostnameFqdn()

def IgnoreExpiryDateCrlCmd_handler( mode, args ):
   mode.disableVerifyExpiryDateCrl()
   config.ignExpDateCrlsNewSyntax = False

def IgnoreExpiryDateCrlCmd_noOrDefaultHandler( mode, args ):
   mode.enableVerifyExpiryDateCrl()
   config.ignExpDateCrlsNewSyntax = False

def RevocationIgnoreExpiryDateCrlCmd_handler( mode, args ):
   mode.disableVerifyExpiryDateCrl()
   config.ignExpDateCrlsNewSyntax = True

def RevocationIgnoreExpiryDateCrlCmd_noOrDefaultHandler( mode, args ):
   mode.enableVerifyExpiryDateCrl()
   config.ignExpDateCrlsNewSyntax = True

def VerifyPeerHostnameActionCmd_handler( mode, args ):
   mode.profileConfig_.verifyPeerHostnameInCommonName = 'common-name' in args
   mode.profileConfig_.verifyPeerHostnameInSan = True

def VerifyPeerHostnameActionCmd_noOrDefaultHandler( mode, args ):
   mode.profileConfig_.verifyPeerHostnameInSan = False
   mode.profileConfig_.verifyPeerHostnameInCommonName = False

def TlsVersions_handler( mode, args ):
   mask = 0
   for vStr, vMask in DefaultSslLib.TlsStrToVersionsMap.items():
      if vStr not in args:
         continue
      mask += vMask

   if 'remove' in args:
      mode.disableTlsVersion( mask )
   else:
      add = 'add' in args
      mode.enableTlsVersion( mask, add=add )

def TlsVersions_noOrDefaultHandler( mode, args ):
   mode.enableTlsVersion( Constants.allTlsVersion )

def FipRestrictionsCmd_handler( mode, args ):
   mode.enableFipsMode()

def FipRestrictionsCmd_noOrDefaultHandler( mode, args ):
   mode.disableFipsMode()

def DiffieHellmanParamsCmd_handler( mode, args ):
   mode.setDhparam( args[ 'DHPARAMS' ] )

def DiffieHellmanParamsCmd_noOrDefaultHandler( mode, args ):
   mode.disableDhparam()

def CipherListCmd_handler( mode, args ):
   mode.enableCipherList( args[ 'CIPHERS' ] )

def CipherListCmd_noOrDefaultHandler( mode, args ):
   mode.disableCipherList()

def CipherCmd_handler( mode, args ):
   mode.enableCipher( args[ 'VERSION' ], args[ 'CIPHERS' ] )
   config.ciphersNewSyntax = True

def CipherCmd_noOrDefaultHandler( mode, args ):
   mode.disableCipher( args[ 'VERSION' ] )
   config.ciphersNewSyntax = True

def CertificateCmd_handler( mode, args ):
   if args.get( 'auto' ) and args.get( 'AUTO_CERT_PROFILE' ):
      autoCertProfileName = args[ 'AUTO_CERT_PROFILE' ]
      autoCertFile = f"{autoCertProfileName}.crt"
      mode.setCertKey( autoCertFile, "", CertLocation.autoCerts )
   else:
      mode.setCertKey( args[ 'CERT_NAME' ], args[ 'KEY_NAME' ], CertLocation.certs )

def CertificateCmd_noOrDefaultHandler( mode, args ):
   mode.noCertKey()

def TrustCertificateCmd_handler( mode, args ):
   mode.addTrustedCert( args.get( 'CERT_NAME', Constants.system ) )

def NoTrustCertificateCmd_noOrDefaultHandler( mode, args ):
   mode.noTrustedCert( args.get( 'CERT_NAME', Constants.system ) )

def ChainCertificateCmd_handler( mode, args ):
   mode.addChainedCert( args[ 'CERT_NAME' ] )

def NoChainCertificateCmd_noOrDefaultHandler( mode, args ):
   mode.noChainedCert( args[ 'CERT_NAME' ] )

def ChainIncludeRootCACmd_handler( mode, args ):
   mode.setVerifyChainHasRootCA( True )

def ChainIncludeRootCACmd_noOrDefaultHandler( mode, args ):
   mode.setVerifyChainHasRootCA( False )

def CrlCmd_handler( mode, args ):
   mode.addCrl( args[ 'CRL_NAME' ] )
   config.crlsNewSyntax = False

def NoCrlCmd_noOrDefaultHandler( mode, args ):
   mode.noCrl( args[ 'CRL_NAME' ] )
   config.crlsNewSyntax = False

def RevocationCrlCmd_handler( mode, args ):
   mode.addCrl( args[ 'CRL_NAME' ] )
   config.crlsNewSyntax = True

def RevocationNoCrlCmd_noOrDefaultHandler( mode, args ):
   mode.noCrl( args[ 'CRL_NAME' ] )
   config.crlsNewSyntax = True

def OcspProfileCmd_handler( mode, args ):
   mode.setOcspProfile( args )

def OcspProfileCmd_noOrDefaultHandler( mode, args ):
   mode.noOcspProfile( args )

def CommonNameRegexCmd_handler( mode, args ):
   mode.setCommonNameRegex( args[ '<REGEX>' ] )

def CommonNameRegexCmd_noOrDefaultHandler( mode, args ):
   mode.noCommonNameRegex()

def ShowSecuritySslCert_handler( mode, args ):
   ret = SslModel.Certificates()

   if args.get( 'auto' ):
      certName = args.get( 'AUTO_CERT_NAME' )
      if certName:
         certs = [ certName ]
      else:
         certs = _listdir( Constants.autoCertsDirPath() )
   else:
      certName = args.get( 'CERT_NAME' ) or args.get( 'system' )
      if certName:
         certs = [ certName ]
      else:
         certs = _listdir( Constants.certsDirPath() )

   for cert in certs:
      if args.get( 'auto' ):
         certFile = Constants.autoCertPath( cert )
         certificates = SslCertKey.extractCerts( cert,
                                   certLocation=CertLocation.autoCerts )
      else:
         certFile = Constants.certPath( cert )
         certificates = SslCertKey.extractCerts( cert,
                                                 certLocation=CertLocation.certs )
      if os.path.isfile( certFile ):
         for certData, lineno in certificates:
            if SslCertKey.hasCertificate( certData ):
               if len( certificates ) > 1:
                  certId = f"{cert}/{lineno}"
               else:
                  certId = cert
               ret.certificates[ certId ] = SslCliLib.getCertificateModel(
                  certData, isFile=False )
   return ret

def ShowSecuritySslKey_handler( mode, args ):
   keyName = args.get( 'KEY_NAME' )
   ret = SslModel.PublicKeys()
   if keyName:
      keys = [ keyName ]
   else:
      keys = _listdir( Constants.keysDirPath() )

   for key in keys:
      keyFile = Constants.keyPath( key )
      if os.path.isfile( keyFile ):
         ret.publicKeys[ key ] = SslCliLib.getPublicKeyModel( keyFile )
   return ret

def ShowSecuritySslCrl_handler( mode, args ):
   crlName = args.get( 'CRL_NAME' )
   ret = SslModel.Crls()
   if crlName:
      crls = [ crlName ]
   else:
      crls = _listdir( Constants.certsDirPath() )

   for crl in crls:
      crlFile = Constants.certPath( crl )
      crls = SslCertKey.extractCrls( crl )
      if os.path.isfile( crlFile ):
         for crlData, lineno in crls:
            if SslCertKey.hasCrl( crlData ):
               if len( crls ) > 1:
                  crlId = f"{crl}/{lineno}"
               else:
                  crlId = crl
               ret.crls[ crlId ] = SslCliLib.getCrlModel( crlData,
                                                          isFile=False )
   return ret

def _dhparamsResetAttempted():
   return status.dhparamsResetProcessed or status.dhparamsLastResetFailed

def ShowSecuritySslDiffieHellman_handler( mode, args ):
   ret = SslModel.DiffieHellman()
   if args.get( 'DHPARAMS' ) and args[ 'DHPARAMS' ] != NamedDhparams.generated:
      dhparamsFile = Constants.namedDhparamPath( args[ 'DHPARAMS' ] )
      ret.diffieHellmanParameters = SslCliLib.getDhparamsModel( dhparamsFile )
      ret.dhparamsResetInProgress = False
      return ret
   if status.dhparamsResetInProgress:
      ret.dhparamsResetInProgress = True
   else:
      if _dhparamsResetAttempted():
         ret.dhparamsLastResetFailed = status.dhparamsLastResetFailed
         if status.dhparamsResetProcessed:
            ret.dhparamsLastSuccessfulReset = int( status.dhparamsResetProcessed +
                                                   Tac.utcNow() -
                                                   Tac.now() )
      ret.dhparamsResetInProgress = False
      dhparamsFile = Constants.dhParamPath()
      if os.path.isfile( dhparamsFile ):
         ret.diffieHellmanParameters = SslCliLib.getDhparamsModel( dhparamsFile )
   return ret

def _profileError( err ):
   errorDatesModel = None

   if ( MgmtSecurityToggleLib.toggleExtendedInvalidCertificateStateEnabled()
         and err.errorAttr == "certificate"
         and err.errorType in [ ErrorType.certNotYetValid,
                                ErrorType.certExpired ] ):
      fullPath = Constants.certPath( err.errorAttrValue )
      if os.path.isfile( fullPath ):
         notBefore, notAfter = SslCliLib.getDatesEpoch( fullPath )
         errorDatesModel = SslModel.InvalidCertificateDates(
            notBefore=notBefore if err.errorType == ErrorType.certNotYetValid
            else None,
            notAfter=notAfter if err.errorType == ErrorType.certExpired
            else None
         )

   profErr = SslModel.ProfileError( errorAttr=err.errorAttr,
                                    errorAttrValue=err.errorAttrValue,
                                    errorType=err.errorType,
                                    errorTypeExtra=err.errorTypeExtra,
                                    errorDates=errorDatesModel )
   return profErr

def ShowSecuritySslProfile_handler( mode, args ):
   profileName = args.get( 'PROFILE_NAME' )
   ret = SslModel.SslStatus()
   if profileName:
      profiles = [ profileName ]
   else:
      profiles = status.profileStatus
   
   if "detail" in args:
      ret._detail = True # pylint: disable-msg=W0212
   for name in profiles:
      profile = status.profileStatus.get( name )
      if profile is not None:
         profileStatusModel = SslModel.ProfileStatus()
         profileStatusModel.profileState = profile.state
         for err in profile.error.values():
            profErr = _profileError( err )
            # pylint: disable-next=no-member
            profileStatusModel.profileError.append( profErr )
            ret._hasError = True # pylint: disable-msg=W0212
         for warning in profile.warning.values():
            profWarning = SslModel.ProfileError(
               errorAttr=warning.errorAttr,
               errorAttrValue=warning.errorAttrValue,
               errorType=warning.errorType,
               errorTypeExtra=warning.errorTypeExtra )
            if profWarning not in profileStatusModel.profileError:
               # pylint: disable-next=no-member
               profileStatusModel.profileError.append( profWarning )
               ret._hasError = True # pylint: disable-msg=W0212

         if profile.certKeyPair:
            profileStatusModel.certName = profile.certKeyPair.certFile
            profileStatusModel.keyName = profile.certKeyPair.keyFile
         profileStatusModel.chainedCertificates = sorted(
            profile.chainedCert ) if profile.chainedCert else None
         profileStatusModel.trustedCertificates = sorted(
            profile.trustedCert ) if profile.trustedCert else None
         profileStatusModel.crls = sorted(
            profile.crl ) if profile.crl else None
         if profile.ocspSettings:
            profileStatusModel.ocspProfileName = profile.ocspSettings.name
         profileStatusModel.tlsVersion = DefaultSslLib.tlsVersionMaskToStrList(
            profile.tlsVersion )
         profileStatusModel.fipsMode = profile.fipsMode
         profileStatusModel.cipherList = profile.cipherSuite
         profileStatusModel.cipherSuiteV1_3 = profile.cipherSuiteV1_3
         ret.profileStatus[ name ] = profileStatusModel
   return ret

def ShowSecuritySslProfileCiphers_handler( mode, args ):
   def __getCipherSuitesList( cipherListStr="", cipherSuiteStr="" ):
      cipherStr = SslCliLib.listCipherSuiteNames( cipherListStr, cipherSuiteStr )
      return cipherStr.split( ':' )

   profileName = args.get( 'PROFILE_NAME' )
   ret = SslModel.ProfileCiphers()
   if profileName:
      profiles = [ profileName ]
   else:
      profiles = status.profileStatus

   for name in profiles:
      profile = status.profileStatus.get( name )
      if profile is not None:
         profileCipher = SslModel.ProfileCipher()
         profileCipher.cipherList = __getCipherSuitesList(
            cipherListStr=profile.cipherSuite )
         profileCipher.cipherSuiteV1_3 = __getCipherSuitesList(
            cipherSuiteStr=profile.cipherSuiteV1_3 )
         ret.profileCiphers[ name ] = profileCipher
   return ret

def ResetSslDiffieHellmanParametersCmd_handler( mode, args ):
   execRequest.dhparamsResetRequest = Tac.now()

def Plugin( entityManager ):
   global execRequest
   execRequest = LazyMount.mount( entityManager, "mgmt/security/ssl/execRequest",
                                  "Mgmt::Security::Ssl::ExecRequest", "w" )
