#!/usr/bin/env python3
# Copyright (c) 2015 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import argparse
from collections import OrderedDict
import PyClient
import re
import TacMarco
import Tac

#
# This script establishes a pyclient connection to a set of specified
# EOS agents to retrieve Smash collections info such as owner, counters and
# mount status.
#

def getSmashEntities( root ):
   ''' Returns a list of mounted Smash entities from the given root '''

   if root.eval( '.tacType.fullTypeName' ) != 'Tac::Dir':
      # found the entity
      return [ root ]

   # walk the directory
   mounts = []
   for d in root.values():
      mounts.extend( getSmashEntities( d ) )

   return mounts

def queryAgent( rootName, agentName, thread ):
   ''' open pyclient connection to an agent and query for smash collection info '''

   # connect to agent
   pc = PyClient.PyClient( rootName, agentName )

   # collect our shocket info
   shocketState = ''
   shmemEm = pc.root()[ rootName ][ agentName ][ agentName ].shmemEm
   if not thread:
      shocketState = shmemEm.eval( '.dumpState()' )
   else:
      threadCmd = f'.socketName(\'{thread}\')'
      socketName = shmemEm.eval( threadCmd )
      if not socketName:
         currentThreads = shmemEm.eval( '.socketNamesMap()' )
         shocketState = ( f'thread name: {thread} could not be found\n'
                          f'registered: {currentThreads}' )
      else:
         debugHelper = Tac.newInstance( 'TacSharedMem::DebugHelper' )
         shocketState = debugHelper.dumpStateForSocket( socketName ) 

      # bypass dumping the smash counters and mounts belonging to the main thread
      # since we are interested in the non-main thread
      return {}, shocketState

   # collect info about our smash counters and mounts
   smashInfo = {}
   # the smash root point
   smashRoot = pc.root()[ rootName ].get( 'Shmem' )
   if not smashRoot:
      # no smash mounts on this agent, we are done here!
      return smashInfo, shocketState

   # get all the smash mounts for the agent
   entities = getSmashEntities( smashRoot )
   for entity in entities:
      # figure out smash collections from the entity attributes
      controlLen = len( 'Control' )
      controls = [ s for s in entity.attributes if s.endswith( 'Control' ) and
                   s[ :-controlLen ] in entity.attributes ]

      for control in controls:
         cc = '.' + control + '.'

         # retrieve root path
         rootPath = entity.eval( cc + 'rootPath()' )

         # query mount status
         mountStatus = entity.eval( cc + 'mountStatus' )

         # query owner and retrieve pid portion
         owner = entity.eval( cc + 'owner()' )
         ownerPid = int( owner.split( ':' )[ 1 ] )

         # retrieve the collection counters
         # counters are different for reader and writer, and we expose 2 functions
         # to be able to retrieve them. Unfortunately, invoking the wrong function
         # makes the remote agent to throw an internal 'unimplemented' exception
         # which is caught by the PyServer activity. This is not nice, so we must
         # be careful to which function to call, if the collection is 'connected'
         # we can infer that is an active reader. If 'attached', its either a
         # passive reader or a writer, we compare the owner pid to the agent pid
         # to figure out if its a writer, otherwise we assume passive reader.
         # This won't be needed if we had a way to figure out the mode of the
         # collection, see BUG143113 for details.
         # Also, note that we cannot simply do
         #    counters = entity.eval( '.fooControl.readerCounters()' )
         # For reasons I don't understand, this results in a default-constructed
         # TacSmash::ReaderCounters.  See related BUG420254.
         # Instead, evaluate the string representation of the reader counters.
         countersCmd = f"str( {repr( entity )}.{control}.readerCounters() )"
         if mountStatus == 'attached' and pc.pid() == ownerPid:
            # collection mount status is attached and owner pid matches agent pid
            # we think this is a writer
            countersCmd = \
               f"str( {repr( entity )}.{control}.writerCounters() )"
         counters = pc.eval( countersCmd )

         smashInfo[ rootPath ] = OrderedDict( (
            ( 'owner', owner ),
            ( 'counter', counters ),
            ( 'mount status', mountStatus ),
         ) )
   return smashInfo, shocketState

if __name__ == '__main__':
   parser = argparse.ArgumentParser( description='Script that queries agents '
                                     'for smash collections info.' )
   parser.add_argument( '--sysname', default='ar',
                        help='system name (default: \'%(default)s\')' )
   parser.add_argument( '--path', default='.',
                        help='smash path regex (default: \'%(default)s\')' )
   parser.add_argument( 'agents', nargs='+',
                        help='name of agents to collect smash information from' )
   parser.add_argument( '--thread',
                        help='name of thread to collect shocket state from, \
                        default to main-thread if none specified' )
   args = parser.parse_args()

   # compile the path regular expression
   pathPattern = re.compile( args.path )

   for agent in args.agents:
      print(
         f'---------------------------- {agent} agent ----------------------------' )
      _smashInfo, _shocketState = queryAgent( args.sysname, agent, args.thread )
      for smash, data in _smashInfo.items():
         if pathPattern.search( smash ):
            print( f'Smash collection {smash}:' )
            for i, value in data.items():
               print( f'\t{i}: {value}' )
            print( '\n' )

      print( 'Shocket internal state for ' + \
            ( args.thread if args.thread else 'main thread' ) + ':\n' )
      print( _shocketState )
   print( '----------------------------------------------------------------------' )
