# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import difflib
import re

separatorRe = re.compile( "^ *!$" )
commentRe = re.compile( "^ *!!" )
noCommandRe = re.compile( "^ *no " )
commandRe = re.compile( r"( *)(.*)" )

class DiffModelBlock:
   """ A class represents the diff block between 2 CliSave models. """
   def __init__( self, tag, block, revertCmd=None, rootBlock=False ):
      # - tag is either DiffModel's COMMON, ADD, REMOVE
      # - block is a list of strings representing a diff block
      # - revertCmd is meaningful only when tag is REMOVE. It indicates
      #   that this is ModeEntryModel and only revertCmd should be used
      #   to generate Cli command.
      # - If rootBlock is set, this is a special empty block only used
      #   to generate block hierarchy with special indentLevel = -1.
      self.tag_ = tag
      self.block_ = block
      self.revertCmd_ = revertCmd
      self.rootBlock_ = rootBlock
      # childBlocks keep a list of child DiffModelBlock.  This is used
      # for reordering the deletion blocks when we generate cli diff.
      self.childBlocks_ = []
      # point to parent.  All childBlocks point to this node.
      self.parent_ = None
      # hasRemove flag indicates if there is any REMOVE block in this subtree
      self.hasRemove_ = tag == DiffModel.REMOVE
      self.hasAdd_ = tag == DiffModel.ADD

      if rootBlock:
         self.tag = DiffModel.COMMON
         self.indentLevel_ = -1
      else:
         # By default, calculate indent level based on first command
         numSpace = len( block[ 0 ] ) - len( block[ 0 ].lstrip( ' ' ) )
         assert numSpace % 3 == 0
         self.indentLevel_ = numSpace // 3

   def __str__( self ):
      return f"DiffModelBlock tag: {self.tag_} block: {self.block_} " + \
             f"revertCmd: {self.revertCmd_} rootBlock: {self.rootBlock_} " + \
             f"indentLevel: {self.indentLevel_}"

   def addChildBlock( self, child ):
      self.childBlocks_.append( child )
      child.parent_ = self
      # Update all ancestors if this is ADD or REMOVE node:
      if child.hasRemove_ and not self.hasRemove_ or \
            child.hasAdd_ and not self.hasAdd_:
         hasRemove = child.hasRemove_
         hasAdd = child.hasAdd_
         node = child
         while node.parent_:
            node.parent_.hasRemove_ |= hasRemove
            node.parent_.hasAdd_ |= hasAdd
            node = node.parent_

# Given a list of DiffModelBlock, update their child block based on indent
# and order.  Return the next index to be processed.
def updateChildBlocks( blocks, index, parentBlock ):
   expectedLevel = parentBlock.indentLevel_ + 1
   prevBlock = None
   while True:
      if index >= len( blocks ):
         return index
      block = blocks[ index ]
      if block.indentLevel_ == expectedLevel:
         parentBlock.addChildBlock( block )
         prevBlock = block
      elif block.indentLevel_ == expectedLevel + 1:
         index = updateChildBlocks( blocks, index, prevBlock )
         continue
      elif block.indentLevel_ < expectedLevel:
         return index
      else:
         assert False, "More indents than expected."
      index += 1

class DiffModel:
   """ This represents the diff between 2 CliSave models."""
   COMMON = ' '
   ADD = '+'
   REMOVE = '-'

   def __init__( self ):
      # a list of DiffModelBlocks
      self.blocks_ = []
      # rootBlock is created once buildBlockTree() is called
      self.rootBlock_ = None

   def append( self, tag, block, revertCmd=None ):
      """ append the diff model with cli command 'block' and 'tag' """
      assert tag in ( self.COMMON, self.ADD, self.REMOVE )
      if not block:
         # BUG991455: sometimes a block doesn't have any data.
         # This can happen like with an interface profile
         return
      # Make it a list if not
      block = [ block ] if isinstance( block, str ) else block
      # Strip off '\n' from all lines, if any
      block = [ s.rstrip( '\n' ) for s in block ]
      self.blocks_.append( DiffModelBlock( tag, block, revertCmd ) )

   def appendDiffModel( self, theirModel ):
      self.blocks_.extend( theirModel.blocks_ )

   def render( self, stream ):
      """ print the diff model in the diff format to stream """
      for block in self.blocks_:
         for line in block.block_:
            stream.write( f'{block.tag_}{line}\n' )

   # Create block hierarchy from self.blocks_.  The self.rootBlock_ is an root
   # node of the tree.
   def buildBlockTree( self ):
      if self.rootBlock_:
         return
      self.rootBlock_ = DiffModelBlock( '', None, None, rootBlock=True )
      # Update child blocks to create hierarchy
      index = updateChildBlocks( self.blocks_, 0, self.rootBlock_ )
      assert index == len( self.blocks_ )
      return

   def processAddBlocks( self, results ):
      blockStack = [ self.rootBlock_ ]
      while blockStack:
         block = blockStack.pop()
         if not block.hasAdd_:
            continue
         assert block.tag_ in ( self.COMMON, self.ADD ) or block.rootBlock_
         _ = self.updateResult( block, results )
         # put child blocks into stack. Reverse to preserve forward order.
         for child in reversed( block.childBlocks_ ):
            blockStack.append( child )

   def processRemoveBlocks( self, results ):
      blockStack = [ self.rootBlock_ ]
      alreadyRemovedComment = False
      while blockStack:
         block = blockStack.pop()
         if not block.hasRemove_:
            continue
         assert block.tag_ in ( self.COMMON, self.REMOVE ) or block.rootBlock_
         alreadyRemovedComment = self.updateResult(
                                    block, results, alreadyRemovedComment )
         # puts child blocks into stack.  This effectively put last block first.
         blockStack.extend( block.childBlocks_ )

   # Update result with block info.  Return if this block removes comments.
   def updateResult( self, modelBlock, results, alreadyRemovedComment=False ):
      tag = modelBlock.tag_
      block = modelBlock.block_
      if modelBlock.rootBlock_:
         return False
      elif tag in ( self.COMMON, self.ADD ):
         results.extend( block )
         return False
      elif tag == self.REMOVE:
         # For REMOVE case, the only time the block len > 1 is when it
         # is ModeEntryModel, where revertCmd is set.
         line = block[ 0 ]
         prefix = ' ' * ( len( line ) - len( line.lstrip( ' ' ) ) )
         if modelBlock.revertCmd_ is not None:
            results.append( f'{prefix}{modelBlock.revertCmd_}' )
            return False
         elif separatorRe.match( line ):
            # Separator.
            results.append( line )
            return False
         elif commentRe.match( line ):
            # Avoid printing multiple "no comment"
            if not alreadyRemovedComment:
               results.append( f'{prefix}no comment' )
            return True
         elif noCommandRe.match( line ):
            line = re.sub( r"no", "default", line, 1 )
            results.append( line )
            return False
         else:
            line = commandRe.sub( r"\1default \2", line, 1 )
            results.append( line )
            return False
      else:
         assert False

   def getCliCommands( self ):
      """ return list of Cli commands of the diff model"""
      # NOTE
      # 1. For COMMON and ADD, Cli commands should be printed as is.
      # 2. For REMOVE,
      #    2.1. For ModeEntryModel, only write the revertCmd for
      #         the whole block
      #    2.2. If the removed command was not started with 'no', put
      #         'default' prefix
      #    2.3. If the removed command started with 'no', replace 'no'
      #         with 'default'
      #    2.4. For comments, write "no comment" for all removed comments.
      #         There is no multiple "no comment" for a particular block.
      # 3. separator should be printed no matter what (COMMON, ADD, REMOVE)
      #
      # The output is generated in 2 phases:
      # Phase1: remove blocks are printed in reverse order. This order is across
      #         levels.
      # Phase2: add blocks are printed.
      # For both phases, any ancestor node (tag=COMMON) required to the leaf
      # add/remove blocks are also printed.

      # Generate the block hierarchy rooted at an empty root node self.rootBlock_
      self.buildBlockTree()
      results = []

      # Output the REMOVE blocks first in reverse order.
      self.processRemoveBlocks( results )

      # Output the add blocks.
      self.processAddBlocks( results )

      return results

def diffLines( diffModel, theirLines, myLines, prefix='', filterFunc=None,
      useInsertionOrder=False ):
   '''
   Update 'diffModel' with the diff between 2 sets of lines.
   The differ used will depend if insertion order matters or not. If insertion
   order doesn't matter then use the traditional differ, otherwise use a custom
   differ

   The 'filterFunc', if set, is used to filter the output.  Given 'tag' and 'line',
   filterFunc return True if they should be included in the output.
   '''
   def appendIfEligible( tag, line ):
      if ( not filterFunc ) or filterFunc( tag, line ):
         diffModel.append( tag, line )

   if not useInsertionOrder:
      return _diffLinesTraditional( diffModel, theirLines, myLines, prefix=prefix,
            appendFunc=appendIfEligible )
   return _diffLinesInsertionOrder( diffModel, theirLines, myLines, prefix=prefix,
         appendFunc=appendIfEligible )

def _diffLinesTraditional( diffModel, theirLines, myLines, appendFunc, prefix='' ):
   '''
   Update 'diffModel' with the diff between 2 sets of lines.
   This is similiar to https://docs.python.org/3/library/difflib.html#difflib.ndiff
   however it has the important distinction that it won't print ? in the case
   lines are somewhat similar
   '''

   s = difflib.SequenceMatcher( None, theirLines, myLines )
   for tag, i1, i2, j1, j2 in s.get_opcodes():
      if tag == 'replace':
         for idx in range( i1, i2 ):
            for line in theirLines[ idx ].splitlines():
               appendFunc( DiffModel.REMOVE, f'{prefix}{line}' )
         for idx in range( j1, j2 ):
            for line in myLines[ idx ].splitlines():
               appendFunc( DiffModel.ADD, f'{prefix}{line}' )
      elif tag == 'delete':
         for idx in range( i1, i2 ):
            for line in theirLines[ idx ].splitlines():
               appendFunc( DiffModel.REMOVE, f'{prefix}{line}' )
      elif tag == 'insert':
         for idx in range( j1, j2 ):
            for line in myLines[ idx ].splitlines():
               appendFunc( DiffModel.ADD, f'{prefix}{line}' )
      elif tag == 'equal':
         for idx in range( j1, j2 ):
            for line in myLines[ idx ].splitlines():
               appendFunc( DiffModel.COMMON, f'{prefix}{line}' )
      else:
         assert False, f'Unknown tag {tag}'

def _diffLinesInsertionOrder( diffModel, theirLines, myLines, appendFunc,
      prefix='' ):
   ''' This differ assumes that theirLines and myLines might not be sorted
   in the same order. For example we might have something like
   A
    a a
   B
    b b
   C
    c c

   Be changed to:
   C
    c c
   B
    b b
   A
    a a

   in which case this function should return:
   -A
   - a a
   -B
   - b b
   -C
   - c c
   +C
   + c c
   +B
   + b b
   +A
   + a a
   '''

   # If the 2 sets of lines are common just then mark them as common
   if theirLines == myLines:
      for idx, line in enumerate( theirLines ):
         appendFunc( DiffModel.COMMON, f'{prefix}{line}' )
      return

   # The 2 sets of lines have differences.
   # Print all of the common lines, then once they have diverged remove all of
   # theirLines and add all of myLines
   idx = 0
   brokeEarly = False
   for idx, line in enumerate( theirLines ):
      if len( myLines ) <= idx or line != myLines[ idx ]:
         brokeEarly = True
         break
      appendFunc( DiffModel.COMMON, f'{prefix}{line}' )

   if not brokeEarly:
      # if all of the lines
      for line in myLines[ idx + 1 : ]:
         appendFunc( DiffModel.ADD, f'{prefix}{line}' )
      return

   # remove the rest of their lines and add the rest of my lines
   for line in theirLines[ idx : ]:
      appendFunc( DiffModel.REMOVE, f'{prefix}{line}' )

   for line in myLines[ idx : ]:
      appendFunc( DiffModel.ADD, f'{prefix}{line}' )
