#!/usr/bin/env python3
# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import sysconfig
import os
import argparse
import configparser
import sys
import subprocess
import base64
import uuid

# Bootz package and deps are all contained in a folder ../site-packages/bootz
pyLibDir = sysconfig.get_paths()[ 'purelib' ]
bootzLibDir = os.path.join( pyLibDir, 'bootz' )
sys.path.insert( 1, bootzLibDir )

#pylint: disable=wrong-import-position,no-name-in-module,import-error
import grpc
from github.com.openconfig.bootz.proto import bootz_pb2
from github.com.openconfig.bootz.proto import bootz_pb2_grpc
#pylint: enable=wrong-import-position,no-name-in-module,import-error

indent = '   '
workDir = "/tmp"

def getControlCardState( serialNo="serial-1234" ):
   ControlCardStatus = bootz_pb2.ControlCardState.ControlCardStatus
   cardState = bootz_pb2.ControlCardState(
      serial_number=serialNo,
      status=ControlCardStatus.CONTROL_CARD_STATUS_INITIALIZED )
   return cardState

def sendReportStatus( stub, reportStatusRequest ):
   try:
      stub.ReportStatus( reportStatusRequest )
   except grpc.RpcError as rpc_error:
      print( f"ReportStatus() failure: {rpc_error}" )
      return False
   print( f"ReportStatus(msg={reportStatusRequest.status_message}) success" )
   return True

def sendReportStatusRequests( stub ):
   # Construct and send misc ReportStatus requests
   BootstrapStatus = bootz_pb2.ReportStatusRequest.BootstrapStatus
   reportStatusRequest = bootz_pb2.ReportStatusRequest( status_message="hello" )
   if not sendReportStatus( stub, reportStatusRequest ):
      return

   reportStatusRequest = bootz_pb2.ReportStatusRequest(
      status=BootstrapStatus.BOOTSTRAP_STATUS_SUCCESS,
      status_message="all ok" )
   cardState = getControlCardState( "serial-2" )
   reportStatusRequest.states.extend( [ cardState ] )
   if not sendReportStatus( stub, reportStatusRequest ):
      return

   reportStatusRequest = bootz_pb2.ReportStatusRequest(
      status=BootstrapStatus.BOOTSTRAP_STATUS_FAILURE,
      status_message="problem-x" )
   cardState = getControlCardState( "serial-3" )
   reportStatusRequest.states.extend( [ cardState ] )
   if not sendReportStatus( stub, reportStatusRequest ):
      return

   reportStatusRequest = bootz_pb2.ReportStatusRequest(
      status=BootstrapStatus.BOOTSTRAP_STATUS_INITIATED,
      status_message="bootstrap started" )
   cardState = getControlCardState( "serial-4" )
   reportStatusRequest.states.extend( [ cardState ] )
   if not sendReportStatus( stub, reportStatusRequest ):
      return

def getControlCard( partNo, serialNo, slotId ):
   controlCard = bootz_pb2.ControlCard(
      part_number=partNo,
      serial_number=serialNo,
      slot_id=slotId )
   return controlCard

def getChassisDescriptor( manuf, partNo, serialNo ):
   chassisDescriptor = bootz_pb2.ChassisDescriptor(
      manufacturer=manuf,
      part_number=partNo,
      serial_number=serialNo )

   controlCard1 = getControlCard( "part-1", "sno-1", "slot-1" )
   chassisDescriptor.control_cards.extend( [ controlCard1 ] )
   return chassisDescriptor

def extractBootstrapDataResponse(index, response):
   print(f"{indent*2}Response({index}):")
   print(f"{indent*3}serial_num={response.serial_num}")
   
   if response.HasField('intended_image'):
      softwareImage = response.intended_image
      print(f"{indent*3}Software Image:")
      print(f"{indent*4}Name: {softwareImage.name}")
      print(f"{indent*4}Version: {softwareImage.version}")      
      print(f"{indent*4}Url: {softwareImage.url}")
      print(f"{indent*4}Hash: {softwareImage.os_image_hash}")
      print(f"{indent*4}Algo: {softwareImage.hash_algorithm}")

   if response.HasField('boot_config'):
      bootConfig = response.boot_config
      print(f"{indent*3}Boot Config:")
      print(f"{indent*4}Vendor Config: {bootConfig.vendor_config}")
   
def extractBootstrapDataSigned( bootstrapDataSignedBytes ):
   bootstrapDataSigned = bootz_pb2.BootstrapDataSigned( )
   bootstrapDataSigned.ParseFromString( bootstrapDataSignedBytes )

   print( f"{indent*2}Nonce: {bootstrapDataSigned.nonce}" )
   
   responses = bootstrapDataSigned.responses
   index = 1
   for response in responses:
      extractBootstrapDataResponse( index, response )
      index += 1

def extractOwnerCert( ownerCertCms ):
   # first save the owner-cert cms data to a file
   rcvdOwnerCertCmsFile  = os.path.join( workDir, "rcvdOwnerCert.cms" )
   with open( rcvdOwnerCertCmsFile, 'wb' ) as f:
      f.write( ownerCertCms )
   print( f"Saved owner-cert CMS to file {rcvdOwnerCertCmsFile},"
          f" len={ len( ownerCertCms ) }" )

   print( "Extracting owner-cert from CMS..." )
   rcvdOwnerCertFile  = os.path.join( workDir, "rcvdOwner.cert" )
   opensslCmdArgs = [ "openssl", "cms", "-cmsout",
                      "-in", rcvdOwnerCertCmsFile,
                      "-inform", "der", "-noout",
                      "-certsout", rcvdOwnerCertFile ]
   print( "Issuing command: ", " ".join( opensslCmdArgs ) )
   proc = subprocess.run( opensslCmdArgs, stdout=subprocess.DEVNULL, check=True )
   print( f"Result: {proc.returncode}" )
   if proc.returncode == 0:
      print( f"Extracted owner-cert, file={rcvdOwnerCertFile}" )
   return rcvdOwnerCertFile

def extractPublicKey( certFile, publicKeyFile ):
   print( f"Extracting public-key from {certFile}" )
   # extract the public key from the owner-cert
   opensslCmdArgs = [ "openssl", "x509", "-pubkey",
                      "-in", certFile,
                      "-out", publicKeyFile ]
   print( "Issuing command: ", " ".join( opensslCmdArgs ) )
   proc = subprocess.run( opensslCmdArgs, stdout=subprocess.DEVNULL, check=True )
   print( f"Result: {proc.returncode}" )
   if proc.returncode == 0:      
      print( f"Extracted public-key, file {publicKeyFile}" )
   return publicKeyFile

def verifySignature( rcvdOwnerCertFile, responseSigB64, bootstrapDataSignedBytes ):
   # first base64-decode the response signature
   responseSig = base64.b64decode( responseSigB64 )
   print( f"Extracted response-signature, len={ len( responseSig ) }" )

   # next extract the public-key from the owner-cert
   ownerPublicKeyFile = os.path.join( workDir, "rcvdOwnerPublic.key" )
   extractPublicKey( rcvdOwnerCertFile, ownerPublicKeyFile )
   
   # next, save the signature to a file
   rcvdResponseSigFile = os.path.join( workDir, "rcvdResponse.sig" )
   with open( rcvdResponseSigFile, 'wb' ) as f:
      f.write( responseSig )
   print( f"Saved response-signature to file {rcvdResponseSigFile},"
          f" len={ len( responseSig ) }" )
   
   # and save the recieved serialized-bootstrap-data
   rcvdBootstrapDataSignedFile = os.path.join(
      workDir, "rcvdBootstrapDataSigned.bin" )
   with open( rcvdBootstrapDataSignedFile, 'wb' ) as f:
      f.write( bootstrapDataSignedBytes )
   print( f"Saved serialized-bootstrap-data to file {rcvdBootstrapDataSignedFile},"
          f" len={ len( bootstrapDataSignedBytes ) }" )

   # last, attampt to verify the signature
   opensslCmdArgs = [ "openssl", "dgst", "-sha256",
                      "-verify", ownerPublicKeyFile,
                      "-signature", rcvdResponseSigFile,
                      rcvdBootstrapDataSignedFile ]
   print( "Issuing command: ", " ".join( opensslCmdArgs ) )
   proc = subprocess.run( opensslCmdArgs, stdout=subprocess.PIPE, check=True )
   print( f"Result: {proc.returncode} ({proc.stdout.decode().rstrip()})" )

def sendGetBootstrapDataRequests( stub ):
   # Construct and send GetBootstrapData request
   chassisDescriptor = getChassisDescriptor( "Arista", "part-1", "sno-1" )
   cardState = getControlCardState( "serial-1" )

   nonce = uuid.uuid4().hex.encode()
   getBootstrapDataRequest = bootz_pb2.GetBootstrapDataRequest(
      chassis_descriptor=chassisDescriptor,
      control_card_state=cardState,
      nonce=nonce
   )

   try:
      getBootstrapDataResponse = stub.GetBootstrapData( getBootstrapDataRequest )
      print( "GetBootstrapData() success", )
   except grpc.RpcError as rpc_error:
      print( f"Call failure: {rpc_error}" )
      return

   ownerVoucherCms = getBootstrapDataResponse.ownership_voucher
   print( f"{indent}Ownership-voucher(CMS): len= { len( ownerVoucherCms ) }" )
   ownerCertCms = getBootstrapDataResponse.ownership_certificate
   print( f"{indent}Ownership_certificate(CMS): len={ len( ownerCertCms ) }" )
   responseSignatureB64 = getBootstrapDataResponse.response_signature
   print( f"{indent}Response-signature(base64): {responseSignatureB64}" )
   bootstrapDataSignedBytes = getBootstrapDataResponse.serialized_bootstrap_data
   print( f"{indent}BootstrapDataSigned(bytes): "
          f"len={ len( bootstrapDataSignedBytes ) }" )
   if len( ownerCertCms ):
      ownerCertFile = extractOwnerCert( ownerCertCms )
      verifySignature( ownerCertFile,
                       responseSignatureB64,
                       bootstrapDataSignedBytes )
   else:
      print( "Ownership-certificate not available, skipping signature verification" )
      
   extractBootstrapDataSigned( bootstrapDataSignedBytes )

def run( args ):
   cfgFile = args.cfgFile or "/tmp/bootz.cfg"

   # Parse config file
   config = configparser.ConfigParser()
   config.sections()
   config.read( cfgFile )
   defConfig = config[ 'DEFAULT' ]

   # Read config values
   global workDir
   hostIp = defConfig.get( 'Ip', '[::]' )
   port = defConfig.get( 'Port', '50051' )
   workDir = defConfig.get( 'WorkDir', '/tmp' )

   bootzAddr = f"{hostIp}:{port}"
   print( f"Trying to contact bootz {bootzAddr}..." )
   with grpc.insecure_channel( bootzAddr ) as channel:
      stub = bootz_pb2_grpc.BootstrapStub( channel )
      sendReportStatusRequests( stub )
      sendGetBootstrapDataRequests( stub )

if __name__ == "__main__":
   parser = argparse.ArgumentParser()
   parser.add_argument( '-c', '--cfgFile', action='store',
                        help='Bootz server config file',
                        default='/tmp/bootzd.cfg' )
   pargs = parser.parse_args()
   run( pargs )

