# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import ast
import base64
import binascii
import glob
import os
import re
import six
import subprocess
import ssl

from M2Crypto import EC, EVP, X509

CPUBOARD_PREFDL_PATH = "/etc/prefdl"

CHIP_SIGNING_ROOT_CA_PATH = "/etc/chip-signing-rootCa*.crt"

IsTpm2 = None

def getCpuBoardPrefdl():
   if os.path.exists( CPUBOARD_PREFDL_PATH ):
      with open( CPUBOARD_PREFDL_PATH, "rb" ) as f:
         prefdl = f.read()
   else:
      prefdl = subprocess.check_output( [ 'genprefdl' ] )
   return prefdl

def isTpm2():
   global IsTpm2
   if IsTpm2 is None:
      # Check TPM version by looking at /sys/class/tpm/tpm0/tpm_version_major
      sysfsPath = '/sys/class/tpm/tpm0/tpm_version_major'
      try:
         with open( sysfsPath ) as f:
            tpmVersionMajor = f.read().splitlines()[ 0 ]
      except OSError:
         tpmVersionMajor = ""
      IsTpm2 = tpmVersionMajor == "2"
   return IsTpm2

def processWhitespace( bytesVal ):
   """Remove all whitespace from prefdl fields that are string encoded dicts, lists,
   or tuples. This is to make the signature verification process robust to minor
   and unintentional whitespace changes in the MFG step that programs the prefdl.
   """
   try:
      evalObj = ast.literal_eval( six.ensure_str( bytesVal ) )
   except ( ValueError, SyntaxError ):
      return bytesVal
   if isinstance( evalObj, ( dict, list, tuple ) ):
      return bytesVal.replace( b' ', b'' )
   else:
      return bytesVal

def getCertificate( prefdl, prefix="" ):
   """Returns the Signing certificate stored in the Certificate prefdl field
   in Base64.
   """
   certificateRe = fr"(?:^|\n){prefix}Certificate: (.+)\n"
   m = re.search( six.ensure_binary( certificateRe ), prefdl )
   return m.group( 1 ) if m else b""

def getSignature( prefdl, prefix="" ):
   """Returns the Signature prefdl field in hex."""
   signatureRe = fr"(?:^|\n){prefix}Signature: (\w+)\n"
   m = re.search( six.ensure_binary( signatureRe ), prefdl )
   return m.group( 1 ) if m else b""

def getSignatureList( prefdl, prefix="" ):
   """Returns the SignatureList prefdl field in ASCII."""
   signatureListRe = fr"(?:^|\n){prefix}SignatureList: (.+)\n"
   m = re.search( six.ensure_binary( signatureListRe ), prefdl )
   return m.group( 1 ) if m else b""

def concatTpmPublicEndorsementKey( modulus, exponent ):
   """
   Concatenate the TPM Public Endorsement Key in a format
   usable for signature verification. Returned format is a
   hex string with the modulus followed by the exponent.
   """
   hexModulus = modulus.zfill( 512 )
   hexExponent = exponent.zfill( 6 )
   return hexModulus + hexExponent

def getTpmPublicEndorsementKey():
   """
   Returns the modulus and exponent as hex strings for later usage.
   """
   try:
      output = subprocess.check_output( [ 'tpmutil', 'readpubek' ] )
   except OSError:
      # Assume that tpmutil only exists in Aboot
      # tpm_readPubEk is the EOS equivalent for getting the TPM pub ek
      try:
         output = subprocess.check_output( [ "tpm_readPubEk" ] )
      except ( OSError, subprocess.CalledProcessError ):
         return None
   m = re.search( br"Modulus: (\w+)\n", output )
   modulus = m.group( 1 ) if m else b""
   m = re.search( br"Exponent: 0x(\w+)\n", output )
   exponent = m.group( 1 ) if m else b""
   return (modulus, exponent)

def getTpm2PublicPPK():
   cmd = "tpm2_readPubPpk"
   try:
      return subprocess.check_output( cmd.split() )
   except ( OSError, subprocess.CalledProcessError ):
      return None

def getTpmPublicKey():
   if isTpm2():
      return getTpm2PublicPPK()
   else:
      return getTpmPublicEndorsementKey()

def formatTpmPublicKey( tpmPublicKeyComponents ):
   if isTpm2():
      return base64.b64encode( tpmPublicKeyComponents )
   else:
      return concatTpmPublicEndorsementKey( tpmPublicKeyComponents[0],
                                            tpmPublicKeyComponents[1] )

def evalSignatureList( signatureListBytes ):
   """Convert SignatureList from bytes to a list."""
   try:
      signatureList = ast.literal_eval( six.ensure_str( signatureListBytes ) )
   except ( ValueError, SyntaxError ):
      return None
   if isinstance( signatureList, list ):
      return signatureList
   else:
      return None

def calculatePayload( prefdl, tpmPublicKey, signatureList, signatureListBytes ):
   """Calculates and returns the expected payload given the state of the running
   system. Signing this payload with the Signing certificate's private key results
   in the Signature. This calculated value is used to check against the actual
   Signature stored in the prefdl.

   The payload is the concatenation of the following fields in the order shown:

   TPM public EK modulus (hex bytes, zero padded to 512 chars) +
   TPM public EK exponent (hex bytes, zero padded to 6 chars) +
   SignatureList (value only) + prefdl fields listed in SignatureList (in order,
   value only)

   e.g.
   prefdl:
   ...
   SerialNumber: SSJ17450458
   HwEpoch: 02.00
   SignatureList: ['SerialNumber','HwEpoch','SwFeatures']
   SwFeatures: {'feature1':'enabled','feature2':'disabled'}
   ...

   payload:
   ccdd8520b874ca7b0968b0401ed7febd4bba1b4983a7831698fe2d54157e50c22f42803486bb18a77
   e0a2e6d14571593976ca1259910c3cc7cbb675f6e22eb8d2b7e69c1a9eb7b0e5473ec3bc0a94d4abf
   a9009dd2436944d5cc5e12916f6b6e978e10fcc05ad13404801538442738aebc53ad1df1641ceb641
   3bb9d94b5a400959aff90f36af0da2782b0c5219643dea55f7f15837abfa4e048e1b89b41f56f1289
   bdaadb56a47155fa0f41ad48e272a5dc925a860b07af8f7b5d92e0640c98146a8f33fa4271b1f69d1
   aa3a9e4ee5e8823906835d43b32d8c1098dad6ac1f10a2393c50deb81b7cef8fd07adcc325cfd95eb
   be5e0eefd099f10a6a41fadf73 + 010001 +
   ['SerialNumber','HwEpoch','SwFeatures'] + SSJ17450458 + 02.00 +
   {'feature1':'enabled','feature2':'disabled'}
   """
   payload = tpmPublicKey
   payload += processWhitespace( signatureListBytes )

   for field in signatureList:
      m = re.search( six.ensure_binary( fr"{field}: (.+)\n" ), prefdl )
      val = m.group( 1 ) if m else b""
      payload += processWhitespace( val )
   return payload

def verifyCertificate( signingCert ):
   """Returns True if the Signing certificate stored in the Certificate prefdl
   field is signed by the Chip Signing Root CA. This verifies the chain of trust
   between the two certificates.

   We check with all Chip Signing Root CA in /etc/ in order to account for 
   root cert expiration
   """
   for rootCa in glob.glob( CHIP_SIGNING_ROOT_CA_PATH ):
      chipSigningRootCa = X509.load_cert( rootCa )
      if signingCert.verify( chipSigningRootCa.get_pubkey() ) == 1:
         return True
   return False

def verifySignatureFormat( signature, usingP384 ):
   """
   (1) For signatures generated using the RSA signing certs, returns True if the 
   signature is 256 bytes encoded in 512 hex chars.
   (2) For signatures generated using the P-384 signing certs, return True if the 
   signature is 102 bytes encoded as 204 hex chars.
   """
   if len( signature ) != ( 204 if usingP384 else 512 ):
      return False
   try:
      int( signature, 16 )
   except ValueError:
      return False
   return True

def verifySignature( signature, signingCert, payload, usingP384 ):
   """Returns True if the Signature stored in the Signature prefdl field is
   successfully verified against the running system. This is done by decrypting
   the Signature with the Signing certificate's public key and comparing the 
   result to the calculated payload.
   """
   if usingP384:
      ec = EC.pub_key_from_der( signingCert.get_pubkey().as_der() )
      if not ec.check_key():
         return False
      m = EVP.MessageDigest( 'sha384' )
      m.update( payload )
      digest = m.digest()
      result = ec.verify_dsa_asn1( digest, binascii.a2b_hex( signature ) )
      return result == 1
   else:
      pubKey = signingCert.get_pubkey().get_rsa()
      m = EVP.MessageDigest( 'sha256' )
      m.update( payload )
      digest = m.final()
      result = pubKey.verify_rsassa_pss( digest, binascii.a2b_hex( signature ),
                                       'sha256', salt_length=-2 )
      return result == 1

def verifySignatureListFields( prefdl, signatureList ):
   """Returns a list of SignatureList members that are not successfully located
   in the prefdl.
   """
   missingFields = []
   for field in signatureList:
      m = re.search( six.ensure_binary( fr"^{field}: " ), prefdl,
                     re.MULTILINE )
      if not m:
         missingFields.append( field )
   return missingFields

class VerificationError(Exception):
   def __init__( self, errorMsg ):
      Exception.__init__( self, errorMsg )

def verifyPrefdl( prefdl=None ):
   if not prefdl:
      prefdl = getCpuBoardPrefdl()

   # Check to see if Cpu-prefixed fields exist first, and use them if available. For
   # example, systems that use the Mauna Kea chassis merge the CPU and switchcard
   # prefdls together, prefixing the CPU prefdl with "Cpu", and the CPU prefdl
   # contains the certificate/signature values.
   fieldPrefix = "Cpu"
   certificate = getCertificate( prefdl, prefix=fieldPrefix )
   if not certificate:
      fieldPrefix = ""
      certificate = getCertificate( prefdl, prefix=fieldPrefix )
   if not certificate:
      raise VerificationError( "Failed to read Certificate from prefdl" )

   try:
      derCertificate = base64.b64decode( certificate )
   except ( TypeError, binascii.Error ) as e:
      raise VerificationError( "Certificate invalid - "
                               "failed to decode as Base64" ) from e

   try:
      pemCertificate = str( ssl.DER_cert_to_PEM_cert( derCertificate ) )
      signingCertX509 = X509.load_cert_string( pemCertificate )
   except X509.X509Error as e:
      raise VerificationError( "Certificate invalid - "
                               "failed to load into X509 object" ) from e

   # Check if the signing cert is using P-384
   usingP384 = bool( 'NIST CURVE: P-384' in signingCertX509.as_text() )

   signature = getSignature( prefdl, prefix=fieldPrefix )
   if not signature:
      raise VerificationError( "Failed to read Signature from prefdl" )
   sigFormatValid = verifySignatureFormat( signature, usingP384 )
   if not sigFormatValid:
      raise VerificationError( "Signature is encoded incorrectly" )

   signatureListBytes = getSignatureList( prefdl, prefix=fieldPrefix )
   if not signatureListBytes:
      raise VerificationError( "Failed to read SignatureList from prefdl" )
   signatureList = evalSignatureList( signatureListBytes )
   if not signatureList:
      raise VerificationError(
              "Failed to convert SignatureList to an iterable list" )
   # Prepend prefix to signature list fields if we found one
   signatureList = [ f"{fieldPrefix}{field}" for field in signatureList ]
   missingFields = verifySignatureListFields( prefdl, signatureList )
   if missingFields:
      raise VerificationError( f"Failed to find {missingFields} in the prefdl" )

   tpmPublicKeyComponents = getTpmPublicKey()
   if not tpmPublicKeyComponents:
      raise VerificationError( "Failed to read TPM public key" )
   tpmPubKey = formatTpmPublicKey( tpmPublicKeyComponents )
   payload = calculatePayload( prefdl, tpmPubKey, signatureList,
                               signatureListBytes )

   certValid = verifyCertificate( signingCertX509 )
   if not certValid:
      raise VerificationError(
             "Chain of trust verification failed between Signing certificate " \
             "and Chip Signing Root CA"
            )

   sigValid = verifySignature( signature, signingCertX509, payload, usingP384 )
   if not sigValid:
      raise VerificationError(
              "Failed to verify Signature against expected payload"  )

   return tpmPublicKeyComponents
