#!/usr/bin/env python3
# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import datetime as dt
import argparse
import os
import re
import sys
import subprocess
import glob

import Tac

# Default settings
defaultLogFileDir = "/mnt/flash/"
defaultLogFileName = "autobw-Rsvp.log"
defaultQtDir = "/var/log/qt/"
defaultQtFileName = "autobw-Rsvp.qt"
defaultQtLevels = "0"
defaultLogFileRotateSize = 10000000 # 10 MB
defaultNumOfLogFiles = 10
defaultPattern = r"^\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d+ \d+ 0[xX][0-9a-fA-F]+, "\
                 r"\+{0,1}\d+ \"(Tunnel|hMean): "
# Constants
checkpointFileName = ".autobw-Rsvp-checkpoint.txt"

def parseArgs( *args ):
   parser = argparse.ArgumentParser(
         description="append quicktrace file to a log file" )
   parser.add_argument( 'qtFile', nargs='?', help="full path of quicktrace file",
         default=os.path.join( defaultQtDir, defaultQtFileName ),
         metavar='<quicktrace file>' )
   parser.add_argument( 'logFile', nargs='?', help="full path of log file",
         metavar='<log file>' )
   parser.add_argument( '-n', '--log-file-num',
                     dest='numOfLogFiles', type=int,
                     help='number of kept log files',
                     metavar='NUM',
                     default=defaultNumOfLogFiles )
   parser.add_argument( '-s', '--log-file-size',
                     dest='logFileRotateSize', type=int,
                     help='log file size cap',
                     metavar='BYTES',
                     default=defaultLogFileRotateSize )
   parser.add_argument( '-d', '--log-file-dir',
                     dest='logFileDir',
                     help='log file directory (not used when <log file> is set)',
                     metavar='DIRECTORY',
                     default=defaultLogFileDir )
   parser.add_argument( '-f', '--log-file-name',
                     dest='logFileName',
                     help='log file name (not used when <log file> is set)',
                     metavar='FILENAME',
                     default=defaultLogFileName )
   parser.add_argument( '-l', '--levels',
                     dest='levels',
                     help='quicktrace levels',
                     default=defaultQtLevels )
   parser.add_argument( '-p', '--pattern',
                     dest='pattern',
                     help='quicktrace filter regexp pattern',
                     default=defaultPattern )
   parsedArgs = parser.parse_args( *args )

   # determine log file
   if parsedArgs.logFile:
      logFile = parsedArgs.logFile
      logFileDir = os.path.dirname( logFile )
      logFileName = os.path.basename( logFile )
      if not logFileName:
         logFileName = defaultLogFileName
   else:
      logFileDir = parsedArgs.logFileDir
      logFileName = parsedArgs.logFileName

   # determine qt levels
   if parsedArgs.levels == "all":
      qtLevels = None
   else:
      qtLevels = parsedArgs.levels

   return ( parsedArgs.qtFile, logFileDir, logFileName,
            parsedArgs.logFileRotateSize, parsedArgs.numOfLogFiles, qtLevels,
            parsedArgs.pattern )

def main():
   args = parseArgs()
   err = appendQuickTrace( *args )
   if err is not None:
      sys.stderr.write( err )

def _readQt( qtFile, qtLevels ):
   cmd = [ "qtcat", "--tsc" ]
   if qtLevels is not None:
      cmd += [ "-l", qtLevels ]
   cmd.append( qtFile )
   # pylint: disable-next=consider-using-with
   qtrace = subprocess.Popen( cmd, bufsize=-1, stdin=subprocess.PIPE,
         stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True,
         text=True )
   qtrace.stdin.close() # in case qtcat gets into pdb
   return qtrace

def getTimestampFromTraceLine( line ):
   try:
      date, logTime = line.split()[ : 2 ]
      timestamp = float(
         dt.datetime.strptime( date + " " + logTime, "%Y-%m-%d %H:%M:%S.%f" ).
         strftime( '%s.%f' ) )
   except ValueError:
      timestamp = None
   return timestamp

def qtToLogFormat( qtLine ):
   # Quicktrace entries 2, 3 and 4 are: level, time in seconds since start
   # of the trace, time in ticks since previous message.
   # We don't need them in the log
   line = qtLine.split()
   formattedLine = " ".join( line[ : 2 ] + line[ 5 : ] )
   return formattedLine

class LogWriter:
   def __init__( self, logFile, logFileRotateSize, numOfLogFiles ):
      self.fname = logFile
      # pylint: disable-next=consider-using-with
      self.currFile = open( logFile, 'a' )
      self.logFileRotateSize = logFileRotateSize
      self.numOfLogFiles = numOfLogFiles

   def close( self ):
      self.currFile.flush()
      os.fsync( self.currFile.fileno() )
      return self.currFile.close()

   def backupCurrentLog( self ):
      """Back up the current log file by closing it, renaming and compressing it,
      and reopening a new log file."""
      # close the file as we're about to replace it
      self.close()

      # rename file, deduplicating if necessary
      rotatedLogFile = self.fname + "." + dt.datetime.now().strftime( "%Y-%m-%d_%s" )
      rotatedLogFilePrefix = rotatedLogFile
      dedupIndex = 1
      while os.path.exists( rotatedLogFile + '.gz' ):
         rotatedLogFile = "%s.%d" % ( rotatedLogFilePrefix, dedupIndex )
         dedupIndex += 1
      os.rename( self.fname, rotatedLogFile )

      # compress file
      Tac.run( [ "gzip", rotatedLogFile ], ignoreReturnCode=True,
               stdout=Tac.DISCARD )

      # reopen (recreate) log
      # pylint: disable-next=consider-using-with
      self.currFile = open( self.fname, 'a' )

   def deleteOldLogs( self ):
      """Delete old zipped log files, per numOfLogFiles"""
      zippedLogFiles = sorted( glob.glob( "%s*gz" % self.fname ),
                               reverse=True, key=os.path.getmtime )
      if zippedLogFiles:
         for oldFile in zippedLogFiles[ self.numOfLogFiles : ]:
            try:
               os.remove( oldFile )
            except OSError:
               pass

   def writeLine( self, line ):
      """
      Write a line to the current file, rotating it if it reaches the max allowed
      size.
      """
      self.currFile.flush() # Write out the file before we compare size
      if os.path.getsize( self.fname ) >= self.logFileRotateSize:
         self.backupCurrentLog()
         self.deleteOldLogs()
      self.currFile.write( line + '\n' )

def appendQuickTrace( qtFile, logFileDir, logFileName,
      logFileRotateSize=defaultLogFileRotateSize,
      numOfLogFiles=defaultNumOfLogFiles,
      qtLevels=defaultQtLevels,
      pattern=defaultPattern ):
   if not os.path.exists( qtFile ):
      return qtFile + " does not exist\n"
   if not os.path.exists( logFileDir ):
      return logFileDir + " does not exist\n"
   logFile = os.path.join( logFileDir, logFileName )
   checkpointFile = os.path.join( logFileDir, checkpointFileName )
   rePattern = re.compile( pattern )

   lastTs = None
   lastTsc = None
   lastLineContent = None

   # Try to get the last logged line data from the checkpoint file
   try:
      with open( checkpointFile ) as f:
         lastLine = f.readline().strip()
      lastTs = getTimestampFromTraceLine( lastLine )
      # Checkpoint contains last quicktrace line copied. The fourth column should
      # contain the CPU cycle counter value, which is unique between trace lines.
      lastLineCols = lastLine.split()
      lastTsc = lastLineCols[ 3 ] if len( lastLineCols ) >= 4 else None
   except OSError:
      pass

   # If the checkpoint file was invalid or missing, look at the last line in the log
   # instead. Since the log doesn't include the CPU cycle counter we will try to use
   # the contents of the trace as a way to identify the associated quicktrace line,
   # even though this is not guaranteed to be unique.
   checkpointInvalid = lastTs is None or lastTsc is None
   if checkpointInvalid and os.path.exists( logFile ):
      # Reset lastTsc, since it shouldn't be used
      lastTsc = None
      # Get last timestamp in the logfile
      lastLine = Tac.run( [ "tail", "-n", "-1", logFile ],
                          stdout=Tac.CAPTURE ).strip()
      lastTs = getTimestampFromTraceLine( lastLine )
      # Everything after the 2nd column is the trace content for log lines
      lastLineContent = lastLine.split()[ 2 : ]

   # read qt, append new lines to logfile
   qtrace = _readQt( qtFile, qtLevels )
   qtLine = qtrace.stdout.readline()
   lastQtLine = None # Last qtLine written
   logWriter = LogWriter( logFile, logFileRotateSize, numOfLogFiles )

   # If we have a valid lastTs, fast forward to any line that has TS >= lastTs.
   # # Lines where the timestamp == lastTs are cached, since they may be new.
   if lastTs is not None:
      qtLinesSameTime = []
      while qtLine:
         currentLineTs = getTimestampFromTraceLine( qtLine )
         if not rePattern.match( qtLine ) or currentLineTs is None:
            # Skip, we don't care about this line
            pass
         elif currentLineTs == lastTs:
            qtLinesSameTime.append( qtLine )
         elif currentLineTs > lastTs:
            break
         qtLine = qtrace.stdout.readline()
      # If there were lines with same timestamp as the last logged line,
      # write out those following the last logged line. If there are multiple
      # potential last lines (which can occur if the checkpoint file was invalid),
      # skip past the *last* potential last line to avoid potentially duplicating
      # any log lines.
      firstNewLogIdx = 0
      for i in range( len( qtLinesSameTime ) - 1, -1, -1 ):
         qtLineCols = qtLinesSameTime[ i ].split()
         # If we know the last tsc, we can use that as an identifier as it is
         # unique. Otherwise, use the last line content.
         if lastTsc is not None and lastTsc == qtLineCols[ 3 ]:
            firstNewLogIdx = i + 1
            break
         if lastTsc is None and lastLineContent == qtLineCols[ 5 : ]:
            firstNewLogIdx = i + 1
            break
      for line in qtLinesSameTime[ firstNewLogIdx : ]:
         lastQtLine = line
         formattedLine = qtToLogFormat( line )
         logWriter.writeLine( formattedLine )

   while qtLine:
      if rePattern.match( qtLine ):
         lastQtLine = qtLine
         formattedLine = qtToLogFormat( qtLine )
         logWriter.writeLine( formattedLine )
      qtLine = qtrace.stdout.readline()

   logWriter.close()
   if lastQtLine:
      with open( checkpointFile, 'w' ) as f:
         f.write( lastQtLine )

   err = qtrace.stderr.readlines()
   qtrace.wait()
   if qtrace.returncode:
      return "qtcat terminated with error: " + err[ 0 ]

if __name__ == "__main__":
   main()
