# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from abc import abstractmethod, ABC as AbstractBaseClass
from dataclasses import dataclass
from enum import Enum
from typing import List, Union

@dataclass
class ConditionFragment( AbstractBaseClass ):
   """Abstract base type for condition fragments. Condition fragments are RCF text
   expressions in the context of an if-condition.

   Derived types must implement formatStr.
   """
   @abstractmethod
   def formatStr( self ):
      pass

@dataclass
class ModificationFragment( AbstractBaseClass ):
   """Abstract base type for modification fragments. These  fragments are RCF text
   expressions in the assigns/modifies attributes.

   Derived types must implement formatStr.
   """
   @abstractmethod
   def formatStr( self ):
      pass

class RelationalOp( Enum ):
   IS = "is"
   IS_NOT = "is_not"
   GE = ">="
   LE = "<="

@dataclass
class RelationalFragment( ConditionFragment ):
   """Type containing values to format as a binary attribute relation expression. For
   example:
      RelationalFragment( "med", RelationalOp.IS, "100" )
   Becomes:
      (med is 100)
   """
   lhs: str
   op: RelationalOp
   rhs: str

   def formatStr( self ):
      # Wrap with paranthesis to prevent issues with operator precedence.
      return f"({self.lhs} {self.op.value} {self.rhs})"

class ExternalRefOp( Enum ):
   MATCH = "match"

class ExternalRefType( Enum ):
   PREFIX_LIST_V4 = "prefix_list_v4"
   PREFIX_LIST_V6 = "prefix_list_v6"

@dataclass
class ExternalRefOpFragment( ConditionFragment ):
   """Type containing values for format as an external ref (match) operation. For
   example:
      ExternalRefRelationalFragment( "prefix", ExternalRefOp.MATCH,
                                     ExternalRefType.PREFIX_LIST_V4,
                                     "FOO" )
   Becomes:
      (prefix match prefix_list_v4 FOO)
   """
   lhs: str
   op: ExternalRefOp
   externalRefType: ExternalRefType
   rhs: str

   def formatStr( self ):
      # Wrap with paranthesis to prevent issues with operator precedence.
      return f"({self.lhs} {self.op.value} {self.externalRefType.value} {self.rhs})"

class UnaryOp( Enum ):
   NOT = "not"

@dataclass
class UnaryOpFragment( ConditionFragment ):
   """Type containing another condition fragment whose result to modify with a
   unary operator (not). For example:
      UnaryOpFragment(
         UnaryOp.NOT,
         ExternalRefRelationalFragment( "prefix", ExternalRefOp.MATCH,
                                        ExternalRefType.PREFIX_LIST_V4,
                                        "FOO" ),
      )
   Becomes:
      (not (prefix match prefix_list_v4 FOO))
   """
   op: UnaryOp
   innerCondition: ConditionFragment

   def formatStr( self ):
      # Wrap with paranthesis to prevent issues with operator precedence.
      return f"({self.op.value} {self.innerCondition.formatStr()})"

class ModificationOp( Enum ):
   EQ = "="
   INC = "+="
   DEC = "-="
   PREPEND = "prepend"

@dataclass
class AssignmentFragment( ModificationFragment ):
   """Type containing values to format as an assignment statement. For example,
   AssignmentFragment( "med", ModificationOp.EQ, "100" ) becomes:
      "med = 100"
   """
   lhs: str
   op: ModificationOp
   rhs: str

   def formatStr( self ):
      return f"{self.lhs} {self.op.value} {self.rhs};"

@dataclass
class ModifierFunctionFragment( ModificationFragment ):
   """Type containing function name that is used for assignment. For example,
   AssignmentFragment( "set_next_hop_self()" ) becomes:
      "set_next_hop_self()"
   """
   func: str

   def formatStr( self ):
      return f"{self.func};"

class Directive( Enum ):
   FAIL_COMPILE = "FAIL_COMPILE"

class ReturnValue( Enum ):
   TRUE = "true"
   FALSE = "false"
   UNKNOWN = "unknown"

class LineBuffer:
   """Helper for storing lines emitted by the text generator.

   Args:
      tabSize (int): number of spaces per indentation level
   """
   def __init__( self, tabSize ):
      self.lines = []
      self.tabSize = tabSize

   def addLine( self, line, indent ):
      """Add a line to the buffer.

      Args:
         line (str): The line to add
         indent (int): The indentation level
      """
      spaces = " " * self.tabSize * indent
      self.lines.append( f"{spaces}{line}" )

   def getLines( self ):
      """Return the lines from the buffer.
      """
      return self.lines

@dataclass
class TextBlock:
   """Type containing a list of comments, condition expressions, and action
   statements (in the RCF grammar sense) to be formatted as a block of RCF text.

   See fragmentsToFunctionText for formatting.

   This is used as an intermediate for OpenConfig functions and the route-map
   converter.
   """
   commentLines: List[ str ]
   conditionFragments: List[ ConditionFragment ]
   modificationFragments: List[ ModificationFragment ]
   directive: Union[ None, Directive ]
   returnValue: Union[ None, ReturnValue ]

   @classmethod
   def initEmpty( cls ):
      """Construct an empty TextBlock.

      Returns: an empty TextBlock instance
      """
      return cls( commentLines=[], conditionFragments=[], modificationFragments=[],
                  directive=None, returnValue=None )

   def formatLines( self, lineBuffer, indent ):
      """Format the contents of this block.

      - Comments are first.
      - Directives are second.
      - If there are conditions, an if statement containing the modify ops is
        produced. Otherwise, the modify ops are placed in the current scope.
      - If the text block is conditional, any return value is placed inside the
        aforementioned if statement. Otherwise, it's placed in the current
        scope.

      Returns: a list of strings, each element is one line.
      """

      for comment in self.commentLines:
         lineBuffer.addLine( f"# {comment}", indent )

      if self.directive:
         lineBuffer.addLine( f"!{self.directive.value};", indent )

      if self.conditionFragments:
         parenthesizedConditions = [ condition.formatStr()
                                     for condition in self.conditionFragments ]
         allConditionsConjunction = " and ".join( parenthesizedConditions )
         lineBuffer.addLine( f"if {allConditionsConjunction} {{", indent )
         indent += 1

      for modification in self.modificationFragments:
         lineBuffer.addLine( modification.formatStr(), indent )

      if self.returnValue:
         lineBuffer.addLine( f"return {self.returnValue.value};", indent )

      if self.conditionFragments:
         indent -= 1
         lineBuffer.addLine( "}", indent )

   def updateWith( self, otherTextBlock ):
      """Update the contents of this text block with any non-empty contents of
      another instance.

      Args:
         otherTextBlock: the TextBlock instance to copy from
      """
      self.commentLines += otherTextBlock.commentLines
      self.conditionFragments += otherTextBlock.conditionFragments
      self.modificationFragments += otherTextBlock.modificationFragments
      if otherTextBlock.directive is not None:
         self.directive = otherTextBlock.directive
      if otherTextBlock.returnValue is not None:
         self.returnValue = otherTextBlock.returnValue

def fragmentsToFunctionText( functionName, textBlocks, tabSize=3 ):
   """Generate an RCF text function from a list of TextBlock values, names as
   functionName.

   TextBlock( commentLines=[ "Comment line 1", "Comment line 2" ]
              conditionFragments=[
                    RelationalFragment( "med", RelationalOp.IS, "100" ) ],
              modificationFragments=[
                    AssignmentFragment( "med", ModificationOp.EQ, "200" ) ],
              directive=Directive.FAIL_COMPILE,
              returnValue=ReturnValue.TRUE )

   Becomes:

   # Comment line 1
   # Comment line 2
   !FAIL_COMPILE;
   if (med is 100) {
      med = 200;
      return true;
   }

   Args:
      functionName (str): function name
      textBlocks (List[TextBlock]): list of block fragments
      tabSize (int): number of spaces per indentation level
   Returns:
      RCF text for the function.
   """
   lineBuffer = LineBuffer( tabSize )
   indent = 0

   lineBuffer.addLine( f"function {functionName}() {{", indent )
   indent += 1
   for block in textBlocks:
      block.formatLines( lineBuffer, indent )
   indent -= 1
   lineBuffer.addLine( "}", indent )
   # Add a trailing newline after the closing bracket so the generated function
   # concatenates nicely with others.
   lineBuffer.addLine( "", indent )

   return "\n".join( lineBuffer.getLines() )

class RcfFragmentsMixin:

   @staticmethod
   def matchLocalPref( value ):
      return RelationalFragment( "input.local_preference", RelationalOp.IS, value )

   @staticmethod
   def matchTag( value ):
      return RelationalFragment( "input.igp.tag", RelationalOp.IS, value )

   @staticmethod
   def matchMed( value ):
      return RelationalFragment( "input.med", RelationalOp.IS, value )

   @staticmethod
   def matchRouteType( value ):
      return RelationalFragment( "bgp.route_source", RelationalOp.IS, value )

   @staticmethod
   def matchPrefixSet( value, proto, invert=False ):
      assert proto in [ 4, 6 ]
      externalRefType = ExternalRefType.PREFIX_LIST_V4
      if proto == 6:
         externalRefType = ExternalRefType.PREFIX_LIST_V6
      fragment = ExternalRefOpFragment( "prefix", ExternalRefOp.MATCH,
                                        externalRefType, value )
      if invert:
         fragment = UnaryOpFragment( UnaryOp.NOT, fragment )
      return fragment

   @staticmethod
   def setLocalPref( value ):
      return [ AssignmentFragment( "local_preference", ModificationOp.EQ, value ) ]

   @staticmethod
   def setWeight( value ):
      return [ AssignmentFragment( "weight", ModificationOp.EQ, value ) ]

   @staticmethod
   def setLevel( value ):
      return [ AssignmentFragment( "isis.levels", ModificationOp.EQ, value ) ]

   @staticmethod
   def setMetric( value ):
      return [ AssignmentFragment( "isis.metric", ModificationOp.EQ, value ) ]

   @staticmethod
   def setMed( metricType, value ):
      setMedFragments = {
         "medNormal": [ AssignmentFragment( "med", ModificationOp.EQ, value ) ],
         "medIgpNexthopCost": [ AssignmentFragment( "med", ModificationOp.EQ,
                                                    "next_hop_metric" ) ],
         "medAdditive": [ AssignmentFragment( "med", ModificationOp.INC, value ) ],
         "medAddIgpNexthopCost": [ AssignmentFragment( "med", ModificationOp.INC,
                                                       "next_hop_metric" ) ],
         "medSubtractive": [ AssignmentFragment( "med", ModificationOp.DEC,
                                                 value ) ],
         "medValueAddIgpNexthopCost": [ AssignmentFragment( "med",
                                                            ModificationOp.EQ,
                                                            value ),
                                        AssignmentFragment( "med",
                                                            ModificationOp.INC,
                                                            "next_hop_metric" ) ],
         "medIgpMetric": [ AssignmentFragment( "med", ModificationOp.EQ,
                                               "igp.metric" ) ],
         "medAddIgpMetric": [ AssignmentFragment( "med", ModificationOp.INC,
                                                  "igp.metric" ) ],
         "medValueAddIgpMetric": [ AssignmentFragment( "med", ModificationOp.EQ,
                                                       value ),
                                   AssignmentFragment( "med", ModificationOp.INC,
                                                       "igp.metric" ) ],
      }
      # *<value> is not supported, as we do not support floating numbers
      # return None is the error case
      return setMedFragments.get( metricType, None )

   @staticmethod
   def setOrigin( value ):
      return [ AssignmentFragment( "origin", ModificationOp.EQ, value ) ]

   @staticmethod
   def setAsPathPrepend( value ):
      return [ AssignmentFragment( "as_path", ModificationOp.PREPEND, value ) ]

   @staticmethod
   def setNextHop( value ):
      if value == "SELF":
         return [ ModifierFunctionFragment( "set_next_hop_self()" ) ]
      elif value == "PEER_ADDRESS":
         return [ AssignmentFragment( "next_hop", ModificationOp.EQ,
                                      "source_session.remote.ip_address" ) ]
      else:
         return [ AssignmentFragment( "next_hop", ModificationOp.EQ, value ) ]

   @staticmethod
   def returnValue( permit ):
      return ReturnValue.TRUE if permit else ReturnValue.FALSE

   @staticmethod
   def asPathLength( operator, value ):
      asPathLengthFragment = {
         "ATTRIBUTE_EQ": [ RelationalFragment( "input.as_path.length",
                                               RelationalOp.IS, value ) ],
         "ATTRIBUTE_GE": [ RelationalFragment( "input.as_path.length",
                                               RelationalOp.GE, value ) ],
         "ATTRIBUTE_LE": [ RelationalFragment( "input.as_path.length",
                                               RelationalOp.LE, value ) ],
      }
      return asPathLengthFragment.get( operator )
