#!/usr/bin/env python3
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import six
import argparse
from enum import Enum
import os
import shutil
import sys
import time

import SignatureFile
import SignatureRequest
import Swi
import SwiSignLib
import Tac
import VerifySwi
import platform

# Signing a single-rootfs image is straight forward: call
#   Swi.sign_( imagePath, forceSign, useDevCA, devCaKeyPair, signUrl ).
# For a multi-rootfs image, we have to extact each optimization and sign that. Then
# extract each just inserted signature and insert it back into the multi-rootfs image
# while making its name unique to the optimisation (<optim>.signature). Then sign
# the new multi-rootfs.
# The newly signed multi-rootfs image will have its own signature (swi-signature), as
# well as a signature for each of its optimization (<optim>.signature).
# At runtime, during Aboot or during "install source" cli command, the multi-rootfs
# image will be scaled down to a single rootfs. At that time, the optimization's
# signature will be moved into the image's and all other signatures removed.

# Note that what we sign is not the meta image, but the flavors created out of it.
# That's suboptimal in case the same optimization exists in multiple flavors, but
# that's not a thing today. This keeps the signing code simple (always sign both
# the optimizations as well as the (super) image itself.

# Note that if a flavor is not signed, even if it contains the signatures for its
# contained optimizations, when it gets optimized on a DUT, the resulting optimized
# image will be signature-less (we copy foo.sig to swi-signature *if* it exists).

printDebugs = False # used during Abuild so we can see how long things take

class SWI_SIGN_RESULT:
   SUCCESS = 0
   ALREADY_SIGNED_INVALID = 1
   ALREADY_SIGNED_VALID = 2
   SERVER_ERROR = 3
   ERROR_NOT_A_SWI = 4 # We just check the version has "SWI_VERSION"
   ERROR_SWI_NOT_BLESSED = 5
   CANNOT_ADD_OPTIMIZATION_SIGNATURES = 6
   CANNOT_EXTRACT_SIGNATURE = 7

def alreadySigned( swiFile ):
   ''' Returns a tuple of ( isSigned, validSignature ) where "isSigned"
       is True if the swiFile is signed and False otherwise, and
       "validSignature" is True if the signature is valid, False otherwise'''
   signed = SwiSignLib.swiSignatureExists( swiFile )
   swiSigValid, _, _ = SwiSignLib.verifySwiSignature( swiFile )
   return ( signed, swiSigValid )

def querySignNotBlessed( force ):
   if force:
      return True
   question = "SWI is not blessed. Are you sure you wish to sign it? [y/N] "
   queryBless = six.moves.input( question )
   # Not signing is the default
   if queryBless.lower().startswith( 'y' ):
      return True
   return False

class UPDATE_SIG_RESULT( Enum ):
   UPDATED_SIG = 0
   NO_SIG = 1
   INVALID_SIG = 2
   ERROR_UPDATING = 3
   ERROR_RELEASE_CERT = 4
   ERROR_UNKNOWN_CERT = 5

UPDATE_SIG_MESSAGE = {
   UPDATE_SIG_RESULT.UPDATED_SIG : "SWI signature updated.",
   UPDATE_SIG_RESULT.NO_SIG : "SWI not signed, signature update not needed.",
   UPDATE_SIG_RESULT.INVALID_SIG : "Invalid SWI signature, " \
                                   "signature update aborted.",
   UPDATE_SIG_RESULT.ERROR_UPDATING : "Failed to update SWI signature, please" \
                                      " re-sign the SWI.",
   UPDATE_SIG_RESULT.ERROR_RELEASE_CERT : "Signed with release CA, please manually" \
                                          " re-sign the SWI.",
   UPDATE_SIG_RESULT.ERROR_UNKNOWN_CERT : "Signed with an unknown certificate," \
                                          " please manually re-sign the SWI.",
}

def updateSwiSignature( swiFile, allowReleaseCA=False ):
   if not SwiSignLib.swiSignatureExists( swiFile ):
      return UPDATE_SIG_RESULT.NO_SIG

   swiSig = VerifySwi.getSwiSignatureData( swiFile )
   if swiSig is None:
      return UPDATE_SIG_RESULT.INVALID_SIG

   try:
      aristaCaX509 = VerifySwi.loadRootCert( VerifySwi.ARISTA_ROOT_CA_FILE_NAME )
      result = VerifySwi.checkSigningCert( swiSig, aristaCaX509 )
   except VerifySwi.X509CertException:
      return UPDATE_SIG_RESULT.INVALID_SIG

   if result == VerifySwi.VERIFY_SWI_RESULT.SUCCESS:
      if allowReleaseCA:
         result = sign( swiFile, forceSign=True )
         if result == SWI_SIGN_RESULT.SUCCESS:
            return UPDATE_SIG_RESULT.UPDATED_SIG
         else:
            return UPDATE_SIG_RESULT.ERROR_UPDATING
      return UPDATE_SIG_RESULT.ERROR_RELEASE_CERT

   if os.path.exists( VerifySwi.DEV_ROOT_CA_FILE_NAME ):
      devCaX509 = VerifySwi.loadRootCert( VerifySwi.DEV_ROOT_CA_FILE_NAME )
      result = VerifySwi.checkSigningCert( swiSig, devCaX509 )

      if result == VerifySwi.VERIFY_SWI_RESULT.SUCCESS:
         result = sign( swiFile, forceSign=True, useDevCA=True )
         if result == SWI_SIGN_RESULT.SUCCESS:
            return UPDATE_SIG_RESULT.UPDATED_SIG
         else:
            return UPDATE_SIG_RESULT.ERROR_UPDATING

   return UPDATE_SIG_RESULT.ERROR_UNKNOWN_CERT

# Entry point by scripts. Some scripts are too snob to call Tac.run("swi sign")
# and so they import sign and call sign.sign, bypassing all argparse checks...
def sign( swiFile, forceSign=False, useDevCA=False, devCaKeyPair=None,
          workDir=None, signUrl=SignatureRequest.defaultSwiSignURL() ):
   # Shortcut in case of legacy swi (discovered by count the nb of dot rootfs)
   if platform.machine() == 'aarch64':  # A4NOCHECK
      r = os.system(
              r"exit $( unzip -Z1 %s | grep \\.rootfs-aarch64\\.sqsh | wc -l )"
                      % swiFile )
   else:
      r = os.system( r"exit $( unzip -Z1 %s | grep \\.rootfs-i386\\.sqsh | wc -l )"
                      % swiFile )
   r = os.WEXITSTATUS( r )
   if r < 1:
      return sign_( swiFile, forceSign, useDevCA, devCaKeyPair, signUrl=signUrl )
   # Also shortcut already adapted swis, they no longer have any .version files and
   # swadapting them would cause the version file to miss, and signing to fail.
   r = os.system( r'exit $( unzip -Z1 %s | grep \\.version$ | wc -l )' % swiFile )
   r = os.WEXITSTATUS( r )
   if r < 1:
      return sign_( swiFile, forceSign, useDevCA, devCaKeyPair, signUrl=signUrl )

   # We will sign the image and add signatures for all its contained optimized
   # images, which causes all optimizations in the provided image to be extracted
   # to workDir and signed as a side-effect. The workDir is deleted at the end,
   # unless it was provided as argument (for debugging or access to optimized img).
   # Even if the provided image only has a single rootfs, we still do all the hoopla
   # (double signature) to ensure the signature will be valid after image adaptation
   # on the dut (which might remove "debris" like foo.version thus altering sign).
   # We only skip the recursive signing if it is a legacy (non-swim) image.
   try:
      deleteWorkDir = True
      if workDir:
         deleteWorkDir = False
         if not os.path.exists( workDir ):
            os.mkdir( workDir )
         elif not os.path.isdir( workDir ):
            print( "Error: '%s' should be a directory" % workDir )
            sys.exit( -1 )
      else:
         workDir = "/tmp/optims-%d" % os.getpid()
         os.mkdir( workDir )
      # Make sure the image we got is a swi file
      if ( not os.path.isfile( swiFile ) or
           os.system( "set -e; image=$(readlink -f %s); cd %s;"
                      "unzip -o -q $image version" % ( swiFile, workDir ) ) ):
         print( "Error: '%s' should be an image" % swiFile )
         sys.exit( -1 )

      removeSwis = deleteWorkDir
      returnCode = signAll( swiFile, workDir, forceSign, useDevCA, devCaKeyPair,
                            removeSwis, signUrl=signUrl )

   finally:
      if deleteWorkDir:
         shutil.rmtree( workDir )

   return returnCode

def sign_( swiFile, forceSign=False, useDevCA=False, devCaKeyPair=None,
           signUrl=SignatureRequest.defaultSwiSignURL() ):
   signed, validSig = alreadySigned( swiFile )
   if signed:
      if not forceSign:
         if validSig:
            print( 'SWI is already signed with a valid signature.' )
            return SWI_SIGN_RESULT.ALREADY_SIGNED_VALID
         else:
            print( 'Warning: SWI is signed with an invalid signature.' )
            return SWI_SIGN_RESULT.ALREADY_SIGNED_INVALID
      else:
         # Force sign. Remove the swi-signature file from the SWI.
         Tac.run( [ 'zip', '-q', '-d', swiFile, SwiSignLib.SIG_FILE_NAME ] )
   blessed, version = Swi.getBlessedAndVersion( swiFile )
   if version is None:
      print( 'Error: Version file not found in SWI' )
      return SWI_SIGN_RESULT.ERROR_NOT_A_SWI
   if not useDevCA and not blessed and not querySignNotBlessed( forceSign ):
      print( "Not signing SWI that is not blessed." )
      return SWI_SIGN_RESULT.ERROR_SWI_NOT_BLESSED
   swiSignature = SignatureFile.Signature()
   swiData = SignatureFile.prepareDataForServer( swiFile, version, swiSignature )
   try:
      if useDevCA:
         signatureData = SignatureRequest.getDataFromDevCA( swiFile, swiData,
                                                        devCaKeyPair=devCaKeyPair )
      else:
         signatureData = SignatureRequest.getDataFromServer( swiFile, swiData,
                                                        licenseServerUrl=signUrl )
      SignatureFile.generateSigFileFromServer( signatureData, swiFile, swiSignature )
   except SignatureRequest.SigningServerError as e:
      print( e )
      # Remove the null signature
      Tac.run( [ 'zip', '-d', swiFile, SwiSignLib.SIG_FILE_NAME ] )
      return SWI_SIGN_RESULT.SERVER_ERROR
   return SWI_SIGN_RESULT.SUCCESS

class PrintTimed:
   def __init__( self, msg ):
      self.t0 = 0
      self.msg = msg

   def __enter__( self ):
      self.t0 = time.time()
      return self

   def __exit__( self, a, b, c ):
      if printDebugs:
         print( "%s, time taken: %ds" % ( self.msg, int( time.time() - self.t0 ) ) )

def dprint( msg ):
   if printDebugs:
      print( msg )

def getOptimizations( image, workDir ):
   ret = os.system( "set -e; image=$(readlink -f %s); cd %s; "
                    "unzip -q -o $image swimSqshMap" % ( image, workDir ) )
   if ret:
      return None # legacy image
   optims = []
   with open( "%s/swimSqshMap" % workDir ) as f:
      for line in f:
         optim, _ = line.split( "=", 1 )
         optims.append( optim )
   os.system( "rm %s/swimSqshMap" % workDir )
   return optims

def getSigFileName():
   return SwiSignLib.SIG_FILE_NAME

def extractSignature( image, destFile, sigFileName=None ):
   sigFn = sigFileName
   if not sigFn:
      sigFn = getSigFileName()
   ret = os.system( "set -e;"
                    "destDir=$(readlink -f $(dirname %s));"
                    "destFile=$(basename %s);"
                    "image=$(readlink -f %s);"
                    "sigFile=%s;"
                    "cd $destDir;"
                    "unzip -o -q $image $sigFile;"
                    "mv $sigFile $destFile" % (
                      destFile, destFile, image, sigFn )
                  )
   if ret != 0:
      print( f"Error: Cannot extract signature file {sigFn} from  {image}" )
      sys.exit( SWI_SIGN_RESULT.CANNOT_EXTRACT_SIGNATURE )

def signAll( image, workDir, forceSign=False, useDevCa=False, devCaKeyPair=None,
             removeSwis=True, signUrl=SignatureRequest.defaultSwiSignURL() ):

   optims = getOptimizations( image, workDir )
   if optims is None:
      dprint( "legacy image, just sign the image" )
      return sign_( image, forceSign, useDevCa, devCaKeyPair, signUrl=signUrl )
   dprint( "Optimizations in {}: {}".format( image, " ".join( optims ) ) )

   optimSigFiles = []
   for optim in optims:
      if optim.startswith( "Default" ):
         dprint( "Not signing Default optimization" )
         continue

      optimImage = f"{workDir}/{optim}.swi"
      # Adapt image
      with PrintTimed( "Swadapted %s" % optimImage ):
         os.system( f"swadapt {image} {optimImage} {optim}" )
      # Sign image
      with PrintTimed( "Signed %s" % optimImage ):
         ret = sign_( optimImage, forceSign=True, useDevCA=useDevCa,
                      devCaKeyPair=devCaKeyPair, signUrl=signUrl )
         if ret is not SWI_SIGN_RESULT.SUCCESS:
            print( "Error: cannot sign %s" % image )
            return ret

      # Extract the signature from each optim and call it <optim>.sig
      optimSigFile = "%s.signature" % optim
      optimSigPath = f"{workDir}/{optimSigFile}"
      optimSigFiles.append( os.path.basename( optimSigPath ) )
      extractSignature( optimImage, optimSigPath )

      # Signature extracted, remove swi as it is no longer needed
      if removeSwis:
         os.remove( optimImage )

   # update the source swi with the signatures of its "baby" swis (optims)
   dprint( "Adding signature files to {}: {}".format( image,
                                                  " ".join( optimSigFiles ) ) )
   ret = os.system( "src=$(readlink -f %s); cd %s; "
                    "zip -q -0 -X $src %s" % ( image, workDir,
                    " ".join( optimSigFiles ) ) )
   if ret != 0:
      print( "Error: cannot add variant signatures to %s" % image )
      return SWI_SIGN_RESULT.CANNOT_ADD_OPTIMIZATION_SIGNATURES

   # now sign the source swi
   with PrintTimed( "Signed %s" % image ):
      ret = sign_( image, forceSign, useDevCa, devCaKeyPair=devCaKeyPair,
                   signUrl=signUrl )
      if ret is not SWI_SIGN_RESULT.SUCCESS:
         print( "Error: cannot sign %s" % image )

   return ret

def signHandler( args ):
   if not args:
      args = args[1:]
   parser = argparse.ArgumentParser( prog="swi sign" )
   parser.add_argument( "swi", metavar="EOS.swi",
                        help="Path of the SWI to be signed" )
   parser.add_argument( "--force-sign", help="Overwrite any existing signature",
                        action="store_true")
   parser.add_argument( "--dev-ca", help="Use development certificates for signing",
                        action="store_true" )
   parser.add_argument( "--work-dir", help="If specified, will be kept and contain "
                                           "all extracted and signed optimizations "
                                           "(convenient for debugging)",
                        action="store" )
   parser.add_argument( "--verbose", help="Print debugging/timing info",
                        action="store_true" )
   parser.add_argument( "--intermediate-cert", type=int,
                        choices=list(
                           range( 1, len( SignatureRequest.SWI_SIGN_URL ) + 1 )
                        ),
                        default=SignatureRequest.DEFAULT_SWI_SIGN_URL_POS + 1,
                        help="Choose the intermediate CA to use for the signing "
                             "request" )

   args = parser.parse_args( args )

   if args.verbose:
      global printDebugs
      printDebugs = True

   signUrl = SignatureRequest.SWI_SIGN_URL[ args.intermediate_cert - 1 ]
   returnCode = sign( args.swi, args.force_sign, args.dev_ca, None, args.work_dir,
                      signUrl=signUrl )

   sys.exit( returnCode )
