#!/usr/bin/env python3
# Copyright (c) 2009, 2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import optparse # pylint: disable=deprecated-module
import os
import re
import sys

from ExtensionMgr import ( 
      errors,
      logs,
)
import ExtensionMgrLib
import SignatureVerificationMapLib as SigVerifyMapLib
import SslCertKey
import SwiSignLib
import Tac
import TpmGeneric.Defs as TpmDefs
from TpmGeneric.Tpm import TpmGeneric

MgmtSslConstants = Tac.Type( "Mgmt::Security::Ssl::Constants" )

class LoadExtensionsError( errors.Error ):
   pass


class ParseError( errors.Error ):
   pass

def _pcrExtendNoExtension():
   try:
      tpm = TpmGeneric()
      if not tpm.isToggleBitSet( TpmDefs.SBToggleBit.MEASUREDBOOT ):
         return

      tpm.pcrExtendFromData( TpmDefs.PcrRegisters.EOS_EXTENSIONS,
                             TpmDefs.PcrEventType.EOS_EXTENSION_INSTALL,
                             TpmDefs.DefaultPcrExtend.NO_BOOT_EXT.value,
                             log=[ 'no boot extensions' ] )
   except TpmDefs.Error:
      pass

def getFormatFromName( name ):
   if name.lower().endswith( '.rpm' ):
      return 'formatRpm'
   if name.lower().endswith( '.swix' ):
      return 'formatSwix'
   return 'formatUnknown'

def parseConfig( lines ):
   """Takes an iterable containing lines from a boot-extensions config file
   and returns a list of ( extension name, flags, type, dependencies ) tuples.
   line has the following format:
   extension no|force|boot|no,boot|force,boot format [dependencies]
   """
   extensions = []
   specialCharRe = re.compile( ".*[\x00-\x08\x0b-\x1f\x7f-\xff]+.*" )
   lineRe = re.compile(
      r'(?P<file>\S+)(\s+(?P<flags>force,boot|no,boot|boot|force|no))?'
      r'(\s+(?P<format>\S+)\s*(?P<deps>.*))?' )
   blankRe = re.compile( "^\\s*$" )
   lineNum = 0
   for line in lines:
      lineNum += 1
      if line.startswith( '#' ): # pylint: disable=no-else-continue
         continue
      elif blankRe.match( line ) is not None:
         continue
      elif specialCharRe.match( line ) is not None:
         msg = "Invalid character found on line %d" % lineNum
         raise ParseError( msg )
      m = lineRe.match( line )
      if m is None:
         msg = "Parse error on line %d: %s" % ( lineNum, line )
         raise ParseError( msg )
      groups = m.groupdict()
      if groups[ 'format' ] is None:
         groups[ 'format' ] = getFormatFromName( groups[ 'file' ] )
      extensions.append( ( groups[ 'file' ], groups[ 'flags' ], groups[ 'format' ],
                           groups[ 'deps' ] ) )
   return extensions

def bootStatus( flags ):
   if flags and 'boot' in flags:
      # configured by 'boot extension' command
      return ExtensionMgrLib.BootStatus.bootByConfig
   else:
      return ExtensionMgrLib.BootStatus.bootByCopy

usage = """
LoadExtensions -f <config-file> -d <extensions dir> [options]

Reads the config file, installs the extensions specified in the config,
and generates output indicating the status of each extension, i.e. whether
each extension installed properly or not.

This program must be run as root.
"""
def main( argv, stdout=sys.stdout, stderr=sys.stderr ):
   parser = optparse.OptionParser( usage=usage )
   parser.add_option( "-d", "--dir", action="store",
      help="the extensions directory" )
   parser.add_option( "-f", "--config", action="store",
      help="the config file" )
   parser.add_option( "-o", "--output", action="store", default="",
      help="the output file (default: stdout)" )
   parser.add_option( "-q", "--quiet", action="store_true", default=False,
      help="suppress output (default: %default)" )
   parser.add_option( "-l", "--logging", action="store", default="",
      help="where to log the output, syslog if none specified" )

   options, args = parser.parse_args( args=argv )
   if args:
      parser.error( "Unrecognized argument" )
   if not options.config:
      parser.error( "Must specify -f <config file>" )
   if not options.dir:
      parser.error( "Must specify -d <extensions directory>" )
   if options.logging:
      ExtensionMgrLib.LOGFILE = options.logging

   def err( *args, **kwargs ):
      if not options.quiet:
         stderr.write( *args, **kwargs )

   if not os.path.exists( options.config ):
      msg = "%s not found: no extensions will be installed" % options.config
      raise LoadExtensionsError( msg )

   try:
      f = open( options.config ) # pylint: disable=consider-using-with
      lines = f.readlines()
      f.close()
      extensionsToLoad = parseConfig( lines )
   except ParseError as e:
      raise LoadExtensionsError( str( e ) ) # pylint: disable=raise-missing-from
   except Exception as e:
      msg = "Error reading config file: %s" % e
      raise LoadExtensionsError( msg ) # pylint: disable=raise-missing-from

   if options.output:
      try:
         output = open( options.output, "w" ) # pylint: disable=consider-using-with
      except Exception as e:
         msg = "Error opening output file: %s\n" % e
         raise LoadExtensionsError( msg ) # pylint: disable=raise-missing-from
   else:
      output = stdout

   # print output header
   output.write( "# --config=%s\n" % options.config )
   output.write( "# --dir=%s\n" % options.dir )

   # Construct a dummy Extension::Status instance to pass to ExtensionMgrLib
   # functions that need one.  This script runs before Sysdb starts, so it's
   # not possible to get the "real" Extension::Status.
   status = Tac.newInstance( "Extension::Status", "status" )
   for extensionName, flags, pType, deps in extensionsToLoad:
      depsList = []
      if deps:
         # strip all the excess and get a list of the file names
         depsList = [ x.strip()[ 1:-1 ] for x in deps[ 1:-1 ].split( ',' ) ]
      shouldForce = ( flags and ( 'force' in flags ) )
      path = extensionName

      if path[ 0 ] != '/':
         path = os.path.join( options.dir, path )
      try:
         ExtensionMgrLib.readExtensionFile( path, status,
                                            pType=pType, deps=depsList )
      except Exception as e:
         msg = "Error reading extension file: %s\n" % e
         raise LoadExtensionsError( msg ) # pylint: disable=raise-missing-from
      info = ExtensionMgrLib.latestExtensionForName( extensionName, status )
      if info is None:
         err( "Extension %s is not present\n" % extensionName )
         # create a fake info
         key = Tac.Value( "Extension::InfoKey", extensionName,
                          pType or 'formatUnknown', 1 )
         info = Tac.newInstance( "Extension::Info", key )
         info.presence = "absent"
         info.status = "notInstalled"
         info.boot = bootStatus( flags )
         ExtensionMgrLib.printExtensionInfo( output, info )
         continue

      info.boot = bootStatus( flags )
      try:
         sigValid = _verifySignature( info )
         ExtensionMgrLib.installExtension( status, info, force=shouldForce )
      except errors.InstallError as e:
         err( f"Error installing {extensionName}: {e}\n" )
      finally:
         # pylint: disable-next=used-before-assignment
         ExtensionMgrLib.printExtensionInfo( output, info, sigValid=sigValid )

   if not extensionsToLoad:
      _pcrExtendNoExtension()

   output.close()

def _verifySignature( info ):
   if info.format == 'formatSwix':
      return _verifySwixSignature( info.filepath )
   # Only SWIXs have signature verification
   return None

def _verifySwixSignature( filepath ):
   ''' Verify SWIX signatures using info from the signature-verification
   mapping file, generate a syslog for signature status, and 
   return whether SWIX signature is valid (or None, which means the feature
   is not enabled)'''
   try:
      basename = os.path.basename( filepath )
      sigVerifyFileMgr = SigVerifyMapLib.FileMgr( 'extension' )
      sigVerifyAttrs = sigVerifyFileMgr.readAttrs()
      enabled = sigVerifyAttrs.get( SigVerifyMapLib.FILE_ATTR.ENFORCE_SIGNATURE )
      sslProfile = sigVerifyAttrs.get( SigVerifyMapLib.FILE_ATTR.SSL_PROFILE )
      if enabled == 'True':
         if sslProfile:
            trustedCertsPath = MgmtSslConstants.trustedCertsPath( sslProfile )
            trustedCerts = SslCertKey.getAllPem( trustedCertsPath )
            if not trustedCerts:
               ExtensionMgrLib.log( logs.EXTENSION_SIGNATURE_INVALID, basename, 
                     "Unable to verify signature. No trusted certificates defined." )
               return False
            sigValid, reason, _ = SwiSignLib.verifySwixSignature( filepath,
                                             trustedCerts, rootCAIsFile=False )
            if sigValid:
               ExtensionMgrLib.log( logs.EXTENSION_SIGNATURE_VALID, basename )
               return True
            else:
               ExtensionMgrLib.log( logs.EXTENSION_SIGNATURE_INVALID,
                                    basename, reason )
         else:
            ExtensionMgrLib.log( logs.EXTENSION_SIGNATURE_INVALID, basename, 
                     "Unable to verify signature. SSL profile is unknown." )
      else:
         return None # Feature is not enabled
   except OSError as e:
      ExtensionMgrLib.log( logs.EXTENSION_SIGNATURE_INVALID, basename,
                           "Unable to verify signature: %s" % e )
   return False
