#!/usr/bin/env python3
# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from collections import deque
import os.path
import sys

import Tracing

traceHandle = Tracing.Handle( "ObjectTree" )
t5 = traceHandle.trace5
t9 = traceHandle.trace9

class ObjectTree:
   """Represents the object tree structure on an agent. Constructs an internal trie
   given a list of parsed HashMsgs."""
   def __init__( self, agent, prefix, msgs ):
      self.agent = agent
      self.prefix = prefix

      # If not given any messages, there is no tree to generate.
      if not msgs:
         self.root = None
         return

      t5( f"Constructing ObjectTree given messages: {msgs!r}" )

      # Check if the root node is a leaf
      if msgs[ 0 ][ "path" ] == self.prefix:
         self.root = ObjectTreeNode( self.agent, msgs[ 0 ] )
         return
      else:
         self.root = ObjectTreeNode( self.agent )

      for msg in msgs:
         self._insert( msg[ "path" ], ObjectTreeNode( self.agent, msg ) )

      self._aggregateHashes( self.root )

   # Prints a representation of the ObjectTree.
   def dumpTree( self ):
      print( f"{self.prefix}   {self.root}" )
      self._dumpTreeHelper( self.root, "|" )
      print() # Trailing double-newline

   def _dumpTreeHelper( self, node, prefix ):
      for seg, child in node.children.items():
         print( f"{prefix}-- {seg}\t{child}" )
         self._dumpTreeHelper( child, prefix + "   |" )

   # Assumes that all HashMsgs are for unique paths.
   def _insert( self, path, node ):
      stripped = stripPrefix( path, self.prefix )
      segments = splitPath( stripped )
      cur = self.root
      for i, seg in enumerate( segments ):
         if i == len( segments ) - 1:
            break
         cur = cur.setdefault( seg, ObjectTreeNode( self.agent ) )
      cur.setdefault( segments[ -1 ], node )

   # Aggregates hashes from the subtree rooted at node
   def _aggregateHashes( self, node ):
      if node.isLeaf():
         t9( f"Hit leaf with hash {node.hash}" )
         return node.hash

      # Aggregate child hashes along with their path segment to make sure that trees
      # with the same graph topology and values but different paths (ex:
      # /test/col[key=0]/colValue = "ent" vs. /test/ent/name = "ent") have different
      # hashes.
      hashes = [ ( seg, self._aggregateHashes( child ) )
               for seg, child in node.children.items() ]
      t9( f"Aggregating child hashes: {hashes!r}" )
      nodeHash = hash( tuple( hashes ) )
      t9( f"Result: {nodeHash}" )
      node.hash = nodeHash
      return nodeHash

class ObjectTreeNode:
   def __init__( self, agent, msg=None ):
      if msg:
         self.hash = msg[ "hash" ]
         self.value = msg[ "value" ]
         self.type = msg[ "type" ]
      self.agent = agent
      self.children = {}

   def __str__( self ):
      if self.isLeaf():
         return f"hash: {self.hash}, value: {self.value!r}, type: {self.type}"
      else:
         return f"hash: {self.hash}"

   def __getitem__( self, key ):
      return self.children[ key ]

   def setdefault( self, key, default ):
      return self.children.setdefault( key, default )

   def isLeaf( self ):
      return not self.children

# Utility Functions

def error( msg ):
   print( msg, file=sys.stderr )
   sys.exit( 1 )

# Removes prefix from a path.
def stripPrefix( path, prefix ):
   if path.startswith( prefix ):
      return path[ len( prefix ) : ]
   return path

# Splits entire path into segments, turning "/path/to/object" into [ "path", "to",
# "object" ].
def splitPath( path ):
   pathList = deque()
   head = path
   while True:
      head, tail = os.path.split( head )
      if tail == "":
         break
      pathList.appendleft( tail )
      if head == "":
         break
   return pathList

# Top level function to compare ObjectTrees. Treats tree a as the source of truth.
def diffTrees( a, b ):
   if a.root.hash == b.root.hash:
      print( f"{b.agent} matches {a.agent} at {a.prefix}." )
   else:
      diffSubtrees( a.root, b.root, a.prefix )

# Compares subtrees rooted at a and b, treating subtree a as the source of truth.
# Notice that this function will only ever be called if there is a hash difference
# between the subtrees, so we know there should always be a difference to look for.
def diffSubtrees( a, b, pathHead, pathTail=None ):
   # If both nodes are leaves, then their values differ
   if a.isLeaf() and b.isLeaf():
      leafDiff( a, b, pathHead, pathTail )
      return

   # Otherwise their children differ
   aSegments = set()
   bSegments = set( b.children.keys() )
   fullPath = os.path.join( pathHead, pathTail ) if pathTail else pathHead
   foundDiff = False
   for aSeg in a.children.keys():
      # Unmatched child path segments indicate a structural difference
      if aSeg not in bSegments:
         aSegments.add( aSeg )
      else:
         bSegments.remove( aSeg )
         aChild = a[ aSeg ]
         bChild = b[ aSeg ]
         # Mismatched hashes indicate a difference in the subtree
         if aChild.hash != bChild.hash:
            foundDiff = True
            diffSubtrees( aChild, bChild, fullPath, aSeg )

   # Report any structural differences
   # aSegments is the set of child paths present in node a but not node b, while
   # bSegments is the set of those in node b but not node a.
   if aSegments or bSegments:
      structuralDiff( a, b, aSegments, bSegments, fullPath )
   elif not foundDiff:
      error( f"Failed to find difference at {fullPath} despite hash mismatch." )

def leafDiff( a, b, pathHead, pathTail ):
   if a.value == b.value:
      error( f"Found matching values in nodes with different hashes.\na: {a}\t"
             f"b: {b}" )

   msg = f"Attribute {pathTail} of {pathHead} differs by value.\n" + \
         f"{a.agent}: {str( a.value )!r} (of type {a.type})\n" + \
         f"{b.agent}: {str( b.value )!r} (of type {b.type})\n"
   print( msg )

def structuralDiff( a, b, aDiff, bDiff, path ):
   if aDiff and bDiff:
      msg = f"Both agents have unmatched children at {path}.\n"
   else:
      has, doesnt = ( a.agent, b.agent ) if aDiff else ( b.agent, a.agent )
      msg = f"{has} has children not present in {doesnt} at {path}.\n"

   def showUnmatchedChildren( agent, diff ):
      out = f"{agent}:\n"
      # Sort diff largely to have deterministic output for tests
      for child in sorted( diff ):
         out += f"   {child}\n"
      return out

   if aDiff:
      msg += showUnmatchedChildren( a.agent, aDiff )
   if bDiff:
      msg += showUnmatchedChildren( b.agent, bDiff )

   print( msg )
