#!/usr/bin/env python3
# Copyright (c) 2018 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-with

import os
import subprocess
import argparse
import Tac
import Smash
import SmashCoreUtils

description = '''This script extracts the Smash files for the specified collection
path from the given core file. Unless specified otherwise, the files will be
extracted to the shared-memory rootpath for the current directory.

For example:
   # smashrevive -c /var/core/core.8520.Rib -p ar/Smash/routing/status
   # smash -x ar/Smash/routing/status'''

parser = argparse.ArgumentParser(
   description=description,
   formatter_class=argparse.RawDescriptionHelpFormatter )
parser.add_argument( "-c", "--core", dest="coreFilePath", required=True )
parser.add_argument( "-p", "--path", dest="collectionPath", required=True )
parser.add_argument( "-o", "--output", dest="outputDir" )

args = parser.parse_args()
coreFilePath = args.coreFilePath
collectionPath = args.collectionPath
outputDir = args.outputDir

# By default, put the files in the shared-memory rootpath for the current directory,
# so that the user can immediately run the other Smash tools such as "smash" and
# "smashview".
if not outputDir:
   rootPath = Smash.rootPath()
   outputDir = rootPath + '/' + collectionPath

class SmashFile:
   def __init__( self, fileName ):
      self.fileName = fileName
      self.filePath = collectionPath + '/' + fileName
      self.virtualAddr = 0
      self.filePosition = 0
      self.fileLength = 0

files = { fileName : SmashFile( fileName ) for fileName in SmashCoreUtils.tables() }

# Get the list of virtual addresses for each Smash file.
readelfNotes = subprocess.Popen( [ 'readelf', '-n', coreFilePath ],
                                 stdout=subprocess.PIPE )
lastLine = ""
for line in readelfNotes.stdout.readlines():
   for f in files.values() :
      if f.filePath in line:
         f.virtualAddr = lastLine.split()[ 0 ]
   lastLine = line

# Find the LOAD lines that have virtual addresses matching our Smash files.
readelfHeaders = subprocess.Popen( [ 'readelf', '-W', '-l', coreFilePath ],
                                   stdout=subprocess.PIPE )
for line in readelfHeaders.stdout.readlines():
   split = line.split()
   if not split or split[ 0 ] != "LOAD":
      continue
   for f in files.values():
      if f.virtualAddr == split[ 2 ]:
         # Get the position of this Smash file and its length.
         f.filePosition = split[ 1 ]
         f.fileLength = split[ 4 ]

# Now that we have all the metadata we need, copy all the Smash files from the core.
core = open( coreFilePath )
for f in files.values():
   core.seek( int( f.filePosition, 16 ) )
   if not os.path.isdir( outputDir ):
      os.makedirs( outputDir )
   smashFilePath = outputDir + '/' + f.fileName
   smashFile = open( smashFilePath, 'w' )
   content = core.read( int( f.fileLength, 16 ) )
   assert content
   smashFile.write( content )
   smashFile.close()
   print( 'Successfully extracted', f.fileName, 'to', smashFilePath )

# Find the fenixId of the Smash collection and create a matching "fenix" file.
smashX = subprocess.Popen( [ 'smash', '-x', collectionPath ],
                           stdout=subprocess.PIPE )
fenixId = 0
for line in smashX.stdout.readlines():
   if 'Fenix:' in line:
      newFenixId = line.split()[ 1 ]
      if fenixId:
         assert newFenixId == fenixId, 'Smash files have different fenixIds'
      fenixId = newFenixId
fenixFilePath = outputDir + '/fenix'
fenixFile = open( fenixFilePath, 'w' )
fenixFile.write( str( fenixId ) )
fenixFile.close()
print( 'Successfully created fenix at', fenixFilePath )
