# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pkgdeps: rpm MatchList-lib

import Arnet
import BasicCli
import ConfigMount
import CliCommand
from CliDynamicSymbol import CliDynamicPlugin
import CliGlobal
import CliMatcher
# pylint: disable-next=consider-using-from-import
import CliMode.MatchListCliMode as MatchListCliMode
from MultiRangeRule import MultiRangeMatcher
# pylint: disable-next=consider-using-from-import
import CliPlugin.Ip6AddrMatcher as Ip6AddrMatcher
# pylint: disable-next=consider-using-from-import
import CliPlugin.IpAddrMatcher as IpAddrMatcher
import Tac

# --------------------------------------------------------
# Global definitions
# --------------------------------------------------------
AddressFamily = Tac.Type( 'Arnet::AddressFamily' )
MatchStringInfo = Tac.Type( "MatchList::MatchStringInfo" )
MatchStringType = Tac.Type( "MatchList::MatchStringType" )

gv = CliGlobal.CliGlobal( matchListConfig=None )

MatchListModels = CliDynamicPlugin( "MatchListModels" )

# max seqence number (32-bit)
MAX_SEQ = 0xFFFFFFFF

# --------------------------------------------------------
# Utility functions
# --------------------------------------------------------
def getMatchListModelInfo( matchListName, addressFamily ):
   matchListConfig = gv.matchListConfig
   matchlistInfo = MatchListModels.MatchListInfo()
   if addressFamily == AddressFamily.ipv4:
      matchLists = matchListConfig.matchIpv4PrefixList
   else:
      matchLists = matchListConfig.matchIpv6PrefixList

   def populateMatchListModel( mlName ):
      matchListModel = matchlistInfo.MatchList()
      matchListModel.prefixes = []
      mlConfig = matchLists[ mlName ]
      for prefix in mlConfig.subnets:
         matchListModel.prefixes.append( prefix.stringValue )
      matchlistInfo.matchLists[ mlName ] = matchListModel

   if not matchListName:
      # If MatchList name is not given, show all configured
      # lists
      for mlName in matchLists:
         populateMatchListModel( mlName )
   elif matchListName in matchLists:
      # If MatchList name is given
      populateMatchListModel( matchListName )

   return matchlistInfo

# ----------------------------------------------------------------------------
# Match-list modes
# ----------------------------------------------------------------------------

class MatchListStringConfigMode( MatchListCliMode.MatchListMode,
                                 BasicCli.ConfigModeBase ):
   name = "Match list configuration for strings"

   def __init__( self, parent, session, matchListName ):
      self.matchListName = matchListName
      MatchListCliMode.MatchListMode.__init__( self, self.matchListName,
                                               "string" )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

class MatchListIpv4PrefixConfigMode( MatchListCliMode.MatchListMode,
                                     BasicCli.ConfigModeBase ):
   name = "Match list configuration for IPv4 prefixes"

   def __init__( self, parent, session, matchListName ):
      self.matchListName = matchListName
      MatchListCliMode.MatchListMode.__init__( self, self.matchListName,
                                               "prefix-ipv4" )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

class MatchListIpv6PrefixConfigMode( MatchListCliMode.MatchListMode,
                                     BasicCli.ConfigModeBase ):
   name = "Match list configuration for IPv6 prefixes"

   def __init__( self, parent, session, matchListName ):
      self.matchListName = matchListName
      MatchListCliMode.MatchListMode.__init__( self, self.matchListName,
                                               "prefix-ipv6" )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

def enterMatchList( mode, args ):
   matchListName = args[ 'NAME' ]
   inputType = args[ 'TYPE' ]

   if inputType == 'string':
      modeClass = MatchListStringConfigMode
   elif inputType == 'prefix-ipv4':
      modeClass = MatchListIpv4PrefixConfigMode
   elif inputType == 'prefix-ipv6':
      modeClass = MatchListIpv6PrefixConfigMode
   else:
      assert False

   childMode = mode.childMode( modeClass, matchListName=matchListName )
   mode.session_.gotoChildMode( childMode )

def noMatchList( mode, args ):
   matchListName = args[ 'NAME' ]
   inputType = args[ 'TYPE' ]
   if inputType == 'string':
      del gv.matchListConfig.matchStringList[ matchListName ]
   elif inputType == 'prefix-ipv4':
      del gv.matchListConfig.matchIpv4PrefixList[ matchListName ]
   elif inputType == 'prefix-ipv6':
      del gv.matchListConfig.matchIpv6PrefixList[ matchListName ]
   else:
      assert False

def duplicateMatch( mode, matchStringList, matchInfo, seqNumber ):
   # duplication detection
   if seqNumber is not None:
      # if a sequence number is specified, and there is already one,
      # they have to match
      curMatch = matchStringList.matchInfo.get( seqNumber )
      if curMatch:
         if curMatch != matchInfo:
            mode.addError( "Error: Duplicate sequence number" )
         return True

      if matchInfo.type != MatchStringType.exact:
         # We want to detect duplicate exact matches, since it'd cause trouble
         # for maintaining the exact set. When we remove an exact match,
         # we cannot just remove it from the exact set if there are other matches
         # with the same string.
         #
         # There is no real reason why users might want to add identical regex
         # matches to the same list twice with different sequence numbers, but
         # it's the shipping behavior and aside from efficiency it's not a real
         # problem.
         return False

   if matchStringList.duplicate( matchInfo ):
      if seqNumber is not None:
         # this is an error
         mode.addError( "Error: Duplicate match" )
      return True

   return False

def handleMatchString( mode, matchType, string, seqNumber ):
   matchInfo = MatchStringInfo( matchType, string )
   if matchType == MatchStringType.regex:
      regexValidator = Tac.newInstance( 'MatchList::RegexValidator' )
      if not regexValidator.validatePosixRegex( string ):
         mode.addError( 'Invalid regex' )
         return

   matchStringList = gv.matchListConfig.matchStringList.get( mode.matchListName )
   if matchStringList:
      if duplicateMatch( mode, matchStringList, matchInfo, seqNumber ):
         return

   if seqNumber is None:
      seqNumber = ( matchStringList.lastSeqNo() if matchStringList else 0 ) + 10
      if seqNumber > MAX_SEQ:
         mode.addError( 'Error: Sequence number out of range' )
         return

   if not matchStringList:
      matchStringList = gv.matchListConfig.newMatchStringList( mode.matchListName )

   matchStringList.matchInfo[ seqNumber ] = matchInfo
   # pylint: disable-next=consider-using-max-builtin

   if matchType == MatchStringType.exact:
      matchStringList.exactMatch.add( string )

def noMatchString( mode, string, matchType ):
   matchStringList = gv.matchListConfig.matchStringList.get( mode.matchListName )
   if not matchStringList:
      return
   for seqNo, match in matchStringList.matchInfo.items():
      if match.type == matchType and match.string == string:
         del matchStringList.matchInfo[ seqNo ]
         if match.type == MatchStringType.exact:
            matchStringList.exactMatch.remove( string )
         break

def noMatchSeqRange( mode, seqRanges ):
   matchStringList = gv.matchListConfig.matchStringList.get( mode.matchListName )
   if not matchStringList:
      return
   for start, end in seqRanges:
      for seqNo, matchInfo in matchStringList.matchInfo.items():
         if seqNo > end:
            break
         if seqNo >= start:
            del matchStringList.matchInfo[ seqNo ]
            if matchInfo.type == MatchStringType.exact:
               matchStringList.exactMatch.remove( matchInfo.string )

# ------------------------------------------------------------------------------
# The "[no|default] match regex | exact" command in
# "match-list input string" mode
# ------------------------------------------------------------------------------

# Don't allow single or double quotes, since the current use case is feeding
# this into rsyslog conf file which would put quotes around the expression,
# and then we would need to enforce proper string escaping for quotes
regexMatcher = CliMatcher.StringMatcher( helpname='REGEXP',
                                         pattern=r'[^"\']+',
                                     helpdesc="Regular expression (POSIX ERE)" )

stringMatcher = CliMatcher.StringMatcher( helpname='STRING',
                                          pattern=r'[^"\']+',
                                          helpdesc="String" )

class MatchStringCommand( CliCommand.CliCommandClass ):
   syntax = "[ SEQNO ] match ( regex REGEX ) | ( exact STRING )"
   noOrDefaultSyntax = "( match ( regex REGEX ) | ( exact STRING ) ) | SEQ_RANGE"

   data = { "match": "Configure matching",
            "regex": "Match using a regular expression",
            "REGEX": regexMatcher,
            "exact": "Match an exact string",
            "STRING": stringMatcher,
            "SEQNO": CliMatcher.IntegerMatcher( 1, MAX_SEQ,
                                                helpdesc='Index in the sequence' ),
            "SEQ_RANGE": MultiRangeMatcher( lambda: ( 1, MAX_SEQ ), False,
                                            'Index in the sequence' )
           }

   @staticmethod
   def handler( mode, args ):
      seqNumber = args.get( 'SEQNO' )
      string = args.get( "REGEX" )
      if string:
         matchType = MatchStringType.regex
      else:
         string = args[ "STRING" ]
         matchType = MatchStringType.exact

      handleMatchString( mode, matchType, string, seqNumber )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      seqRange = args.get( 'SEQ_RANGE' )
      if seqRange is not None:
         noMatchSeqRange( mode, seqRange.ranges() )
         return

      string = args.get( "REGEX" )
      if string:
         matchType = MatchStringType.regex
      else:
         string = args[ "STRING" ]
         matchType = MatchStringType.exact

      noMatchString( mode, string, matchType )

MatchListStringConfigMode.addCommandClass( MatchStringCommand )

# ------------------------------------------------------------------------------
# The "resequence [starting] [increment]" command in
# "match-list input string" mode
# ------------------------------------------------------------------------------
class ResequenceSeqNo( CliCommand.CliCommandClass ):
   syntax = "resequence [ START [ INC ] ]"

   data = { "resequence": "Resequence the list",
            "START": CliMatcher.IntegerMatcher( 1, MAX_SEQ,
                       helpdesc='Starting sequence number (default 10)' ),
            "INC": CliMatcher.IntegerMatcher( 1, MAX_SEQ,
                  helpdesc='Step to increment the sequence number (default 10)' )
   }

   @staticmethod
   def handler( mode, args ):
      start = args.get( 'START', 10 )
      inc = args.get( 'INC', 10 )
      matchStringList = gv.matchListConfig.matchStringList.get( mode.matchListName )
      if not matchStringList:
         return
      matchInfo = matchStringList.matchInfo
      if not matchInfo:
         return
      lastSeqNo = start + ( len( matchInfo ) - 1 ) * inc
      if lastSeqNo > MAX_SEQ:
         mode.addError( 'Error: Sequence number out of range' )
         return
      # Store matchInfos in a temp dictionary
      tempMatchInfo = {}
      sequence = start
      for match in matchInfo.values():
         tempMatchInfo[ sequence ] = match
         sequence += inc
      # Clear matchInfo list and put them all back
      matchInfo.clear()
      matchInfo.update( tempMatchInfo )

MatchListStringConfigMode.addCommandClass( ResequenceSeqNo )

# ------------------------------------------------------------------------------
# The "[no] match prefix ipv4 <P>/<M> command in
# "match-list input prefix ipv4" mode
# ------------------------------------------------------------------------------
class MatchIpv4Subnets( CliCommand.CliCommandClass ):
   syntax = "match prefix-ipv4 { PREFIX_LIST }"
   noOrDefaultSyntax = syntax

   data = { "match": "Configure matching",
            "prefix-ipv4": "Match IPv4 prefix",
            "PREFIX_LIST": IpAddrMatcher.ipPrefixMatcher }

   @staticmethod
   def handler( mode, args ):
      matchList = gv.matchListConfig.newMatchIpv4PrefixList( mode.matchListName )
      prefixes = args[ "PREFIX_LIST" ]
      for prefix in prefixes:
         matchList.subnets[ Arnet.AddrWithMask( prefix ).subnet ] = True

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      matchList = gv.matchListConfig.matchIpv4PrefixList.get( mode.matchListName )
      if matchList:
         prefixes = args[ "PREFIX_LIST" ]
         for prefix in prefixes:
            del matchList.subnets[ Arnet.AddrWithMask( prefix ).subnet ]

MatchListIpv4PrefixConfigMode.addCommandClass( MatchIpv4Subnets )

# ------------------------------------------------------------------------------
# The "[no] match prefix ipv6 <P>/<M> command in
# "match-list input prefix ipv6" mode
# ------------------------------------------------------------------------------
class MatchIpv6Subnets( CliCommand.CliCommandClass ):
   syntax = "match prefix-ipv6 { PREFIX_LIST }"
   noOrDefaultSyntax = syntax

   data = { "match": "Configure matching",
            "prefix-ipv6": "Match IPv6 prefix",
            "PREFIX_LIST": Ip6AddrMatcher.ip6PrefixMatcher }

   @staticmethod
   def handler( mode, args ):
      matchList = gv.matchListConfig.newMatchIpv6PrefixList( mode.matchListName )
      prefixes = args[ "PREFIX_LIST" ]
      for prefix in prefixes:
         matchList.subnets[ Arnet.Ip6AddrWithMask( prefix ) ] = True

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      matchList = gv.matchListConfig.matchIpv6PrefixList.get( mode.matchListName )
      if matchList:
         prefixes = args[ "PREFIX_LIST" ]
         for prefix in prefixes:
            del matchList.subnets[ Arnet.Ip6AddrWithMask( prefix ) ]

MatchListIpv6PrefixConfigMode.addCommandClass( MatchIpv6Subnets )

# ------------------------------------------------------------------------------
# show match-list prefix-ipv4 [ name ]
# ------------------------------------------------------------------------------
def handleShowMatchListPrefixIpv4( mode, args ):
   matchListName = args.get( 'MATCHLIST_NAME' )
   return getMatchListModelInfo( matchListName, AddressFamily.ipv4 )

# ------------------------------------------------------------------------------
# show match-list prefix-ipv6 [ name ]
# ------------------------------------------------------------------------------
def handleShowMatchListPrefixIpv6( mode, args ):
   matchListName = args.get( 'MATCHLIST_NAME' )
   return getMatchListModelInfo( matchListName, AddressFamily.ipv6 )

def Plugin( entityManager ):
   gv.matchListConfig = ConfigMount.mount( entityManager, "matchlist/config/cli",
                                           "MatchList::Config", "w" )
