# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import struct
import json
import math
import itertools
import datetime
import argparse
from pprint import pprint
import sys
import re

try:
   import TableOutput
except ImportError:
   TableOutput = None

CURRENT_CTR_FILE_VERSION = 0
CTR_MAGIC = 0xC180C180
CTR_NAME_SECTION_MAGIC = 0xC180C181
CTR_DATA_MAGIC = 0xC180C182
NAME_SECTION_ALL_GROUPS = 0xFFFF
NAME_SECTION_GROUP_NAMES = 0xFF

class VersionError( BaseException ):
   """File version more recent than the reader lib"""

def readVarInt( f ):
   maxVarIntLen = 10
   dataBits = 7
   continuationBit = 1 << dataBits
   dataMask = 0xFF ^ continuationBit

   out = 0
   for p in range( maxVarIntLen ):
      v = ord( f.read( 1 ) )
      out += ( v & dataMask ) << ( dataBits * p )
      if not v & continuationBit:
         break
   else:
      raise ValueError
   return out

class Options:
   SPARSE_RECORD_OPTION = 0x1
   TIMESTAMP_U32_OPTION = 0x2

   def __init__( self, bitmap ):
      self._bitmap = bitmap
      self.sparseRecords = bool( bitmap & self.SPARSE_RECORD_OPTION )
      self.timestampU32 = bool( bitmap & self.TIMESTAMP_U32_OPTION )

class Reader:
   headerFields = (
         ( ( 'I', 'headerMagic' ) ),
         ( ( 'I', 'fileSize' ) ),
         ( ( 'B', 'version' ) ),
         ( ( 'I', 'optionsRaw' ) ),
         ( ( 'B', 'timestampUnit' ) ),
         ( ( 'Q', 'startTimestamp' ) ),
         ( ( 'H', 'numGroups' ) ),
         ( ( 'B', 'recordDimension' ) ) )

   groupInfoFields = (
         ( ( 'H', 'recordSize' ) ),
         ( ( 'I', 'numRecords' ) ),
         )

   nameSectionFields = (
         ( ( 'I', 'magic' ) ),
         ( ( 'H', 'groupId' ) ),
         ( ( 'B', 'dimensionId' ) ),
         ( ( 'H', 'dimensionSize' ) ),
         ( ( 'I', 'dataSize' ) ) )

   def __init__( self, filename ):
      # pylint: disable=consider-using-with
      self.fp = open( filename, 'rb' )
      self.headerMagic = None
      self.fileSize = None
      self.version = None
      self.optionsRaw = None
      self.options = None
      self.timestampUnit = 0
      self.startTimestamp = 0
      self.numGroups = None
      self.recordSize = None
      self.recordDimension = None
      self.numRecords = None
      self.additionalDataRaw = None
      self.additionalData = None
      self.groupNames = None
      self.nameSection = {}

   def readStruct( self, fields ):
      fmt, names = zip( *fields )
      fmt = '<' + ''.join( fmt )
      raw = self.fp.read( struct.calcsize( fmt ) )
      fields = struct.unpack( fmt, raw )
      return dict( zip( names, fields ) )

   def readU32( self ):
      return struct.unpack( '<I', self.fp.read( 4 ) )[ 0 ]

   def parseHeader( self ):
      header = self.readStruct( self.headerFields )
      for fieldName, fieldValue in header.items():
         assert hasattr( self, fieldName ) # sanity check
         setattr( self, fieldName, fieldValue )
      self.options = Options( self.optionsRaw )

      if self.version > CURRENT_CTR_FILE_VERSION:
         raise VersionError

   def parseGroupInfo( self ):
      self.recordSize = []
      self.numRecords = []
      for _ in range( self.numGroups ):
         data = self.readStruct( self.groupInfoFields )
         self.recordSize.append( data[ 'recordSize' ] )
         self.numRecords.append( data[ 'numRecords' ] )

   def parseAdditionalData( self ):
      addDataSize = self.readU32()
      raw = self.fp.read( addDataSize )
      self.additionalDataRaw = raw.decode()
      if self.additionalDataRaw:
         try:
            self.additionalData = json.loads( self.additionalDataRaw )
         except json.decoder.JSONDecodeError:
            self.additionalData = None

   def parseNameSection( self ):
      section = self.readStruct( self.nameSectionFields )
      assert section[ 'magic' ] == CTR_NAME_SECTION_MAGIC
      dataRaw = self.fp.read( section[ 'dataSize' ] )
      data = dataRaw.decode().split( '\n' )
      assert len( data ) == section[ 'dimensionSize' ]
      section[ 'data' ] = data

      if section[ 'dimensionId' ] == NAME_SECTION_GROUP_NAMES:
         assert self.groupNames is None, "The file contains 2 group names sections"
         self.groupNames = section[ 'data' ]
      else:
         if section[ 'groupId' ] == NAME_SECTION_ALL_GROUPS:
            groupIds = range( self.numGroups )
         else:
            groupIds = [ section[ 'groupId' ] ]

         for groupId in groupIds:
            key = ( groupId, section[ 'dimensionId' ] )
            assert key not in self.nameSection, (
                  "The file contains 2 name sections for "
                  f"groupId={groupId}, dimensionId={section[ 'dimensionId' ]}" )
            self.nameSection[ key ] = section

      return section

   def parseNameSections( self ):
      allKeys = set(
            itertools.product(
               range( self.numGroups ), range( self.recordDimension ) ) )
      while set( self.nameSection ) != allKeys:
         self.parseNameSection()

   def parseHeaders( self ):
      self.parseHeader()
      self.parseGroupInfo()
      self.parseAdditionalData()
      self.parseNameSections()

   def dimensionSize( self, groupId, dimensionId ):
      return self.nameSection[ ( groupId, dimensionId ) ][ 'dimensionSize' ]

   def getGroupName( self, groupId ):
      if not self.groupNames:
         return None
      return self.groupNames[ groupId ]

   def getFieldNames( self, groupId, dimensionId ):
      return self.nameSection[ ( groupId, dimensionId ) ][ 'data' ]

   def getData( self, groupId=None ):
      dataPosition = self.fp.tell()
      dataMagic, = struct.unpack( '<I', self.fp.read( 4 ) )
      assert dataMagic == CTR_DATA_MAGIC
      sparse = self.options.sparseRecords
      bitmapSize = [
         math.ceil( recordSize / 8 )
         for recordSize in self.recordSize
      ]
      if self.options.timestampU32:
         keyFmt = '<IH'
      else:
         keyFmt = '<QH'
      while True:
         raw = self.fp.read( struct.calcsize( keyFmt ) )
         if not raw:
            break
         timestamp, gid = struct.unpack( keyFmt, raw )
         if sparse:
            bitmap = self.fp.read( bitmapSize[ gid ] )
         values = []
         for i in range( self.recordSize[ gid ] ):
            if sparse:
               idx, bit = divmod( i, 8 )
               if not bitmap[ idx ] & ( 1 << bit ):
                  values.append( 0 )
                  continue
            values.append( readVarInt( self.fp ) )
         if groupId is not None and groupId != gid:
            continue
         yield timestamp, gid, values
      self.fp.seek( dataPosition ) # reset the fp so we can call getData again

def maskList( lst, mask ):
   if mask is None:
      return lst
   assert len( lst ) == len( mask )
   return [ e for ( e, masked ) in zip( lst, mask ) if masked ]

def fmtTimestampUnit( unit ):
   """Convert the timestamp unit to a human friendly string
   Examples: 9 -> 'ns' ; 8-> 'x10ns' ; 7 +> 'x100ns' ; 6 -> 'us'
   """
   roundUpUnit = math.ceil( unit / 3 ) * 3
   subUnit = roundUpUnit - unit
   multiplier = ( '', 'x10', 'x100' )[ subUnit ]
   siUnit = { 0 : 's', 3 : 'ms', 6 : 'us', 9 : 'ns' }[ roundUpUnit ]
   return f"{multiplier}{siUnit}"

class CliHelper:
   def __init__( self ):
      self.args = None
      self.reader = None

   def parseArgs( self, cliArgs ):
      parser = argparse.ArgumentParser(
         description="Read .ctr files generated by AleCountersBinaryDump::Writer" )
      parser.add_argument( 'file', help='read specified file' )

      group = parser.add_mutually_exclusive_group()
      group.add_argument( '--json', action='store_true',
            help="print the headers as a json dictionnary" )
      group.add_argument( '--table', action='store_true',
            help="print the records, formatted in a table" )
      group.add_argument( '--csv', action='store_true',
            help="print the records, formatted in CSV" )

      group = parser.add_mutually_exclusive_group()
      group.add_argument( '-g', '--group-regex',
            help="only display the groups matching GROUP_REGEX" )
      group.add_argument( '-G', '--group-id',
            type=int, help="only display group GROUP_ID" )
      parser.add_argument( '-f', '--field-regex',
            help="only display the fields matching FIELD_REGEX" )
      parser.add_argument( '--start', type=int,
            help="only display the records whose timestamp is equal or above START" )
      parser.add_argument( '--stop', type=int,
            help="only display the records whose timestamp is equal or below STOP" )
      self.args = parser.parse_args( cliArgs )

   def getHeadersData( self ):
      out = {}
      dt = datetime.datetime.fromtimestamp(
            self.reader.startTimestamp * 10**( -self.reader.timestampUnit ) )
      out[ 'startTime' ] = dt.isoformat( ' ' )
      out[ 'version' ] = self.reader.version
      out[ 'timestampUnit' ] = f'1e-{self.reader.timestampUnit}'
      out[ 'numGroups' ] = self.reader.numGroups
      out[ 'recordDimension' ] = self.reader.recordDimension
      out[ 'groupInfo' ] = [ {
         'name' : self.reader.groupNames[ groupId ]
         if self.reader.groupNames else '',
         'recordSize' : self.reader.recordSize[ groupId ],
         'dimensionSize' : [
            self.reader.dimensionSize( groupId, dimensionId )
            for dimensionId in range( self.reader.recordDimension ) ],
         'numRecords' : self.reader.numRecords[ groupId ],
         'fieldNames' : {
            dimensionId : self.reader.getFieldNames( groupId, dimensionId )
            for dimensionId in range( self.reader.recordDimension ) } }
         for groupId in range( self.reader.numGroups ) ]
      out[ 'additionalData' ] = self.reader.additionalData
      return out

   def printHeaders( self, headers ):
      print( f"Start time: {headers[ 'startTime' ]}" )
      print( f"Timestamp unit: {fmtTimestampUnit(self.reader.timestampUnit)}" )
      print( f"Number of groups: {headers[ 'numGroups' ]}" )
      print( f"Record dimension: {headers[ 'recordDimension' ]}" )

      def printGroupInfo( groupInfo, singleGroup=False ):
         indent = '' if singleGroup else '  '
         print( f"{indent}Record size: {groupInfo['recordSize']}" )
         print( f"{indent}Number of records: {groupInfo['numRecords']}" )
         for dimId, names in groupInfo[ 'fieldNames' ].items():
            print( f"{indent}Field names for dimension {dimId}: "
                   f"{', '.join( names ) }" )
         if not singleGroup:
            print()

      if len( headers[ 'groupInfo' ] ) == 1:
         printGroupInfo( headers[ 'groupInfo' ][ 0 ], True )
      else:
         if any( group[ 'name' ] for group in headers[ 'groupInfo' ] ):
            groupNames = [ group[ 'name' ] or f'Group {i}'
                           for ( i, group ) in enumerate( headers[ 'groupInfo' ] ) ]
            print( f"Groups: {', '.join( groupNames ) }" )
         for groupId, groupInfo in enumerate( headers[ 'groupInfo' ] ):
            groupName = groupInfo[ 'name' ] or f"Group {groupId}"
            print( f"{groupName}:" )
            printGroupInfo( groupInfo )
      if headers[ 'additionalData' ]:
         print( 'Additional data:' )
         pprint( headers[ 'additionalData' ] )
      print( "To display the records, use --csv or --table" )

   def showHeaders( self ):
      dictOut = self.getHeadersData()
      if self.args.json:
         print( json.dumps( dictOut, indent=2 ) )
      else:
         self.printHeaders( dictOut )

   def getGroupsToShow( self ):
      allGroups = [
            ( groupId, self.reader.getGroupName( groupId ) or f"Group {groupId}" )
            for groupId in range( self.reader.numGroups ) ]

      if self.args.group_id is not None:
         if self.args.group_id > self.reader.numGroups:
            sys.exit(
                  f"Valid values for group-id are 0-{self.reader.numGroups - 1}" )
         return [ allGroups[ self.args.group_id ] ]

      if self.args.group_regex:
         try:
            groupRegex = re.compile( self.args.group_regex )
         except re.error as e:
            sys.exit( f"Invalid regex for group ({e.msg})" )

         groups = [ group for group in allGroups
                    if groupRegex.search( group[ 1 ] ) ]
         if not groups:
            sys.exit( f"No matching group for regex '{self.args.group_regex}'" )
         return groups

      return allGroups

   def printGroupRecords( self, groupId, groupName, multiGroup, fieldRegex ):
      if multiGroup and not self.args.json:
         print( f"{groupName}:" )

      allFields = [ ':'.join( fields ) for fields in itertools.product(
            *[ self.reader.getFieldNames( groupId, dimId ) for dimId
               in range( self.reader.recordDimension ) ] ) ]
      if fieldRegex:
         fieldMask = [ fieldRegex.search( field ) for field in allFields ]
         if not any( fieldMask ):
            print()
            return
      else:
         fieldMask = None
      headers = [ f'Time({fmtTimestampUnit(self.reader.timestampUnit)})' ] + \
            maskList( allFields, fieldMask )

      if self.args.csv:
         print( ','.join( headers ) )
      else:
         table = TableOutput.createTable( headers )
         fTime = TableOutput.Format( isHeading=True )
         table.formatColumns( fTime )

      for timestamp, _, values in self.reader.getData( groupId=groupId ):
         if self.args.start is not None and timestamp < self.args.start:
            continue
         if self.args.stop is not None and timestamp > self.args.stop:
            continue
         if self.args.csv:
            print( ','.join(
               ( str( n ) for n in [
                  timestamp ] + maskList( values, fieldMask ) ) ) )
         else:
            table.newRow( *( [ timestamp ] + maskList( values, fieldMask ) ) )

      if self.args.csv:
         if multiGroup:
            print()
      else:
         print( table.output() )

   def printGroups( self, fieldRegex ):
      groups = self.getGroupsToShow()
      for groupId, groupName in groups:
         self.printGroupRecords( groupId, groupName, len( groups ) > 1, fieldRegex )

   def run( self, args ):
      self.parseArgs( args )

      self.reader = Reader( self.args.file )
      self.reader.parseHeaders()

      if not self.args.table and not self.args.csv:
         self.showHeaders()
         return

      if self.args.table and not TableOutput:
         sys.exit(
               "TableOutput library missing. "
               "Please add it to the system, or use --csv"
         )

      if self.args.field_regex:
         try:
            fieldRegex = re.compile( self.args.field_regex )
         except re.error as e:
            sys.exit( f"Invalid regex for field ({e.msg})" )
      else:
         fieldRegex = None

      self.printGroups( fieldRegex )
