# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import tempfile
from SslCertKey import dirCreate
import Tracing
import Tac
import os
from SshAlgorithms import HOSTKEYS_MAPPINGS

traceHandle = Tracing.Handle( 'SshCertLib' )
t0 = traceHandle.trace0

Constants = Tac.Type( "Mgmt::Ssh::Constants" )
KeyType = Tac.Type( "Mgmt::Ssh::KeyType" )

algoDirToSshAlgo = HOSTKEYS_MAPPINGS

class SshHostCertError( Exception ):
   pass

class SshKeyError( Exception ):
   pass

def createSshDirs():
   dirCreate( Constants.sshBaseDir )
   dirCreate( Constants.authPrincipalsCmdDirPath() )
   dirCreate( Constants.caKeysDirPath() )
   dirCreate( Constants.hostCertsDirPath() )
   dirCreate( Constants.revokeListsDirPath() )
   dirCreate( Constants.hostKeysDirPath() )
   for algo in algoDirToSshAlgo:
      dirCreate( "%s%s" % ( Constants.hostKeysDirPath(), algo ) )
   # Temp directory where named keys are stored while being
   # referenced in sshd_config file
   dirCreate( Constants.sshKeyConfigDir )

def getAlgoFromKeyPath( keyPath ):
   '''
       keyPath will be of the form */<algo>/<key-name>
       Get the <algo> from the keyPath
   '''
   return os.path.basename( os.path.dirname( keyPath ) )

# Validate file containing public keys
def validateMultipleKeysFile ( keyFile ):
   t0( "Validating key file: %s" % keyFile )
   with open( keyFile ) as keyFileHandle:
      for key in keyFileHandle.readlines():
         if not key.strip():
            continue
         with tempfile.NamedTemporaryFile( mode="w" ) as tmpFile:
            tmpFile.write( key )
            tmpFile.flush()
            try:
               Tac.run( [ "ssh-keygen", "-lf", tmpFile.name ],
                        stdout=Tac.CAPTURE, stderr=Tac.CAPTURE,
                        ignoreReturnCode=False )
               t0( "Valid key: %s" % key )
            except Tac.SystemCommandError as e:
               t0( "Invalid key: %s : Error: %s" % ( key.rstrip(),
                                                     e.output.rstrip() ) )
               # pylint: disable-next=raise-missing-from
               raise SshKeyError( "Invalid key: %s" % key.rstrip() )

def validateHostCert( certFile ):
   try:
      cert = Tac.run( [ "ssh-keygen", "-Lf", certFile ],
                      stdout=Tac.CAPTURE, stderr=Tac.CAPTURE,
                      ignoreReturnCode=False )
      return cert
   except Tac.SystemCommandError as e:
      t0( "Invalid cert: %s" % e.output.rstrip() )
      # pylint: disable-next=raise-missing-from
      raise SshHostCertError( "Invalid certificate" )

def validateHostKey( keyFile ):
   '''
      The should be a valid OpenSSH private key. Throw error if
      1. The file does not start with "---".
      2. Cannot extract the public key from the private key.
      3. The specified algorithm directory does not match the key type.
   '''
   t0( "ValidateHostKey: %s" % keyFile )

   # Ensure its a SSH private key
   with open( keyFile ) as priKey:
      firstLine = priKey.readline().rstrip()

   # All SSH private keys being with a "---"
   if firstLine[:3] != "---":
      t0( "Not a valid SSH private key" )
      raise SshKeyError( "Invalid SSH private key" )

   try:
      # Extract the public key out of the private key.
      # This also validates that this is a valid SSH private key.
      pubKeyData = Tac.run( [ "ssh-keygen", "-yf", keyFile ],
            stdout=Tac.CAPTURE, stderr=Tac.CAPTURE,
            ignoreReturnCode=False )

      # Get the algorithm name from the specified file path
      # ssh-key:/<algo>/<file-name>
      algo = os.path.basename( os.path.dirname( keyFile ) )

      # pubKeyData is of the form
      # <algo> <key-bytes> <comments>
      # Split on space and extract the SSH algorithm type
      if ( algo not in algoDirToSshAlgo or
            algoDirToSshAlgo[ algo ] != pubKeyData.split()[0] ):
         t0( "Specified algorithm directory is not correct" )
         raise SshKeyError( "Invalid algorithm directory" )

   except Tac.SystemCommandError as e:
      t0( "Invalid SSH private key: %s" % e.output.rstrip() )
      # pylint: disable-next=raise-missing-from
      raise SshKeyError( "Invalid SSH private key" )

def getHostCertKeyType( certFile ):
   certPath = Constants.hostCertPath( certFile )
   try:
      # Make sure cert is valid
      validateHostCert( certPath )
      with open( certPath ) as certFileHandler:
         cert = certFileHandler.read().strip()
         return cert.split( ' ' )[ 0 ]
   except ( OSError, SshHostCertError ):
      return KeyType.invalid

def hostCertsByKeyTypes( certFiles ):
   hostCertKeyTypes = {}
   for cert in certFiles:
      keyType = getHostCertKeyType( cert )
      if keyType is not None:
         hostCertKeyTypes.setdefault( keyType, [] ).append( cert )
   return hostCertKeyTypes
