#!/usr/bin/env python3
# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function

import argparse
import base64
import codecs
import os
import zipfile
from M2Crypto import X509, BIO
from ArPyUtils.ArTrace import Tracing

traceHandle = Tracing.Handle( 'VerifySwi' )
t0 = traceHandle.trace0

ARISTA_ROOT_CA = """-----BEGIN CERTIFICATE-----
MIIF6TCCA9GgAwIBAgITFgAAAAN8An0siquXTAAAAAAAAzANBgkqhkiG9w0BAQsF
ADA6MTgwNgYDVQQDEy9BcmlzdGEgTmV0d29ya3MgSW50ZXJuYWwgSVQgUm9vdCBD
ZXJ0IEF1dGhvcml0eTAeFw0xNzA1MzEyMjU0MjhaFw0yNzA1MzEyMzA0MjhaMGkx
EzARBgoJkiaJk/IsZAEZFgNjb20xHjAcBgoJkiaJk/IsZAEZFg5hcmlzdGFuZXR3
b3JrczEyMDAGA1UEAxMpQXJpc3RhSVQtSUNBIEVDRFNBIElzc3VpbmcgQ2VydCBB
dXRob3JpdHkwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARPqNqFCrbuLJ1EWkKg
3RLdwkzx8kZxtnMmU0xTT1sLN8oNKMp4yFanvVFxwd4PydTeIlJUAZQ5a73dXqom
pHk2o4ICgjCCAn4wEAYJKwYBBAGCNxUBBAMCAQAwHQYDVR0OBBYEFHO110dbJLhN
GJaktFHcGpdlP31KMFcGA1UdIARQME4wTAYMKwYBAgERBwEEAgEDMDwwOgYIKwYB
BQUHAgEWLmh0dHA6Ly9pdC1wa2kuYXJpc3RhbmV0d29ya3MuY29tL3BraS9jcHMu
aHRtbAAwGQYJKwYBBAGCNxQCBAweCgBTAHUAYgBDAEEwCwYDVR0PBAQDAgGGMA8G
A1UdEwEB/wQFMAMBAf8wHwYDVR0jBBgwFoAULe/a5u/PHHdRv1oCqSjNlACt4cow
gYEGA1UdHwR6MHgwdqB0oHKGNmh0dHA6Ly9pdC1wa2kuYXJpc3RhbmV0d29ya3Mu
Y29tL3BraS9BcmlzdGFJVC1ST09ULmNybIY4aHR0cDovL2l0LXBraTAxLmFyaXN0
YW5ldHdvcmtzLmNvbS9wa2kvQXJpc3RhSVQtUk9PVC5jcmwwggESBggrBgEFBQcB
AQSCAQQwggEAMH0GCCsGAQUFBzAChnFodHRwOi8vaXQtcGtpLmFyaXN0YW5ldHdv
cmtzLmNvbS9wa2kvQXJpc3RhSVQtUk9PVEFyaXN0YSUyME5ldHdvcmtzJTIwSW50
ZXJuYWwlMjBJVCUyMFJvb3QlMjBDZXJ0JTIwQXV0aG9yaXR5LmNydDB/BggrBgEF
BQcwAoZzaHR0cDovL2l0LXBraTAxLmFyaXN0YW5ldHdvcmtzLmNvbS9wa2kvQXJp
c3RhSVQtUk9PVEFyaXN0YSUyME5ldHdvcmtzJTIwSW50ZXJuYWwlMjBJVCUyMFJv
b3QlMjBDZXJ0JTIwQXV0aG9yaXR5LmNydDANBgkqhkiG9w0BAQsFAAOCAgEAotuJ
/hLxKlJOs85pYfbDR8bg5HzsEVHOrc/fjUf85e3riGMh+PaQHj5L++Ah9cMmUujh
+bPq8ycrjhyYwi0IZGLjwJuGWHQ2TMXqB4o+lmKGchpR0gA31pcZCANt5atRghrQ
hTMHN3L5CZRDn3JSCD1xQbW/WVDYlHv6IpWkd2orem/lgQfKVwVlkeB3YPJn5Hka
Hlx37mksQ9KEh7v52Tira5JnP67mUdT1C+gvdGF3DJk3Lg7GWX9Uxo1vG28AmJOU
0n28ek5Ynh0T3uQ+jkMoJEIlyH1fKZ6zyK0sf+yLRb7brkfssZDrRIatxKEkv6Oc
h4kXO2mvvMJxQDf7VvGXEC3fSRURLwPz//6JMx942iOKsES8ZT9nT2q9MxJXfInn
3EcKGmPWKQR4n2qHfmq6sfk2eFBUYIrZBm9RUbVbyLZLCOv2KxJ7FFZ9LV1jp5An
AyHLJUMQqqw/kvUUvUq1bI/PtEOlNc9Ndt/3yeh+HByzIw8/f+gjKkUjQpVncuqS
kFotBPNNj/LjbQD40R/tJ0z/8sPXCGJuo4mE9s/MwnWmkAHxpZyCccMBlNp3LkJk
FHcsVb36Vclv5XWDe5AxU+0sQjEB4LGP7nYo8wjjvSZIpYXRiAmDRGuAGi/W/W3F
6hEQ661JK4KPJvoQsMqYaO/TkZPIXEAdgEDkmj0=
-----END CERTIFICATE-----
"""

ARISTA_ROOT_CA2 = """-----BEGIN CERTIFICATE-----
MIIFeDCCA2CgAwIBAgITKAAAABE7UZEyX+daXgAAAAAAETANBgkqhkiG9w0BAQsF
ADA6MTgwNgYDVQQDEy9BcmlzdGEgTmV0d29ya3MgSW50ZXJuYWwgSVQgUm9vdCBD
ZXJ0IEF1dGhvcml0eTAeFw0yMDA2MTAyMjA5MDdaFw0zNTA2MTAyMjE5MDdaMIGv
MQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEUMBIGA1UEBxMLU2Fu
dGEgQ2xhcmExHTAbBgNVBAoTFEFyaXN0YSBOZXR3b3JrcyBJbmMuMR8wHQYDVQQL
ExZJbmZvcm1hdGlvbiBUZWNobm9sb2d5MTUwMwYDVQQDEyxBcmlzdGEgTmV0d29y
a3MgU1dJIFNpZ25lciBJc3N1aW5nIEF1dGhvcml0eTCBmzAQBgcqhkjOPQIBBgUr
gQQAIwOBhgAEAOAiRTX2jh4JRoKRYgXe018m8r9RRWTnhwjT499M1XpmW4MkdxwR
J4wiYqwHJNmFlX+ef40s+0HgD3DCJLIb0PiFAKsF7YkewAJxizlJXrhNkyYgeBKf
kXy/nqPfF5GhylrDg69TB+mTtK+xX12l+iga+jXvnSnC0t1/Re9Ly1Y0dsYho4IB
hzCCAYMwHQYDVR0OBBYEFG64Ehsf36CRfdGXjODKKf4WzQkRMB8GA1UdIwQYMBaA
FC3v2ubvzxx3Ub9aAqkozZQAreHKMHQGA1UdHwRtMGswaaBnoGWGY2ZpbGU6Ly8v
L1dJTi0zRzZJRzAySzFSMC9DZXJ0RW5yb2xsL0FyaXN0YSUyME5ldHdvcmtzJTIw
SW50ZXJuYWwlMjBJVCUyMFJvb3QlMjBDZXJ0JTIwQXV0aG9yaXR5LmNybDCBkQYI
KwYBBQUHAQEEgYQwgYEwfwYIKwYBBQUHMAKGc2ZpbGU6Ly8vL1dJTi0zRzZJRzAy
SzFSMC9DZXJ0RW5yb2xsL1dJTi0zRzZJRzAySzFSMF9BcmlzdGElMjBOZXR3b3Jr
cyUyMEludGVybmFsJTIwSVQlMjBSb290JTIwQ2VydCUyMEF1dGhvcml0eS5jcnQw
GQYJKwYBBAGCNxQCBAweCgBTAHUAYgBDAEEwDwYDVR0TAQH/BAUwAwEB/zALBgNV
HQ8EBAMCAYYwDQYJKoZIhvcNAQELBQADggIBAEEiS2zAM0ewCGBVc9U/Tz3PN2Y0
IQQ7ihKWfRBS7WehDEkDIH4Mm0irJfgOGctmYTHjAJTccqjjKzefPHYtZfgxbpA0
H4V0eGPMVF9wx0AA1UyLMiVn8NoxxOuhJ9a9GP9ahhYxZCNwlds0zOceS0CEcprD
IZB/MpJkc506RoreZ/MrU0K42RGOJGNbMqatv4yRJep66vdWRkK+K7fN6oOQ9sdj
4PcXfQ1cbTDp3uoj5ib0bS5EPLcC7taq/vxlAtws9aANTsLlbX13IHGsD2lrDcjX
6Kzo2+RQGiCfUR3F4ECf5XeeSamxNjIhDEVg7XNjvF+scB37ixQVwZijJIqVDKAU
0wMmk4T4sV8KEj/hEAy3KuTjTUrt0IJfy2FHVfz5PTqMJ4RXfcujnRyinOtCpMjd
TDE5QzxEjV5gUGlH19xxKGqltt7IX9L3drdNni8fC1ihIkHZqtlpMi6+zV1F4+49
u/ni/uRgHLCWhip25Jn1UF+J+sTFBjCilOlNyztuTiPZsY3H99GrL+S8UilN/TSm
waCsVJCshCaEh/cI/lZGjUfB1vefaL0Vx4rQXMluqLWF1o2LPPioqAZfGTQQPt4F
joZZ/vIFoHwsX70xI7DHjMttLQwndAA3YxNTn2EjykIdo1PQGvwUeNF9IwV/+jVW
IaxKkdU4v9hUBM16
-----END CERTIFICATE-----
"""

SIG_FILE_NAME = 'swi-signature'
SWIX_SIG_FILE_NAME = 'swix-signature'
ARISTA_ROOT_CA_FILE_NAME = '/etc/swi-signature-rootCa.crt'
ARISTA_ROOT_CA2_FILE_NAME = '/etc/swi-signature-rootCa2.crt'
ARISTA_ROOT_CA3_FILE_NAME = '/etc/swi-signature-rootCa3.crt'
USER_ROOT_CA_FILE_NAME = '/etc/swi-signature-userCa.crt'
DEV_ROOT_CA_FILE_NAME = '/etc/swi-signing-devCA/root.crt'

SWI_CERTIFICATES = {
   ARISTA_ROOT_CA_FILE_NAME: ARISTA_ROOT_CA,
   ARISTA_ROOT_CA2_FILE_NAME: ARISTA_ROOT_CA2,
   ARISTA_ROOT_CA3_FILE_NAME: None,
   USER_ROOT_CA_FILE_NAME: None,
}

def bytesRepr( c ):
   """py2: bytes, py3: int"""
   if not isinstance( c, int ):
      c = ord( c )
   return u'\\x{:x}'.format( c )

def textRepr( c ):
   d = ord( c )
   if d >= 0x10000:
      return u'\\U{:08x}'.format( d )
   else:
      return u'\\u{:04x}'.format( d )

def bsEscape( ex ):
   """Python2/3 compatible version of backslashescape"""
   s, start, end = ex.object, ex.start, ex.end
   charRepr = bytesRepr if isinstance( ex, UnicodeDecodeError ) else textRepr
   return ''.join( charRepr( c ) for c in s[ start : end ] ), end

codecs.register_error( 'bsEscape', bsEscape )

class SwiSignature( object ):
   def __init__( self ):
      self.version = ""
      self.hashAlgo = ""
      self.cert = ""
      self.signature = ""
      self.offset = 0
      self.size = 0

   def updateFields( self, sigFile ):
      """ Update the fields of this SwiSignature with the file object
      @sigFile, a new-line-delimited file with key-value
      pairs in the form of key:value. For example:
      key1:value1
      key2:value2
      etc. """
      for line in sigFile:
         data = line.decode( "utf-8", "bsEscape" ).split( ':' )
         if len( data ) == 2:
            if data[ 0 ] == 'Version':
               self.version = data[ 1 ].strip()
            elif data[ 0 ] == 'HashAlgorithm':
               self.hashAlgo = data[ 1 ].strip()
            elif data[ 0 ] == 'IssuerCert':
               self.cert = base64Decode( data[ 1 ].strip() )
            elif data[ 0 ] == 'Signature':
               self.signature = base64Decode( data[ 1 ].strip() )
         else:
            t0( 'Unexpected format for line in swi-signature file:', line )

class VERIFY_SWI_RESULT( object ):
   SUCCESS = 0
   ERROR_SIGNATURE_FILE = 3
   ERROR_VERIFICATION = 4
   ERROR_HASH_ALGORITHM = 5
   ERROR_SIGNATURE_FORMAT = 6
   ERROR_NOT_A_SWI = 7
   ERROR_CERT_MISMATCH = 8
   ERROR_INVALID_SIGNING_CERT = 9
   ERROR_INVALID_ROOT_CERT = 10

VERIFY_SWI_MESSAGE = {
   VERIFY_SWI_RESULT.SUCCESS: "SWI verification successful.",
   VERIFY_SWI_RESULT.ERROR_SIGNATURE_FILE: "SWI is not signed.",
   VERIFY_SWI_RESULT.ERROR_VERIFICATION: "SWI verification failed.",
   VERIFY_SWI_RESULT.ERROR_HASH_ALGORITHM: "Unsupported hash algorithm for SWI"
                                           " verification.",
   VERIFY_SWI_RESULT.ERROR_SIGNATURE_FORMAT: "Invalid SWI signature file.",
   VERIFY_SWI_RESULT.ERROR_NOT_A_SWI: "Input does not seem to be a swi image.",
   VERIFY_SWI_RESULT.ERROR_CERT_MISMATCH: "Signing certificate used to sign SWI"
                                          " is not signed by root certificate.",
   VERIFY_SWI_RESULT.ERROR_INVALID_SIGNING_CERT: "Signing certificate is not a"
                                                 " valid certificate.",
   VERIFY_SWI_RESULT.ERROR_INVALID_ROOT_CERT: "Root certificate is not a"
                                              " valid certificate.",
}

# Same as VERIFY_SWI_MESSAGE, replacing SWI/swi with SWIX/swix
VERIFY_SWIX_MESSAGE = {
   VERIFY_SWI_RESULT.SUCCESS: "SWIX verification successful.",
   VERIFY_SWI_RESULT.ERROR_SIGNATURE_FILE: "SWIX is not signed.",
   VERIFY_SWI_RESULT.ERROR_VERIFICATION: "SWIX verification failed.",
   VERIFY_SWI_RESULT.ERROR_HASH_ALGORITHM: "Unsupported hash algorithm for SWIX"
                                           " verification.",
   VERIFY_SWI_RESULT.ERROR_SIGNATURE_FORMAT: "Invalid SWIX signature file.",
   VERIFY_SWI_RESULT.ERROR_NOT_A_SWI: "Input does not seem to be a SWIX.",
   VERIFY_SWI_RESULT.ERROR_CERT_MISMATCH: "Signing certificate used to sign SWIX"
                                          " is not signed by root certificate.",
   VERIFY_SWI_RESULT.ERROR_INVALID_SIGNING_CERT: "Signing certificate is not a"
                                                 " valid certificate.",
   VERIFY_SWI_RESULT.ERROR_INVALID_ROOT_CERT: "Root certificate is not a"
                                              " valid certificate.",
}

def isSwixFile( filename ):
   # when we copy files, it generates a temporary file that looks like
   # swixname.swix.<randomSeq>, which is the file we end up validating
   exts = filename.lower().split( '.' )[ 1 : ]
   return 'swix' in exts

class X509CertException( Exception ):
   def __init__( self, code ):
      self.code = code
      message = VERIFY_SWI_MESSAGE[ code ]
      super( X509CertException, self ).__init__( message )

def getSwiSignatureData( swi ):
   sigFileName = SWIX_SIG_FILE_NAME if isSwixFile( swi ) else SIG_FILE_NAME
   try:
      swiSignature = SwiSignature()
      with zipfile.ZipFile( swi, 'r' ) as swiFile:
         sigInfo = swiFile.getinfo( sigFileName )
         with swiFile.open( sigInfo, 'r' ) as sigFile:
            # Get offset from our current location before processing sigFile
            # pylint: disable-msg=protected-access
            swiSignature.offset = sigFile._fileobj.tell()
            swiSignature.size = sigInfo.compress_size
            swiSignature.updateFields( sigFile )
            return swiSignature
   except KeyError:
      # Occurs if SIG_FILE_NAME is not in the swi (the SWI is not
      # signed properly)
      return None

def verifySignatureFormat( swiSignature ):
   # Check that the signing cert, hash algorithm, and signatures are valid
   return ( len( swiSignature.cert ) != 0 and
            len( swiSignature.hashAlgo ) != 0 and
            len( swiSignature.signature ) != 0 )

def base64Decode( text ):
   try:
      return base64.standard_b64decode( text )
   except TypeError:
      return ""

def loadSigningCert( swiSignature ):
   # Read signing cert from memory and load it as an X509 cert object
   try:
      signingCert = X509.load_cert_string( swiSignature.cert )
      return signingCert
   except X509.X509Error:
      raise X509CertException( VERIFY_SWI_RESULT.ERROR_INVALID_SIGNING_CERT )

def loadRootCert( rootCA, isFile=True ):
   # Read root cert from file or memory and load it as an X509 cert object
   try:
      if isFile:
         return X509.load_cert( rootCA )
      else:
         return X509.load_cert_string( rootCA )
   except ( X509.X509Error, BIO.BIOError ):
      raise X509CertException( VERIFY_SWI_RESULT.ERROR_INVALID_ROOT_CERT )

def signingCertValid( signingCertX509, rootCertX509 ):
   # Validate cert used to sign SWI with root CA
   result = signingCertX509.verify( rootCertX509.get_pubkey() )
   if result == 1:
      return VERIFY_SWI_RESULT.SUCCESS
   else:
      return VERIFY_SWI_RESULT.ERROR_CERT_MISMATCH

def getHashAlgo( swiSignature ):
   hashAlgo = swiSignature.hashAlgo
   # For now, we always use SHA-256
   if hashAlgo == 'SHA-256':
      return 'sha256'
   return None

def swiSignatureValid( swi, swiSignature ):
   signingCertX509 = loadSigningCert( swiSignature )
   hashAlgo = getHashAlgo( swiSignature )
   if hashAlgo is None:
      return VERIFY_SWI_RESULT.ERROR_HASH_ALGORITHM

   # Verify the swi against the signature in swi-signature
   offset = 0
   BLOCK_SIZE = 65536
   pubkey = signingCertX509.get_pubkey()
   pubkey.reset_context( md=hashAlgo )
   # Begin reading the data to verify
   pubkey.verify_init()
   # Read the swi file into the verification function, up to the swi signature file
   with open( swi, 'rb' ) as swiFile:
      while offset < swiSignature.offset:
         if offset + BLOCK_SIZE < swiSignature.offset:
            numBytes = BLOCK_SIZE
         else:
            numBytes = swiSignature.offset - offset
         pubkey.verify_update( swiFile.read( numBytes ) )
         offset += numBytes
      # Now that we're at the swi-signature file, read zero's into the verification
      # function up to the size of the swi-signature file.
      # pylint: disable-msg=anomalous-backslash-in-string
      pubkey.verify_update( b'\000' * swiSignature.size )

      # Now jump to the end of the swi-signature file and read the rest of the swi
      # file into the verification function
      swiFile.seek( swiSignature.size, os.SEEK_CUR )
      for block in iter( lambda: swiFile.read( BLOCK_SIZE ), b'' ):
         pubkey.verify_update( block )
   # After reading the swi file and skipping over the swi signature, check that the
   # data signed with pubkey is the same as signature in the swi-signature.
   result = pubkey.verify_final( swiSignature.signature )
   if result == 1:
      return VERIFY_SWI_RESULT.SUCCESS
   else:
      return VERIFY_SWI_RESULT.ERROR_VERIFICATION

def checkSigningCert( swiSignature, rootCertX509 ):
   signingCert = loadSigningCert( swiSignature )
   return signingCertValid( signingCert, rootCertX509 )

def verifyAufSig( rootCA, sig ):
   rootCertX509 = loadRootCert( rootCA )
   signingCertX509 = loadRootCert( sig )
   return signingCertValid( signingCertX509, rootCertX509 )

def _verifySwi( swi, rootCA, rootCAIsFile=True ):
   try:
      if not zipfile.is_zipfile( swi ):
         return VERIFY_SWI_RESULT.ERROR_NOT_A_SWI, None
      swiSignature = getSwiSignatureData( swi )
      if swiSignature is None:
         return VERIFY_SWI_RESULT.ERROR_SIGNATURE_FILE, None
      if not verifySignatureFormat( swiSignature ):
         return VERIFY_SWI_RESULT.ERROR_SIGNATURE_FORMAT, None
      rootCertX509 = loadRootCert( rootCA, isFile=rootCAIsFile )
      result = checkSigningCert( swiSignature, rootCertX509 )
      if result != VERIFY_SWI_RESULT.SUCCESS:
         # Signing cert invalid
         return result, None
      result = swiSignatureValid( swi, swiSignature )
      if result != VERIFY_SWI_RESULT.SUCCESS:
         return result, None
      else:
         return VERIFY_SWI_RESULT.SUCCESS, rootCertX509
   except ( IOError, BIO.BIOError ) as e:
      print( e )
      return VERIFY_SWI_RESULT.ERROR_VERIFICATION, None
   except X509CertException as e:
      return e.code, None

def maybeCreateSwiCerts( verbose=False ):
   for certPath, certContent in SWI_CERTIFICATES.items():
      if certContent is None:
         continue
      if os.path.exists( certPath ):
         continue

      try:
         with open( certPath, 'w' ) as caFile:
            caFile.write( certContent )
      except IOError as e:
         if verbose:
            print( "Failed to restore Arista root CA (%s): %s" % ( certPath, e ) )

def verifySwi( swi, rootCA=None, verbose=False, rootCAIsFile=True ):
   maybeCreateSwiCerts( verbose=verbose )

   if rootCA:
      rootCAs = [ rootCA ]
   else:
      rootCAs = []
      for i in SWI_CERTIFICATES:
         if os.path.exists( i ) and os.path.getsize( i ) > 0:
            rootCAs.append( i )

   if not rootCAs:
      print( "No root certificate available for verification." )
      retCode = VERIFY_SWI_RESULT.ERROR_VERIFICATION

   caUsed = None
   for ca in rootCAs:
      if verbose:
         print( "Verifying against %s" % ca )
      retCode, caUsed = _verifySwi( swi, ca, rootCAIsFile=rootCAIsFile )
      if retCode == VERIFY_SWI_RESULT.SUCCESS:
         break
      if not ( retCode == VERIFY_SWI_RESULT.ERROR_CERT_MISMATCH or
               retCode == VERIFY_SWI_RESULT.ERROR_INVALID_ROOT_CERT ):
         break
   return retCode, caUsed

def main():
   helpText = "Verify Arista SWI image or extension"
   parser = argparse.ArgumentParser( description=helpText,
               formatter_class=argparse.ArgumentDefaultsHelpFormatter )
   parser.add_argument( "swi_file", help="SWI[X] file to verify" )
   parser.add_argument( "--CAfile", default=None,
                        help="Root certificate to verify against." )

   args = parser.parse_args()
   swi = args.swi_file
   rootCA = args.CAfile

   retCode, _ = verifySwi( swi, rootCA=rootCA, verbose=True )
   if isSwixFile( swi ):
      print( VERIFY_SWIX_MESSAGE[ retCode ] )
   else:
      print( VERIFY_SWI_MESSAGE[ retCode ] )
   exit( retCode )

if __name__ == "__main__":
   main()
