#!/usr/bin/env python3
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
# This file provides methods for validating the syntax of a "version"
# with the scheme specified in AID6537 (version section).
# It is not a general EOS version validator

from collections import OrderedDict
import importlib
import re

import Tracing

# Set up tracing
t0 = Tracing.trace0


def removePrefixAndSuffix( text, prefix, suffix ):
   """Force to remove the prefix and suffix, return an empty string if prefix
   or suffix is not found"""
   if text.startswith( prefix ) and text.endswith( suffix ):
      if suffix == "":
         return text[ len( prefix ): ]
      return text[ len( prefix ):-len( suffix ) ]
   return ""

class VersionStr:
   """A version string is the string separated by dots in rules"""
   def __init__( self, text ):
      self.containsRange = False  # If this version string contains a range
      self.prefix = ""            # The string before the range
      self.suffix = ""            # The string after the range
      self.valueRange = [ 0, 0 ]  # The numeric range of the range
      self.text = text            # The original version string text
      self.parse()

   def parse( self ):
      if self.text == "":
         raise ValueError( "Invalid version string" )
      if "*" in self.text:
         raise ValueError( "Invalid use of '*'; it is expected to be the"
                           " last character in a version rule." )
      invalidCh = re.search( r"[^a-zA-Z\d\{\}\$ -]", self.text )
      if invalidCh:
         # pylint: disable-next=consider-using-f-string
         raise ValueError( "Invalid character '%s'" % invalidCh.group( 0 ) )
      # Version string containing range (multiple ranges not allowed)
      left = -1
      right = -1
      split = -1
      for i, ch in enumerate( self.text ):
         if ch == "}":
            if left == -1:
               raise ValueError( "Range requires opening '{' before closing '}'" )
            if right != -1:
               raise ValueError( "Unexpected '}' after range" )
            if split == -1:
               raise ValueError( "Missing '-' from range" )
            right = i
         elif ch == "{":
            if right != -1:
               raise ValueError( "Unexpected '{' after range; each dot-separated "
                                 "identifier can only contain a single range" )
            if left != -1:
               raise ValueError( "Unexpected '{' inside range" )
            left = i
         elif ch == "-" and left != -1 and right == -1:
            # Only treat "-" as a split when inside a range.
            if split != -1:
               raise ValueError( "Range should contain a single '-'" )
            split = i
      if left != -1:
         if right == -1:
            raise ValueError( "Range not closed by '}'" )
         self.prefix = self.text[ :left ]
         self.suffix = self.text[ right + 1: ]
         first, second = self.text[ left + 1:split ], self.text[ split + 1:right ]
         if first == "":
            raise ValueError( "Range missing lower bound" )
         try:
            self.valueRange[ 0 ] = int( first )
         except ValueError:
            # pylint: disable-next=consider-using-f-string,raise-missing-from
            raise ValueError( "Range lower bound '%s' invalid" % first )
         if second == "":
            raise ValueError( "Range missing upper bound" )
         if second == "$":
            self.valueRange[ 1 ] = float( "inf" )
         else:
            try:
               self.valueRange[ 1 ] = int( second )
            except ValueError:
               # pylint: disable-next=consider-using-f-string,raise-missing-from
               raise ValueError( "Range upper bound '%s' invalid" % second )
            if self.valueRange[ 0 ] >= self.valueRange[ 1 ]:
               raise ValueError( "Range upper bound must be greater than lower "
                                 "bound" )  
         self.containsRange = True
 
   def match( self, versionStr ):
      """Given another version string, check if it matches (exact match and
      range match) current version string"""
      # Exact match
      if not self.containsRange:
         return self.text == versionStr
      # Range match
      numStr = removePrefixAndSuffix( versionStr, self.prefix, self.suffix )
      if numStr.isdigit():
         num = int( numStr )
         # pylint: disable-next=chained-comparison
         return num >= self.valueRange[ 0 ] and num <= self.valueRange[ 1 ]
      return False


def tokenize( text ):
   """Given a rule or a version, tokenize it into a list of version strings"""
   try:
      processedText = text
      if processedText and processedText[ -1 ] == "*":
         processedText = processedText[ :-1 ]
      return [ VersionStr( num ) for num in processedText.split( "." ) ]
   except ValueError as e:
      # Format the error message
      # pylint: disable-next=consider-using-f-string,raise-missing-from
      raise ValueError( "Invalid version rule \"%s\": %s" %
                        ( text, e ) )

class Rule:
   """A rule is a single pattern (without commas)"""

   def __init__( self, text ):
      self.matchPrefix = False  # If this rule is a prefix match
      self.versionStrs = []     # Version strings seperated by dots
      self.text = text          # The original rule string text
      self.parse()

   def parse( self ):
      if self.text and self.text[ -1 ] == "*":
         self.matchPrefix = True
      self.versionStrs = tokenize( self.text )
   
   def match( self, version ):
      """Given a version, check if it matches current rule"""
      versionStrs = tokenize( version )
      if ( ( not self.matchPrefix and
             len( versionStrs ) != len( self.versionStrs ) ) or
           len( versionStrs ) < len( self.versionStrs ) ):
         return False
      # pylint: disable-next=consider-using-enumerate
      for i in range( len( self.versionStrs ) ):
         versionStr = versionStrs[ i ].text
         myVersionStr = self.versionStrs[ i ]
         # Exact and range match
         if myVersionStr.match( versionStr ):
            continue
         # 4.2F and 4.2M should match 4.2
         if ( i == len( versionStrs ) - 1 and
              ( versionStr[ -1 ] == "F" or versionStr[ -1 ] == "M" ) and
              myVersionStr.match( versionStr[ :-1 ] ) ):
            return True
         # Prefix match
         if i == len( self.versionStrs ) - 1 and self.matchPrefix:
            # Loop through all the prefixes of version string looking for a
            # match
            for j in range( len( versionStr ), 0, -1 ):
               if ( ( j == len( versionStr ) or
                      not versionStr[ j ].isdigit() ) and
                    myVersionStr.match( versionStr[ :j ] ) ):
                  # Prefix match should only matter at the end
                  return True
         # If one number doesn't match, end directly
         return False
      # If all numbers match, end successfully
      return True


def parse( pattern ):
   """Given a version pattern, return parsed rules"""
   pattern = re.sub( r"\s+", "", pattern )  # Remove all spaces
   ruleStrs = pattern.split( "," )          # Separate patterns
   rules = []
   for ruleStr in ruleStrs:
      if ruleStr == "":
         continue
      rules.append( Rule( ruleStr ) )
   return rules


def getEosVersions():
   Bugz = importlib.import_module( "Bugz" )
   connection = Bugz.Connection()
   return connection.eosVersions( shipped=True )


def getMatchingVersions( pattern ):
   """Given a version pattern, return valid versions in the release database
   matching that pattern"""
   rules = parse( pattern )
   eosVersions = getEosVersions()
   result = OrderedDict()
   for version in eosVersions:
      for rule in rules:
         if rule.match( version ):
            result.setdefault( rule.text, [] ).append( version )
            break
   return result


def validateVersion( pattern, version ):
   """Given a version pattern and a version, return a boolean indicating
   whether they match or not"""
   rules = parse( pattern )
   for rule in rules:
      if rule.match( version ):
         return True
   return False


# If it has been called directly, run a simple example
if __name__ == "__main__":
   example = "4.20, 4.20.{0-12}*, 4.21*"
   t0( example + ":" )
   t0( getMatchingVersions( example ) )
