# Copyright (c) 2006-2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

#------------------------------------------------------------------------------------
# This module contains a parser for regular expressions.  It is used by the CLI
# parser to parse the regular expressions for a PatternRule and to build a
# corresponding 'partial' regular expression that will match any prefix of anything
# that matches the PatternRule (which is necessary in order to do <tab> and "?"
# completion).
#------------------------------------------------------------------------------------
import string

class Term:
   def setOptimize( self, val ):
      pass

class EmptyTerm( Term ):
   def __str__( self ):
      return ''

   def partial( self ):
      assert False, "EmptyTerm"

class CaretTerm( Term ):
   def __str__( self ):
      return '^'

   def partial( self ):
      return EmptyTerm()

class DollarTerm( Term ):
   def __str__( self ):
      return '$'

   def partial( self ):
      return EmptyTerm()

class Character( Term ):
   def __init__( self, c ):
      self.c_ = c

   def __str__( self ):
      return self.c_

   def partial( self ):
      return OptionalTerm( self )

class EscapedCharacter( Term ):
   def __init__( self, c ):
      self.c_ = c

   def __str__( self ):
      return '\\' + self.c_

   def partial( self ):
      return OptionalTerm( self )

class CharacterClass( Term ):
   def __init__( self, chars ):
      self.chars_ = chars

   def __str__( self ):
      return '[%s]' % self.chars_

   def partial( self ):
      return OptionalTerm( self )

class ParenTerm( Term ):
   def __init__( self, subterm, special=None ):
      self.subterm_ = subterm
      self.special_ = special
      # Whether to enable optimizations to the string representation of a pattern,
      # for example in order to cut down on the number of numbered groups
      # (which is necessary because partial patterns get very large very quickly,
      # and may have more numbered groups than are supported by the re library).
      self.optimize = True

   def __str__( self ):
      if self.optimize:
         template = '(?:%s%s)'
      else:
         template = '(%s%s)'
      return template % ( self.special_ or '', self.subterm_ )

   def setOptimize( self, val ):
      self.optimize = val
      self.subterm_.setOptimize( val )

   def partial( self ):
      assert self.special_ is None, \
             "Special characters in ParenTerm: %s" % self.special_
      return ParenTerm( self.subterm_.partial() )

class SequenceTerm( Term ):
   def __init__( self, subterms ):
      self.subterms_ = subterms

   def __str__( self ):
      return ''.join( map( str, self.subterms_ ) )

   def setOptimize( self, val ):
      for term in self.subterms_:
         term.setOptimize( val )

   def partial( self ):
      options = []
      # pylint: disable-next=consider-using-enumerate
      for i in range( len( self.subterms_ ) ):
         options.append( SequenceTerm( self.subterms_[ : i ] +
                                       [ self.subterms_[ i ].partial() ] ) )
      return OrTerm( options )

class OrTerm( Term ):
   def __init__( self, options ):
      self.options_ = options

   def __str__( self ):
      return '|'.join( map( str, self.options_ ) )

   def setOptimize( self, val ):
      for option in self.options_:
         option.setOptimize( val )

   def partial( self ):
      return OrTerm( [ t.partial() for t in self.options_ ] )

class StarTerm( Term ):
   def __init__( self, term ):
      self.term_ = term

   def __str__( self ):
      return '%s*' % self.term_

   def setOptimize( self, val ):
      self.term_.setOptimize( val )

   def partial( self ):
      return ParenTerm( SequenceTerm( [ self, self.term_.partial() ] ) )

class PlusTerm( Term ):
   def __init__( self, term ):
      self.term_ = term

   def __str__( self ):
      return '%s+' % self.term_

   def setOptimize( self, val ):
      self.term_.setOptimize( val )

   def partial( self ):
      return ParenTerm( SequenceTerm( [ StarTerm( self.term_ ),
                                         self.term_.partial() ] ) )

class OptionalTerm( Term ):
   def __init__( self, term ):
      self.term_ = term

   def __str__( self ):
      return '%s?' % self.term_

   def setOptimize( self, val ):
      self.term_.setOptimize( val )

   def partial( self ):
      return self.term_.partial()

class NegatedTerm( Term ):
   def __init__( self, term ):
      self.term_ = term

   def __str__( self ):
      return '(?!%s)' % self.term_

   def setOptimize( self, val ):
      self.term_.setOptimize( val )

   def partial( self ):
      return self.term_.partial()

class IteratedTerm( Term ):
   def __init__( self, term, m, n ):
      self.term_ = term
      if n is None:
         self.lower_ = None
         self.upper_ = m
      else:
         self.lower_ = m
         self.upper_ = n

   def __str__( self ):
      if self.lower_ is None:
         return '%s{%d}' % ( self.term_, self.upper_ )
      else:
         return '%s{%d,%d}' % ( self.term_, self.lower_, self.upper_ )

   def setOptimize( self, val ):
      self.term_.setOptimize( val )

   def partial( self ):
      if self.upper_ == 1:
         return self.term_.partial()
      else:
         return ParenTerm( SequenceTerm( [
            IteratedTerm( self.term_, 0, self.upper_ - 1 ),
            self.term_.partial() ] ) )

def parse_( i, pattern ):
   """Parses as much as possible of the given regular expression pattern, starting at
   position i."""

   terms = []
   while i < len( pattern ):
      c = pattern[ i ]
      if c == '[':
         start = i + 1
         end = start
         while pattern[ end ] != ']':
            if pattern[ end ] == '\\':
               end += 1
            end += 1
         # Now pattern[ end ] is the ']' character.
         chars = pattern[ start : end ]
         terms.append( CharacterClass( chars ) )
         i = end + 1
         # Now pattern[ i ] is the first character after the character class.
      elif c == '(':
         if pattern[ i + 1 ] == '?':
            assert pattern[ i + 2 ] in '!:', \
                   "Invalid character following '?': '%s'" % pattern[ i + 2 ]
            special = pattern[ i + 1 : i + 3 ]
            i += 2
         else:
            special = None
         ( i, subterm ) = parse_( i + 1, pattern )
         if special == '?!':
            terms.append( NegatedTerm( subterm ) )
         else:
            terms.append( ParenTerm( subterm, special ) )
         assert pattern[ i ] == ')', \
                "Found '%s' when expecting ')'" % pattern[ i ]
         i += 1
         # Now pattern[ i ] is the first character after the parenthesized
         # expression.
      elif c == ')':
         break
      elif c == '|':
         options = [ terms.pop() ]
         ( i, term ) = parse_( i + 1, pattern )
         options.append( term )
         terms.append( OrTerm( options ) )
         assert i == len( pattern ) or pattern[ i ] == ')', \
                "Pattern termination expected"
      elif c == '\\':
         terms.append( EscapedCharacter( pattern[ i + 1 ] ) )
         i += 2
      elif c == '*':
         term = terms.pop()
         terms.append( StarTerm( term ) )
         i += 1
      elif c == '+':
         term = terms.pop()
         terms.append( PlusTerm( term ) )
         i += 1
      elif c == '?':
         term = terms.pop()
         terms.append( OptionalTerm( term ) )
         i += 1
      elif c == '{':
         start = i + 1
         end = start
         while pattern[ end ] in string.digits:
            end += 1
         # Now pattern[ end ] must either be a ',' or a '}'.
         m = int( pattern[ start : end ] )
         if pattern[ end ] == ',':
            start = end + 1
            end = start
            while pattern[ end ] in string.digits:
               end += 1
            # Now pattern[ end ] must be the '}' character.
            n = int( pattern[ start : end ] )
         else:
            n = None
         assert pattern[ end ] == '}', \
                "Found '%s' when expecting '}'" % pattern[ end ]
         i = end + 1
         term = terms.pop()
         terms.append( IteratedTerm( term, m, n ) )
      elif c == '^':
         terms.append( CaretTerm() )
         i += 1
      elif c == '$':
         terms.append( DollarTerm() )
         i += 1
      else:
         terms.append( Character( c ) )
         i += 1
   # i is the first non-consumed character.
   return ( i, SequenceTerm( terms ) )

def parse( pattern ):
   """Parses the given regular expression pattern and returns a Term object that is
   the root of the parse tree."""

   ( i, term ) = parse_( 0, pattern )

   # Check that the entire pattern was accepted by the parser, and that the parse
   # tree (without optimizations) does indeed represent the original pattern.  If
   # either of these assertions fail, either the pattern is invalid or there's a bug
   # in the parser.
   assert i == len( pattern ), "Parsing stopped at %d while the pattern's length " \
                               "is %d" % ( i, len( pattern ) )
   term.setOptimize( False )
   assert str( term ) == pattern, \
          "Term and pattern differ: %s vs %s" % ( str( term ), pattern )
   term.setOptimize( True )

   return term
