#!/usr/bin/env python3
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

'''
Contains functions for getting a cryptographic signature from IT's
signing server, used for SWI/SWIX/etc signing.
Also contains functions for generating a test signature locally using
development certificates.
'''

from __future__ import absolute_import, division, print_function

import base64
from collections import namedtuple
import json
import requests
import M2Crypto
from ArPyUtils.ArTrace import Tracing

traceHandle = Tracing.Handle( 'SwiSignRequest' )
t0 = traceHandle.trace0

# Various log-in details
VAULT_BASE_URL = "https://vault.aristanetworks.com:8200/v1/"
VAULT_LOGIN_PATH = "auth/approle/login"
VAULT_PASSWORD_PATH = "secret/data/lss/public/image-sign-robot"
VAULT_ROLE_ID = "9f79b6f1-e5a5-c5de-10a9-afea65534e89"
SECRET_ID_FILE = "/etc/swi-signing-secret-id.txt"
SIGNING_USER = 'image-sign-robot'

# Locations of development cert/key
DEV_SIGNING_CERT = "/etc/swi-signing-devCA/signing.crt"
DEV_SIGNING_KEY = "/etc/swi-signing-devCA/signing.key"

SWI_SIGN_URL = [
   # This URL corresponds to the following certificate:
   #
   # cert. version     : 3
   # serial number     : 16:00:00:00:03:7C:02:7D:2C:8A:AB:97:4C:00:00:00:00:00:03
   # issuer name       : CN=Arista Networks Internal IT Root Cert Authority
   # subject name      : DC=com, DC=aristanetworks, CN=AristaIT-ICA ECDSA
   #                     Issuing Cert Authority
   # issued  on        : 2017-05-31 22:54:28
   # expires on        : 2027-05-31 23:04:28
   # signed using      : RSA with SHA-256
   # EC key size       : 256 bits
   # basic constraints : CA=true
   # key usage         : Digital Signature, Key Cert Sign, CRL Sign
   'https://license.aristanetworks.com/sign/swi-hash/',
   # This URL corresponds to the following certificate:
   #
   # cert. version     : 3
   # serial number     : 28:00:00:00:11:3B:51:91:32:5F:E7:5A:5E:00:00:00:00:00:11
   # issuer name       : CN=Arista Networks Internal IT Root Cert Authority
   # subject name      : C=US, ST=California, L=Santa Clara, O=Arista Networks Inc.,
   #                     OU=Information Technology, CN=Arista Networks SWI Signer
   #                     Issuing Authority
   # issued  on        : 2020-06-10 22:09:07
   # expires on        : 2035-06-10 22:19:07
   # signed using      : RSA with SHA-256
   # EC key size       : 521 bits
   # basic constraints : CA=true
   # key usage         : Digital Signature, Key Cert Sign, CRL Sign
   'https://license.aristanetworks.com/sign/v2/eos/swi/',
]
DEFAULT_SWI_SIGN_URL_POS = 0

def defaultSwiSignURL():
   return SWI_SIGN_URL[ DEFAULT_SWI_SIGN_URL_POS ]

SigningServerData = namedtuple( 'SigningServerData',
                                'hashAlgorithm certificate signature' )

class SigningServerError( Exception ):
   ''' Used for when we get an error attempting to access or parse results
   from the signing server.'''
   pass

def getRobotPassword():
   '''Get the password for our automated account image-sign-robot from
   IT's vault database'''
   secretId = ""
   try:
      with open( SECRET_ID_FILE, 'r' ) as secret:
         for line in secret:
            if line and not line.startswith( '#' ):
               secretId = line.strip()
   except IOError:
      raise SigningServerError( "Unable to get secret ID needed for authentication" )
   loginData = { "role_id": VAULT_ROLE_ID,
                 "secret_id": secretId }
   try:
      response = requests.post( VAULT_BASE_URL + VAULT_LOGIN_PATH,
                                data=json.dumps( loginData ) )
      output = json.loads( response.text )
      client_token = output[ 'auth' ][ 'client_token' ]
      header = { "X-VAULT-TOKEN": client_token }
      response = requests.get( VAULT_BASE_URL + VAULT_PASSWORD_PATH,
                               headers=header )
      output = json.loads( response.text )
      return output[ 'data' ][ 'data' ][ 'password' ]
   except requests.exceptions.ConnectionError as e:
      raise SigningServerError( "Unable to get automated account password: %s" % e )
   except KeyError:
      raise SigningServerError( "Unable to get automated account password:"
                                " Received unexpected data format: %s" % output )

def getDataFromServer( swiFile, swiData,
        licenseServerUrl=defaultSwiSignURL(),
        user=SIGNING_USER,
        passwd=None ):
   '''Authenticate to the signing server and post @swiData, returning the
   json response from the signing server.'''
   if not passwd:
      passwd = getRobotPassword()
   authzBytes = ( user + ":" + passwd ).encode( 'utf-8' )
   header = {
               'Content-Type': 'application/json',
               'Authorization': 'Basic ' + base64.standard_b64encode( authzBytes
                  ).decode( 'utf-8' )
            }
   response = None
   try:
      response = requests.post( licenseServerUrl,
                                data=json.dumps( swiData ), headers=header )
   except requests.exceptions.ConnectionError as e:
      raise SigningServerError( "Unable to access signing server: %s" % e )
   else:
      if response.status_code == 201:
         output = json.loads( response.text )
         t0( output )
         return output
      elif response.status_code == 403:
         raise SigningServerError( "Wrong username/password." )
      else:
         print( "Input data: ", swiData )
         print( "Headers: ", response.headers )
         print( response.text )
         raise SigningServerError( "Error %s. SWI not signed!" %
                                    response.status_code )

def getDataFromDevCA( swiFile, swiData, devCaKeyPair=None ):
   '''Use local development cert/key to sign the file, returning
   the same structure of response as the signing server would'''
   errorMsg = "Dev SWI Signing CA not installed"

   devCaKeyPair = devCaKeyPair or ( DEV_SIGNING_CERT, DEV_SIGNING_KEY )

   def loadFile( fileName ):
      with open( fileName, "r" ) as f:
         return f.read()

   try:
      signingCertPath, signingKeyPath = devCaKeyPair

      signingCert = loadFile( signingCertPath )

      errorMsg = "Failed to locally generate signature data"
      result = {}
      result[ "version" ] = swiData[ 'version' ]
      result[ "product" ] = swiData[ 'product' ]
      result[ "hash_algorithm" ] = swiData[ 'hash_algorithm' ]
      result[ "signing_certificate" ] = signingCert

      key = M2Crypto.EVP.load_key( signingKeyPath )
      key.reset_context( md='sha256' )
      key.sign_init()
      with open( swiFile, 'rb' ) as swiFileHandle:
         while True:
            data = swiFileHandle.read( 2**20 )
            if not data:
               break
            key.sign_update( data )
      result[ "signature" ] = "vault:v1:%s" % base64.b64encode(
         key.sign_final() ).decode()
      return result
   except:
      import traceback
      print( traceback.format_exc() )
      raise SigningServerError( errorMsg )

def extractServerData( serverData ):
   '''Perform some validation on @serverData from the signing server or
   our local "server" and return the data.'''
   try:
      hashAlgorithm = serverData[ "hash_algorithm" ]
      certificate = serverData[ "signing_certificate" ]
      signature = serverData[ "signature" ]

      # Signature is a vault signature in the form of vault:v1:SIGNATURE,
      # v1 is a version number. It is given in unicode so we need to convert to
      # byte string
      signature = str( signature.split( ':' )[ 2 ] )

      return SigningServerData( hashAlgorithm=hashAlgorithm,
                                certificate=certificate,
                                signature=signature )
   except ( KeyError, IndexError ):
      raise SigningServerError( 'Signing server gave unexpected data format: %s' %
                                 serverData )
