#!/usr/bin/env arista-python
# Copyright (c) 2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import fileinput, re
import functools

categories = [
   ( "Strings", [
     ( "<c-string>", "C strings" ),
     ( "<c++-string>", "C++ strings" ),
   ] ),
   ( "Introspection objects", [
     ( "_Type", "_Type objects" ),
     ( "Other Introspection", "others" ),
   ] ),
   ( "GenericIf objects", [
     ( "GenericIf", "" ),
   ] ),
   ( "TacNboAttrLogSa objects", [
     ( "TacNboAttrLogSa", "" ),
   ] ),
   ( "Other C++ objects", [
     ( "::", "" ),
     ( "Other C++", "no namespace" ),
   ] ),
   ( "Python objects", [
     ( "<python-arena>", "arenas" ),
     ( "Other Python", "individual objects" ),
     ( "<python-aux>", "auxiliary data" ),
   ] ),
   ( "Unknown", [
     ( "<unknown>", "" ),
   ] ),
   ( "Free space", [
     ( "<free>", "free chunks" ),
     ( "<top>", "top chunks" ),
   ] ),
]

allSubcats = list( subcat for ( cat, subcats ) in categories
                          for ( subcat, desc ) in subcats )
freeSubcats = list( subcat for ( subcat, desc ) in categories[ -1 ][ 1 ] )

def main():
   namespaces = set()
   d = {}

   for line in fileinput.input():
      for subcat in allSubcats:
         if subcat in line:
            if subcat == "::":
               namespace = re.search( r"::\w+", line ).group()
               namespaces.add( namespace )
               subcat = namespace
            size = int( line.strip().split()[ -2 ] )
            pss = int( line.strip().split()[ -1 ] )
            d.setdefault( subcat, [ 0, 0 ] )
            d[ subcat ][ 0 ] += size
            d[ subcat ][ 1 ] += pss
            break

   cPlusPlusSubcats = categories[ 4 ][ 1 ]
   assert ( "::", "" ) in cPlusPlusSubcats
   for namespace in sorted( namespaces, reverse=True ):
      cPlusPlusSubcats.insert( 0, ( namespace, namespace[ 2: ] ) )

   def sums( tuples ):
      return tuple( functools.reduce(
         lambda a, b: list( map( sum, zip( a, b ) ) ), tuples ) )

   formatStr = "%-25s %12s %12s"
   print( formatStr % ( "Category", "Size", "Pss" ) )
   print( "=" * 51 )

   def printCategory( cat, subcats ):
      s = sums( d[ subcat ] for ( subcat, desc ) in subcats if subcat in d )
      print( formatStr % ( ( cat + ":", ) + s ) )
      if len( subcats ) > 1:
         for subcat, desc in subcats:
            if subcat in d:
               print( formatStr % ( ( " => " + desc, ) + tuple( d[ subcat ] ) ) )

   first = True
   for cat, subcats in categories[ :-1 ]:
      if not first:
         print( "-" * 51 )
      first = False
      printCategory( cat, subcats )

   print( "=" * 51 )
   print( formatStr % ( ( "SUBTOTAL", ) + sums(
      # pylint: disable-next=consider-using-dict-items
      d[ subcat ] for subcat in d if subcat not in freeSubcats ) ) )

   print( "-" * 51 )
   printCategory( *categories[ -1 ] )

   print( "=" * 51 )
   print( formatStr % ( ( "TOTAL", ) + sums( list( d.values() ) ) ) )

if __name__ == "__main__":
   main()

