#!/usr/bin/env python3
# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import argparse
import json
import sys

from AgentDirectory import agentIsRunning
from ObjectTree import ObjectTree, diffTrees
import PyClient
from PyRpc import Rpc
import Tracing

traceHandle = Tracing.Handle( "ObjectAudit" )
t0 = traceHandle.trace0
t9 = traceHandle.trace9

def error( msg ):
   print( msg, file=sys.stderr )
   sys.exit( 1 )

def parseArgs():
   parser = argparse.ArgumentParser( description="Walk a given entity subtree of" +
      "one or two agents. If two agents, compare the trees to check for state " +
      "desync" )
   parser.add_argument( "-a", "--agent", metavar="AGENT", dest="agent",
      required=True, help="name of agent to walk and/or compare" )
   compareGroup = parser.add_mutually_exclusive_group()
   compareGroup.add_argument( "-r", "--referenceAgent", metavar="AGENT",
      dest="reference", default="Sysdb",
      help="name of agent to compare against (default Sysdb, mutually exclusive " +
      "with --dumpAgent)" )
   compareGroup.add_argument( "-d", "--dumpAgent", action="store_true",
      help="only walk a single agent and print the object tree, do not compare " +
      "against a reference (mutually exclusive with --referenceAgent)" )
   parser.add_argument( "-s", "--sysname", metavar="SYSNAME", required=False,
      default="ar", help="current sysname (default ar)" )
   parser.add_argument( "--showTrees", action="store_true",
      help="print a representation of the parsed object trees" )
   parser.add_argument( "--noFork", action="store_true",
      help="do not fork the host agent to perform the walk. Off by default. " +
      "Useful when forking the host agent might consume too much memory, and " +
      "cause oom (out of memory) issues. Take extreme care with this option - " +
      "if ObjectAudit or AirStream crashes, they will also crash the host agent." )
   parser.add_argument( "path", metavar="PATH",
      help="path to the root of the subtree to compare" )
   args = parser.parse_args()
   return ( args.agent, args.reference, args.dumpAgent, args.sysname, args.showTrees,
            args.noFork, args.path )

def walkTree( agent, sysname, path, fork=True ):
   if not agentIsRunning( sysname, agent ):
      error( f"{agent} is not running." )

   t0( f"PyClienting into {agent}" )
   if fork:
      # Use execModeForkPerConnection to fork the agent that we're connecting to.
      # This ensures that a crash in ObjectAudit doesn't crash the host agent
      pyclient = PyClient.PyClient( sysname, agent,
                                 execMode=Rpc.execModeForkPerConnection,
                                 reconnect=False )
   else:
      pyclient = PyClient.PyClient( sysname, agent )

   try:
      if fork:
         # execModeForkPerConnection redirects stdout and stderr to /dev/null, but we
         # want to see traces in case of a crash
         pyclient.execute( redirectFds( agent ) )
      messages = pyclient.eval( walkMessage( path ) )
   except PyClient.RpcError as e:
      if "NameError" in e.args[ 0 ]:
         error( f"{path} does not exist in {agent}." )
      else:
         raise
   t9( f"Received output from {agent}: {messages}" )
   return messages

def walkMessage( path ):
   # This function is separated out for testing purposes
   return "Tac.Type( 'ObjectAudit::WalkerHelper' )" + \
          f".walk( Tac.entity( '{path}' ), '{path}' )"

def redirectFds( agent ):
   cmd = \
f"""
import os
fd = os.open("/tmp/{agent}RpcOutput", os.O_CREAT | os.O_WRONLY, mode=0o666 )
os.dup2( fd, 1 )
os.dup2( fd, 2 )
"""
   return cmd

def reconstructTree( agent, path, msgs ):
   # Parse returned HashMsg JSON
   msgStrings = msgs.splitlines()

   def parse( msg ):
      t9( f"Parsing {msg!r}" )
      return json.loads( msg )
   parsedMsgs = [ parse( m ) for m in msgStrings ]

   return ObjectTree( agent, path, parsedMsgs )

def dumpTrees( a, b ):
   print( a.agent )
   a.dumpTree()
   print( b.agent )
   b.dumpTree()

def main():
   agent, reference, dumpAgent, sysname, showTrees, noFork, path = parseArgs()
   fork = not noFork
   if dumpAgent:
      agentMsgs = walkTree( agent, sysname, path, fork )
      agentTree = reconstructTree( agent, path, agentMsgs )

      agentTree.dumpTree()
   else:
      referenceMsgs = walkTree( reference, sysname, path, fork )
      agentMsgs = walkTree( agent, sysname, path, fork )

      referenceTree = reconstructTree( reference, path, referenceMsgs )
      agentTree = reconstructTree( agent, path, agentMsgs )

      if showTrees:
         dumpTrees( referenceTree, agentTree )

      diffTrees( referenceTree, agentTree )

if __name__ == "__main__":
   main()
