#!/usr/bin/env python3
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from TableOutput import createTable
from CliModel import Int, Str, Dict, Model

class ShowAllocsTypeModel( Model ):
   size = Int( help='Size of type in bytes' )
   memoryAllocationOverhead = Int( help='Memory allocation overhead in bytes' )
   currentAllocations = Int( help='Number of objects currently allocated' )
   totalAllocations = Int( help='Total number of allocations (calls to new)' )
   highestAllocations = Int( help='High watermark of currentAllocations' )

class ShowAllocsTypeDiffModel( Model ):
   size = Int( help='Size of type in bytes' )
   memoryAllocationOverhead = Int( help='Memory allocation overhead in bytes' )
   prevAllocations = Int( help='Number of objects allocated when the command was '
                               'run previously' )
   currentAllocations = Int( help='Number of objects currently allocated' )
   deltaAllocations = Int( help='Difference in Number of object allocations' )

class ShowAllocsBase( Model ):
   _sortOrder = Str( help='Order in which type info should be rendered',
                     default='typeName' )
   _typePrefix = Str( optional=True, help='Type name prefix to skip in display' )
   _maxNameLen = Int( help='Maximum number of type name characters to render',
                      default=100 )
   _tableWidth = Int( help='Character width of rendered table', default=200 )
   _limit = Int( help='Number of type entries to display', default=0 )

   def _displayTypeName( self, typeName ):
      if self._typePrefix is not None and typeName.startswith( self._typePrefix ):
         typeName = typeName[ len( self._typePrefix ) : ]
      return typeName[ : self._maxNameLen ]

   def _formatRowMemory( self, row, memoryFields ):
      return [ self._formatMemory( x ) if i in memoryFields else x
               for i, x in enumerate( row ) ]

   def _formatMemory( self, size, precision=2 ):
      suffixes = [ 'B ', 'KB', 'MB', 'GB', 'TB' ]
      suffixIndex = 0
      s = size
      while s >= 1024 and suffixIndex < len( suffixes ) - 1:
         suffixIndex += 1
         s = s / 1024.0
      # pylint: disable-next=consider-using-f-string
      return '%.*f %s' % ( precision, s, suffixes[ suffixIndex ] )

   def _getTotalAllData( self, ncols ):
      if not hasattr( self, "_allTypes" ):
         return []
      totalSumRows = [ self._getRow( typeName, typeInfo )[ 2 : ]
                       for typeName, typeInfo in self._allTypes.items() ]
      totalsRow = [ 'TOTAL', '' ]
      if totalSumRows:
         totalsRow += [ sum( x ) for x in zip( *totalSumRows ) ]
      else:
         totalsRow += [ 0 ] * ncols
      return totalsRow

   def _getRow( self, typeName, typeInfo ):
      return []

   @staticmethod
   def getSortFunction( sortOrder ):
      return lambda x: x

   @staticmethod
   def sortWhenTransient():
      return True

   @classmethod
   def sortTypes( cls, types, limit, sortOrder ):
      # sorting/limiting isn't really done for transient memory - it's used only for
      # the delta mechanism, where sorting is not relevant for the purposes of
      # diffing two ShowAllocsModel objects. So, if sortOrder ==
      # 'transientMemory', we just return all types without sorting/limiting
      # in the case of an ordinary ShowAllocsModel object (i.e.
      # where we don't use 'delta'). Where 'delta' is used, we have just one
      # measurement really - the measurement on which we used 'delta' - so there
      # we just ignore if sortOrder == 'transientMemory' and sort on the delta
      # measurement
      if sortOrder == 'transientMemory' and not cls.sortWhenTransient():
         return types

      limit = limit if limit else None
      sortedTypes = sorted( types.items(),
                            key=cls.getSortFunction( sortOrder ) )[ : limit ]

      return sortedTypes

   @staticmethod
   def getSortOrderDefault():
      return ShowAllocsBase._sortOrder.default

   @staticmethod
   def getLimitDefault():
      return ShowAllocsBase._limit.default

class ShowAllocsDiffModel( ShowAllocsBase ):
   types = Dict( keyType=str, valueType=ShowAllocsTypeDiffModel,
                 help='A mapping of type name to its memory usage information' )
   _allTypes = Dict( keyType=str, valueType=ShowAllocsTypeDiffModel,
                     help='All types, non-limited by the limit command' )

   @staticmethod
   def getSortFunction( sortOrder ):
      def change( t ):
         return -t[ 1 ].deltaAllocations
      return change

   def _getRow( self, typeName, typeInfo ):
      return [ self._displayTypeName( typeName ),
               # pylint: disable-next=consider-using-f-string
               '%d + %2d' % ( typeInfo.size,
                              typeInfo.memoryAllocationOverhead ),
               typeInfo.prevAllocations,
               typeInfo.currentAllocations,
               typeInfo.deltaAllocations
               ]


   def render( self ):
      transient = { 'currentMemory': False,
                    'transientMemory': True }[ self._sortOrder ]

      sortedTypes = ShowAllocsDiffModel.sortTypes( self.types, self._limit,
                                                   self._sortOrder )

      # Prepare sorted data for display
      hdr = ( ( 'type', 'l' ),
              ( 'size +', 'r', ( 'overhead', ) ),
              ( 'transient' if transient else 'current', 'c',
                ( 'begin', 'end', 'delta' ) ) )
      table = createTable( hdr, tableWidth=self._tableWidth )
      rows = [ self._getRow( typeName, typeInfo )
               for typeName, typeInfo in sortedTypes ]
      for row in rows:
         table.newRow( *row )
      table.newRow()
      if len( rows ) != len( self._allTypes ):
         totalsRow = [ 'TOTAL displayed', '' ]
         totalsRow += [ sum( x ) for x in zip( *[ r[ 2 : ] for r in rows ] ) ]
         table.newRow( *totalsRow )
      table.newRow( *self._getTotalAllData( 3 ) )
      print( table.output() )

class ShowAllocsModel( ShowAllocsBase ):
   types = Dict( keyType=str,
                 valueType=ShowAllocsTypeModel,
                 help='A mapping of type name to its memory usage information' )
   _allTypes = Dict( keyType=str, valueType=ShowAllocsTypeModel,
                     help='All types, non-limited by the limit command' )

   def _getRow( self, typeName, typeInfo ):
      return [ self._displayTypeName( typeName ),
               # pylint: disable-next=consider-using-f-string
               '%d + %2d' % ( typeInfo.size, typeInfo.memoryAllocationOverhead ),
               typeInfo.totalAllocations,
               typeInfo.currentAllocations,
               typeInfo.currentAllocations * ( typeInfo.size +
                                               typeInfo.memoryAllocationOverhead ),
               typeInfo.highestAllocations,
               typeInfo.highestAllocations * ( typeInfo.size +
                                               typeInfo.memoryAllocationOverhead )
               ]

   @staticmethod
   def sortWhenTransient():
      return False

   @staticmethod
   def getSortFunction( sortOrder ):
      def _sortSize( t ):
         return t[ 1 ].size + t[ 1 ].memoryAllocationOverhead

      def _sortCurrent( t ):
         return ( -1 * _sortSize( t ) * t[ 1 ].currentAllocations, t[ 0 ] )

      def _sortHighest( t ):
         return ( -1 * _sortSize( t ) * t[ 1 ].highestAllocations, t[ 0 ] )

      sortFunctions = {
         'typeName': lambda t: t[ 0 ],
         'currentMemory': _sortCurrent,
         'highWatermarkMemory': _sortHighest,
         'totalAllocations': lambda t: -t[ 1 ].totalAllocations,
      }

      return sortFunctions[ sortOrder ]

   def render( self ):
      sortedTypes = ShowAllocsModel.sortTypes( self.types, self._limit,
                                               self._sortOrder )

      table = createTable( ( ( 'type', 'l' ),
                             ( 'size +', 'r', ( 'overhead', ) ),
                             ( 'total', 'c', ( 'allocations', ) ),
                             ( 'current', 'c', ( 'count', 'memory' ) ),
                             ( 'high watermark', 'c', ( 'count', 'memory' ) ) ),
                           tableWidth=self._tableWidth )
      rows = [ self._getRow( typeName, typeInfo )
               for typeName, typeInfo in sortedTypes ]
      for row in rows:
         table.newRow( *self._formatRowMemory( row, [ 4, 6 ] ) )
      table.newRow()
      if len( rows ) != len( self._allTypes ):
         # When we display only a subset of rows, include a total of
         # the displayed rows in addition to the total of all rows.
         totalsRow = [ 'TOTAL displayed', '' ]
         totalsRow += [ sum( x ) for x in zip( *[ r[ 2 : ] for r in rows ] ) ]
         table.newRow( *self._formatRowMemory( totalsRow, [ 4, 6 ] ) )
      table.newRow( *self._formatRowMemory( self._getTotalAllData( 5 ), [ 4, 6 ] ) )
      print( table.output() )
