# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import BasicCliUtil
from CliPlugin.Pki import (
   getSignRequestParams,
   PkiProfileConfigMode,
)
import ConfigMount
import LazyMount
import os
import SslCertKey
import Tac

config = None
execRequest = None
status = None

pkiNotReadyError = 'PKI not ready'
Constants = Tac.Type( "Mgmt::Security::Ssl::Constants" )
PublicKeyAlgo = Tac.Type( "Mgmt::Security::Ssl::PublicKeyAlgorithm" )

def GotoPkiProfileModeCmd_handler( mode, args ):
   profileName = args[ 'PROFILE_NAME' ]
   childMode = mode.childMode( PkiProfileConfigMode, profileName=profileName )
   if childMode.profileConfig_.profileType != 'profileTypePki':
      mode.addError( 'Not a PKI profile' )
      return
   mode.session_.gotoChildMode( childMode )

def GotoPkiProfileModeCmd_noOrDefaultHandler( mode, args ):
   profileName = args[ 'PROFILE_NAME' ]
   profileConfig = config.profileConfig.get( profileName )
   if profileConfig and profileConfig.profileType != 'profileTypePki':
      mode.addError( 'Not a PKI profile' )
      return
   del config.profileConfig[ profileName ]

def PkiGenerateKey_handler( mode, args ):
   '''
   Generate a key pair for use in PKI.

   Saves under sslkey: with key if successful 
   '''
   key = args[ 'KEY_NAME' ]
   # We only support ECDSA and RSA right now
   keyAlgo = args[ 'KEY_ALGO' ]
   if not SslCertKey.isSslDirsCreated():
      mode.addError( pkiNotReadyError )
      return
   genericError = 'Error generating key'
   keyFile = os.path.join( Constants.keysDirPath(), key )
   keyLogAction, keyHash = SslCertKey.getLogActionAndFileHash( keyFile, 'sslkey:',
                                                               'created' )
   try:
      if keyAlgo == PublicKeyAlgo.RSA:
         keyParam = args[ 'RSA_BIT_LENGTH' ]
         SslCertKey.generateRsaPrivateKey( keyFile, int( keyParam ) )
      elif keyAlgo == PublicKeyAlgo.ECDSA:
         keyParam = args[ 'EC_CURVE_NAME' ]
         SslCertKey.generateEcdsaPrivateKey( keyFile, curve=keyParam )
      SslCertKey.generateSslKeyCertSysLog( keyFile, 'sslkey:',
                                           keyLogAction, keyHash )
   except SslCertKey.SslCertKeyError as e:
      mode.addError( f'{genericError} ({str( e )})' )
   except Exception: # pylint: disable=broad-except
      mode.addError( f'{genericError}' )

def _getKeyFilepath( mode, keyParams ):
   key = None
   generateNewKey = False

   if not keyParams:
      key = BasicCliUtil.getSingleLineInput( mode, 'PKI Key to use for CSR: ' )
   else:
      key = keyParams[ 'key' ]
      generateNewKey = keyParams[ 'genNewKey' ]

   if not key:
      raise SslCertKey.SslCertKeyError( 'Key is needed' )

   keyFilepath = os.path.join( Constants.keysDirPath(), key )

   if not generateNewKey and not os.path.isfile( keyFilepath ):
      raise SslCertKey.SslCertKeyError( 'Key not found under sslkey:' )

   return keyFilepath

def PkiGenerateCsrOrCert_handler( mode, args ):
   '''
   Generate CSR or self signed certificate
   '''
   digest = args.get( 'DIGEST', Constants.defaultDigest )
   signReqParams = args.get( 'SIGN_REQ_PARAMS' )
   cert = args.get( 'CERT_NAME' )
   keyAlgo = args.get( 'KEY_ALGO' )
   newKeyBits = args.get( 'RSA_BIT_LENGTH' )
   ecCurveName = args.get( 'EC_CURVE_NAME' )
   genNewKey = bool( keyAlgo )
   keyParams = None
   if 'KEY_NAME' in args:
      keyParams = { 'key': args[ 'KEY_NAME' ],
                    'genNewKey': genNewKey }
   validityDays = None
   if 'CERT_NAME' in args:
      validityDays = args.get( 'VALIDITY', Constants.defaultCertValidity )

   if not SslCertKey.isSslDirsCreated():
      mode.addError( pkiNotReadyError )
      return

   genericError = f"Error generating {'certificate' if cert else 'CSR'}"

   try:
      keyFilepath = _getKeyFilepath( mode, keyParams )

      signReqParams = getSignRequestParams( mode, signReqParams )
      certFilepath = os.path.join( Constants.certsDirPath(), cert ) if cert else None

      keyLogAction, keyHash = SslCertKey.getLogActionAndFileHash( keyFilepath,
                                                       'sslkey:', 'created' )
      certLogAction, certHash = SslCertKey.getLogActionAndFileHash( certFilepath,
                                                         'certificate:', 'created' )

      signRequest = not bool( cert )
      ( csr, _ ) = SslCertKey.generateCertificate( keyFilepath=keyFilepath,
                                                   certFilepath=certFilepath,
                                                   signRequest=signRequest,
                                                   genNewKey=genNewKey,
                                                   keyType=keyAlgo,
                                                   newKeyBits=newKeyBits,
                                                   digest=digest,
                                                   curve=ecCurveName,
                                                   validity=validityDays,
                                                   **signReqParams )
      if genNewKey:
         SslCertKey.generateSslKeyCertSysLog( keyFilepath, 'sslkey:',
                                              keyLogAction, keyHash )

      if signRequest:
         mode.addMessage( csr )
      else:
         mode.addMessage( f'certificate:{cert} generated' )
         SslCertKey.generateSslKeyCertSysLog( certFilepath, 'certificate:',
                                              certLogAction, certHash )
   except SslCertKey.SslCertKeyError as e:
      mode.addError( f'{genericError} ({str( e )})' )
   except EnvironmentError as e:
      mode.addError( f'{genericError} ({e.strerror})' )
   except Exception: # pylint: disable=broad-except
      mode.addError( f'{genericError}' )

def Plugin( entityManager ):
   global config, status, execRequest
   config = ConfigMount.mount( entityManager, 'mgmt/security/ssl/config',
                               'Mgmt::Security::Ssl::Config', 'w' )
   status = LazyMount.mount( entityManager, 'mgmt/security/ssl/status',
                             'Mgmt::Security::Ssl::Status', 'r' )
   execRequest = LazyMount.mount( entityManager, 'mgmt/security/ssl/execRequest',
                                  'Mgmt::Security::Ssl::ExecRequest', 'w' )
