#!/usr/bin/env python3
# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import grpc
import os
import sysconfig
import sys
import subprocess
import logging
import time
import signal
import argparse
import configparser
import SztpRestconfTestLib
import base64
from concurrent import futures

# Bootz package and deps are all contained in a folder ../site-packages/bootz
pyLibDir = sysconfig.get_paths()[ 'purelib' ]
bootzLibDir = os.path.join( pyLibDir, 'bootz' )
sys.path.insert( 1, bootzLibDir )

#pylint: disable=wrong-import-position,no-name-in-module,import-error
from github.com.openconfig.bootz.proto import bootz_pb2
from github.com.openconfig.bootz.proto import bootz_pb2_grpc
#pylint: enable=wrong-import-position,no-name-in-module,import-error

logger = logging.getLogger( __name__ )
bootstrapCfgFile = ""
ownerVoucherCmsFile = ""
ownerCertFile = ""
ownerCertCmsFile = ""
ownerKeyFile = ""
imageUrl = ""
imageVersion = ""
osImageHash = ""
workDir = "/tmp"
verifySig = False

hashAlgorithm="ietf-sztp-conveyed-info:sha-256"

class BootzBootstrap( bootz_pb2_grpc.BootstrapServicer ):
   """
   Lightweight bootz server for testing purposes

   Usage:
   /src/ZeroTouch/test/bootz.py --help

   Sample configuration file:
   
   [DEFAULT]
   pidfile = /tmp/bootz-server-host1.pid
   logfile = /tmp/bootz-server-host1.log
   ip = 127.0.0.1
   port = 50051
   bootstrapcfgfile = /home/eamonjd/tmp/boot-startup.cfg
   OwnerVoucherFile = /home/eamonjd/tmp/voucher.txt
   OwnerVoucherCmsFile = /tmp/ownerVoucher.cms
   OwnerCertCmsFile = /tmp/ownerCert.cms
   verifySig = True

   [IMAGE]
   imagepath = http://10.0.0.2/bootz/images/EOS-bootz.swi
   imageversion = 4.33.0
   imagehash = d9:a5:d1:0b:09:fa:4e:96:f2:40:bf:6a:00:00
   hashalgorithm = ietf-sztp-conveyed-info:sha-256

   """
        
   def _getSoftwareImage( self,
                         name="EOS",
                         version="4.33.0",
                         url="http://images.com/4.30.01/EOS.swi",
                         imageHash="d9:a5:d1:0b:09:fa:4e:96:f2:40:bf:6a:82:f5",
                         hashAlgo="ietf-sztp-conveyed-info:sha-256" ):
      logger.info( "Generating SoftwareImage..." )
      softwareImage = bootz_pb2.SoftwareImage(
         name=name,
         version=version,
         url=url,
         os_image_hash=imageHash,
         hash_algorithm=hashAlgo
      )
      return softwareImage
   
   def _getBootConfig( self, startupConfig ):
      logger.info( "Generating BootConfig..." )
      # only attribute we're interested in
      bootConfig = bootz_pb2.BootConfig(
         vendor_config=startupConfig.encode() )
      return bootConfig

   def _getBootPasswordHash( self ):
      sampleHash = osImageHash
      return sampleHash
   
   def _getBootstrapDataResponse( self, request, startupConfig ):
      logger.info( "Generating BootstrapDataResponse..." )
      intendedImage = None
      if imageUrl:
         intendedImage = self._getSoftwareImage(
            version=imageVersion,
            url=imageUrl,
            imageHash=osImageHash,
            hashAlgo=hashAlgorithm
         )
      bootstrapDataResponse = bootz_pb2.BootstrapDataResponse(
         serial_num=request.chassis_descriptor.serial_number,
         boot_password_hash=self._getBootPasswordHash(),
         intended_image=intendedImage,
         boot_config=self._getBootConfig( startupConfig ),
         )
      return bootstrapDataResponse
      
   def _getBootstrapDataSigned( self, request, startupConfig ):
      logger.info( "Generating BootstrapDataSigned..." )
      response = self._getBootstrapDataResponse( request, startupConfig )
      bootstrapDataSigned = bootz_pb2.BootstrapDataSigned(
         nonce=request.nonce )
      bootstrapDataSigned.responses.extend( [ response ] )
      return bootstrapDataSigned

   def _extractOwnerCertFromCms( self ):
      logger.info( "Extracting owner-cert from CMS file %s...", ownerCertCmsFile )
      extOwnerCertFile = os.path.join( workDir, "extOwner.cert" )
      opensslCmdArgs = [ "openssl", "cms", "-cmsout",
                         "-in", ownerCertCmsFile,
                         "-inform", "der", "-noout",
                         "-certsout", extOwnerCertFile ]
      logger.info( "Issuing command: %s", " ".join( opensslCmdArgs ) )
      proc = subprocess.run( opensslCmdArgs, stdout=subprocess.DEVNULL, check=True )
      logger.info( "Result: %d", proc.returncode )
      if proc.returncode == 0:
         logger.info( "Extracted owner-cert, file %s", extOwnerCertFile )
      return extOwnerCertFile

   def _extractPublicKey( self, certFile ):      
      logger.info( "Extracting public-key from cert %s...", certFile )
      publicKeyFile =  os.path.join( workDir, "extOwnerPublic.key" )
      opensslCmdArgs = [ "openssl", "x509", "-pubkey",
                         "-in", certFile,
                         "-out", publicKeyFile ]
      logger.info( "Issuing command: %s", " ".join( opensslCmdArgs ) )
      proc = subprocess.run( opensslCmdArgs, stdout=subprocess.DEVNULL, check=True )
      logger.info( "Result: %d", proc.returncode )
      if proc.returncode == 0:      
         logger.info( "Extracted public-key, file %s", publicKeyFile )
      return publicKeyFile
   
   def _verifySignature( self, responseSigB64, bootstrapDataSignedBytes ):
      logger.info( "Verifying signature, sig-len(base64)=%d, data-len=%s...",
                   len( responseSigB64 ), len( bootstrapDataSignedBytes ) )
      # first base64-decode the response signature
      responseSig = base64.b64decode( responseSigB64 )
      logger.info( "Extracted response-signature, len=%d", len( responseSig ) )
      
      # extract owner cert from CMS
      extOwnerCertFile = self._extractOwnerCertFromCms()

      # and owner public-key from cert
      extOwnerPublicKeyFile = self._extractPublicKey( extOwnerCertFile )
      
      # next, save the signature to a file
      rcvdResponseSigFile = os.path.join( workDir, "response.sig" )
      with open( rcvdResponseSigFile, 'wb' ) as f:
         f.write( responseSig )
         logger.info( "Saved response-signature to file %s, len=%d",
                      rcvdResponseSigFile, len( responseSig ) )
         
      # and save the recieved serialized-bootstrap-data
      rcvdBootstrapDataSignedFile = os.path.join( workDir, "dataSigned.bin" )
      with open( rcvdBootstrapDataSignedFile, 'wb' ) as f:
         f.write( bootstrapDataSignedBytes )
      logger.info( "Saved serialized-bootstrap-data to file %s, len=%d",
                   rcvdBootstrapDataSignedFile, len( bootstrapDataSignedBytes ) )
            
      # last, attampt to verify the signature
      opensslCmdArgs = [ "openssl", "dgst", "-sha256",
                         "-verify", extOwnerPublicKeyFile,
                         "-signature", rcvdResponseSigFile,
                         rcvdBootstrapDataSignedFile ]
      logger.info( "Issuing command: %s", " ".join( opensslCmdArgs ) )
      proc = subprocess.run( opensslCmdArgs, stdout=subprocess.DEVNULL, check=True )
      logger.info( "Result: %d (stdout:%s)", proc.returncode, proc.stdout )
      
   def _getResponseSig( self, serializedBootstrapData ):
      logger.info(
         "Generating response signature; ownerKeyFile: %s", ownerKeyFile )
      serializedBootstrapDataFile = os.path.join(
         workDir, "serializedBootstrapData.bin" )
      with open( serializedBootstrapDataFile, 'wb' ) as f:
         f.write( serializedBootstrapData )
      logger.info( "Saved serialized-bootstrap-data to file %s, len=%d",
                   serializedBootstrapDataFile, len( serializedBootstrapData ) )

      # obtain signature
      opensslCmdArgs = [ "openssl", "dgst", "-sha256",
                         "-sign", ownerKeyFile,
                         serializedBootstrapDataFile ]
      logger.info( "Issuing command: %s", " ".join( opensslCmdArgs ) )
      responseSigB64 = SztpRestconfTestLib.runOpensslCmd(
         opensslCmdArgs, [ "base64", "-w0" ] )
      if not responseSigB64:
         logger.info( "Failed to generate response-signature" )
      else:
         logger.info( "Successfully generated response-signature(base64), len=%s",
                      len( responseSigB64 ) )

      # verify the signature is correct
      if verifySig:
         self._verifySignature( responseSigB64, serializedBootstrapData )
      else:
         logger.info( "Skipping verify signature")
         
      return responseSigB64
      
   def GetBootstrapData( self, request, context ):
      # get chassis-descriptor
      chassisDescriptor = request.chassis_descriptor
      chassisSpec = f"Chassis: manuf={chassisDescriptor.manufacturer}, "\
         f"partNo={chassisDescriptor.part_number}, "\
         f"serialNo={chassisDescriptor.serial_number}"

      # get control-card-state
      state = request.control_card_state
      cardStatus = bootz_pb2.ControlCardState.ControlCardStatus.Name( state.status )
      cardSpec = f"card={state.serial_number} ({cardStatus})"

      nonceSpec = f"nonce={request.nonce}"
      logger.info(
         "Received GetBootstrapData: %s; %s; %s", chassisSpec, cardSpec, nonceSpec )

      # read bootstrap config
      startupConfig = "--default-bootstrap-config--"
      try:
         with open( bootstrapCfgFile ) as f:
            startupConfig = f.read().rstrip()
      except FileNotFoundError:
         logger.info( "Bootstrap config file not found %s", bootstrapCfgFile )

      serializedBootstrapData = self._getBootstrapDataSigned(
         request, startupConfig )

      # owner voucher/cert are only required if nonce is set
      ownerVoucherCms = b""
      ownerCertCms = b""

      # Only provide voucher, cert and signature if nonce is included in request
      if request.nonce:
         logger.info( "Nonce configured, including owner voucher and cert..." )
         # read owner-voucher cms file
         if ownerVoucherCmsFile:
            # read owner-voucher CMS contents
            try:
               with open( ownerVoucherCmsFile, mode='rb' ) as f:
                  ownerVoucherCms = f.read()
            except FileNotFoundError:
               logger.info(
                  "Owner voucher cms-file not found %s", ownerVoucherCmsFile )

         # read owner-cert cms
         if ownerCertCmsFile:
            try:
               with open( ownerCertCmsFile, mode='rb' ) as f:
                  ownerCertCms = f.read()
            except FileNotFoundError:
               logger.info( "Owner cert file not found %s", ownerCertCmsFile )
      else:
         logger.info( "Nonce not configured, omitting owner voucher and cert" )
         
      # sign the serialized response data
      responseSigB64 = ""
      if ownerKeyFile:
         responseSigB64 = self._getResponseSig(
            serializedBootstrapData.SerializeToString() )
      else:
         logger.info( "OwnerKeyFile not configured" )
      
      return bootz_pb2.GetBootstrapDataResponse(
         ownership_voucher=ownerVoucherCms,
         ownership_certificate=ownerCertCms,
         response_signature=responseSigB64,
         serialized_bootstrap_data=serializedBootstrapData.SerializeToString() )

   def ReportStatus( self, request, context ):
      status = bootz_pb2.ReportStatusRequest.BootstrapStatus.Name( request.status )
      statusMsg = request.status_message
      states = request.states
      cardStates = ""
      for state in states:
         status = bootz_pb2.ControlCardState.ControlCardStatus.Name( state.status )
         cardStates += f", card={state.serial_number} ({status})"

      logger.info(
         "Received ReportStatus: status=%s, msg=%s%s",
         status, statusMsg, cardStates )

      return bootz_pb2.EmptyResponse()

def _startLogging( logFile ):
   logging.basicConfig( format='%(asctime)s %(levelname)-8s %(message)s',
                        filename=logFile,
                        level=logging.INFO,
                        datefmt='%Y-%m-%d %H:%M:%S' )
   logger.info( "Bootz server started" )

def _wait_forever( server ):
   print( "Waiting for termination" )
   try:
      while True:
         time.sleep( 20 )
   except KeyboardInterrupt:
      server.stop( None )

def _startBootzServer( server,
                      hostIp="[::]",
                      port="50051",
                      pidFile="/tmp/bootzd.pid" ):
   # save process pid
   with open( pidFile, 'w', encoding='utf-8' ) as f:
      f.write( str( os.getpid() ) )

   bootz_pb2_grpc.add_BootstrapServicer_to_server( BootzBootstrap(), server )
   ipSpec = f"{hostIp}:{port}"
   server.add_insecure_port( ipSpec )
   server.start()
   print( f"Bootz server started, listening on {ipSpec}..." )
   _wait_forever( server )

def _stopBootzServer( pidFile="/tmp/bootzd.pid" ):
   pid = 0
   try:
      with open( pidFile ) as f:
         pid = int( f.read() )
   except FileNotFoundError:
      print( f"Pid file not found {pidFile}" )

   if pid != 0:
      print( f"Stopping Bootz server {pid}..." )
      os.kill( pid, signal.SIGINT )
      
def run( args ):
   cfgFile = args.cfgFile or "/tmp/bootz.cfg"

   # Parse config file
   config = configparser.ConfigParser()
   config.read( cfgFile )
   defConfig = config[ 'DEFAULT' ]

   # Read config values
   global bootstrapCfgFile
   global ownerVoucherCmsFile
   global ownerCertFile
   global ownerCertCmsFile
   global ownerKeyFile
   global workDir
   global verifySig
   
   pidFile = defConfig.get( 'PidFile', '/tmp/bootzd.pid' )
   logFile = defConfig.get( 'LogFile', '/tmp/bootzd.log' )
   hostIp = defConfig.get( 'Ip', '[::]' )
   port = defConfig.get( 'Port', '50051' )
   bootstrapCfgFile = defConfig.get( 'BootstrapCfgFile', '/tmp/startup-config' )
   maxWorkers = defConfig.getint( 'MaxWorkers', 1 )
   ownerVoucherCmsFile = defConfig.get( 'OwnerVoucherCmsFile', None)
   ownerCertFile = defConfig.get( 'OwnerCertFile', None )
   ownerCertCmsFile = defConfig.get( 'OwnerCertCmsFile', None )
   ownerKeyFile = defConfig.get( 'OwnerKeyFile', None )
   workDir = defConfig.get( 'WorkDir', None )
   verifySig = defConfig.getboolean( 'VerifySig', False )

   # Read Download image config
   if 'IMAGE' in config.sections():
      imageConfig = config[ 'IMAGE' ]
      global imageUrl
      global imageVersion
      global osImageHash
      global hashAlgorithm
      
      imageUrl = imageConfig.get(
         'ImagePath', 'http://10.0.0.2/bootz/images/EOS-bootz.swi' )
      imageVersion = imageConfig.get( 'ImageVersion', '4.33.0' )
      osImageHash = imageConfig.get(
         'ImageHash', 'd9:a5:d1:0b:09:fa:4e:96:f2:40:bf:6a:00:00' )
      hashAlgorithm = imageConfig.get(
         'hashAlgorithm', 'ietf-sztp-conveyed-info:sha-256' )
   
   if args.start:
      _startLogging( logFile )
      server = grpc.server( futures.ThreadPoolExecutor( max_workers=maxWorkers ) )
      _startBootzServer( server, hostIp, port, pidFile )
      logger.info( "Bootz server stopped" )
   elif args.stop:
      _stopBootzServer( pidFile )

if __name__ == "__main__":
   parser = argparse.ArgumentParser()
   parser.add_argument( '-c', '--cfgFile', action='store',
                        help='Bootz server config file',
                        default='/tmp/bootzd.cfg' )
   parser.add_argument( '-s', '--start', action='store_true',
                        help='Start bootz server',
                        default=True )
   parser.add_argument( '-t', '--stop', action='store_true',
                        help='Stop bootz server',
                        default=False )
   pargs = parser.parse_args()

   run( pargs )

