# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from CliDynamicSymbol import CliDynamicPlugin
from CliPlugin.ForwardingDestinationCommon import (
   ArgToLabel,
   IPv4Fields,
   IPv6Fields,
   InnerIPv4Fields,
   InnerIPv6Fields,
   InnerL2Fields,
   L4Fields,
   matchers,
   NvgreFields,
)

PromptTree = CliDynamicPlugin( 'PromptTree' )
PromptTreeBase = PromptTree.PromptTreeBase
PromptTreeBooleanChoice = PromptTree.PromptTreeBooleanChoice
PromptTreeField = PromptTree.PromptTreeField
PromptTreeChoice = PromptTree.PromptTreeChoice

def promptForProtocolOrNextHeader( treeDict ):
   return ( treeDict[ '<l4Type>' ] not in [ 'tcp', 'udp' ] and
            treeDict[ '<innerPacketType>' ] not in [ 'ipv4', 'ipv6' ] and
            treeDict[ '<packetType>' ] != 'gre' )

class L4PromptTree( PromptTreeBase ):
   def __init__( self ):
      super().__init__()
      for field in L4Fields:
         ptField = PromptTreeField( field, matchers[ field ],
                                    ArgToLabel[ field ] )
         self.fields.append( ptField )

class Ipv4PromptTree( PromptTreeBase ):
   def __init__( self ):
      super().__init__()
      for field in IPv4Fields:
         ptField = PromptTreeField( field, matchers[ field ],
                                    ArgToLabel[ field ] )
         # <ipProto> is the only exception, as we don't prompt for it when the L4
         # type is UDP or TCP, or when an inner packet type is provided
         if field == '<ipProto>':
            ptField.promptLambda = promptForProtocolOrNextHeader
         self.fields.append( ptField )

class InnerL2PromptTree( PromptTreeBase ):
   def __init__( self ):
      super().__init__()
      # Note that the inner L2 header is currently only used with NVGRE which should
      # never have 802.1Q tags, so they won't be prompted for here.
      for field in InnerL2Fields:
         ptField = PromptTreeField( field, matchers[ field ],
                                    ArgToLabel[ field ] )
         if field == '<innerL2EtherType>':
            ptField.promptLambda = ( lambda treeDict:
                                     treeDict[ '<innerPacketType>' ] == 'ethernet' )
         self.fields.append( ptField )

class InnerIpv4PromptTree( PromptTreeBase ):
   def __init__( self ):
      super().__init__()
      for field in InnerIPv4Fields:
         ptField = PromptTreeField( field, matchers[ field ],
                                    ArgToLabel[ field ] )
         # Don't prompt for ip protocol if L4 type has been provided
         if field == '<innerIpProto>':
            ptField.promptLambda = ( lambda treeDict:
                                     treeDict[ '<l4Type>' ] not in [ 'tcp', 'udp' ] )
         self.fields.append( ptField )

class Ipv6PromptTree( PromptTreeBase ):
   def __init__( self ):
      super().__init__()
      for field in IPv6Fields:
         ptField = PromptTreeField( field, matchers[ field ],
                                    ArgToLabel[ field ] )
         # <nextHeader> is the only exception, as we don't prompt for it when the
         # L4 type is UDP or TCP
         if field == '<nextHeader>':
            ptField.promptLambda = promptForProtocolOrNextHeader
         self.fields.append( ptField )

class InnerIpv6PromptTree( PromptTreeBase ):
   def __init__( self ):
      super().__init__()
      for field in InnerIPv6Fields:
         ptField = PromptTreeField( field, matchers[ field ],
                                    ArgToLabel[ field ] )
         # <innerNextHeader> is the only exception, as we don't prompt for it when
         # the L4 type is UDP or TCP
         if field == '<innerNextHeader>':
            ptField.promptLambda = ( lambda treeDict:
                                     treeDict[ '<l4Type>' ] not in [ 'tcp', 'udp' ] )
         self.fields.append( ptField )

class RawPacketPromptTree( PromptTreeBase ):
   def __init__( self ):
      super().__init__()
      self.fields = [
         PromptTreeField(
            '<rawPacket>', matchers[ '<rawPacket>' ],
            ArgToLabel[ '<rawPacket>' ] ),
      ]

class GrePromptTree( PromptTreeBase ):
   def __init__( self ):
      super().__init__()
      self.fields.append( PromptTreeChoice( '<greType>', 'GRE type',
                                            [ 'gre', 'nvgre' ] ) )
      self.furtherPromptTrees = [
         ( GreTypePromptTree(), lambda treeDict : treeDict[ '<greType>' ] == 'gre' ),
         ( NvgreTypePromptTree(),
           lambda treeDict : treeDict[ '<greType>' ] == 'nvgre' ),
      ]

class GreTypePromptTree( PromptTreeBase ):
   def __init__( self ):
      super().__init__()
      self.fields.append( PromptTreeBooleanChoice( 'gre-checksum', 'gre-checksum' ) )
      for field in [ '<greKey>', '<greSequence>' ]:
         ptField = PromptTreeField( field, matchers[ field ],
                                    ArgToLabel[ field ],
                                    required=False )
         self.fields.append( ptField )
      # Prompt for encapsulated packet information
      self.fields.append( PromptTreeChoice( '<innerPacketType>', 'Inner packet type',
                                            [ 'ipv4', 'ipv6' ] ) )
      self.fields.append( PromptTreeChoice( '<l4Type>', 'L4 header',
                                            [ 'none', 'tcp', 'udp' ] ) )

class NvgreTypePromptTree( PromptTreeBase ):
   def __init__( self ):
      super().__init__()
      for field in NvgreFields:
         ptField = PromptTreeField( field, matchers[ field ],
                                    ArgToLabel[ field ] )
         self.fields.append( ptField )
      # Prompt for encapsulated packet information
      self.fields.append( PromptTreeChoice( '<innerPacketType>', 'Inner packet type',
                                            [ 'ethernet', 'ipv4', 'ipv6' ] ) )
      self.fields.append(
         PromptTreeChoice( '<l4Type>', 'L4 header', [ 'none', 'tcp', 'udp' ],
            promptLambda=( lambda treeDict: treeDict[ '<innerPacketType>' ] in
                           [ 'ipv4', 'ipv6' ] ) ) )

class PacketPromptTree( PromptTreeBase ):
   def __init__( self, rawPacketSupported, innerPacketPrompt ):
      super().__init__()
      packetTypes = [ 'ethernet', 'ipv4', 'ipv6' ]
      if rawPacketSupported:
         packetTypes.append( 'raw' )
      if innerPacketPrompt:
         packetTypes.append( 'gre' )

      def isEthernet( treeDict ):
         return treeDict[ '<packetType>' ] == 'ethernet'

      def isRaw( treeDict ):
         return treeDict[ '<packetType>' ] == 'raw'

      def isIp( treeDict ):
         return treeDict[ '<packetType>' ] in [ 'ipv4', 'ipv6' ]

      def isGre( treeDict ):
         return treeDict[ '<packetType>' ] == 'gre'

      def isNvgre( treeDict ):
         if isGre( treeDict ):
            return treeDict[ '<greType>' ] == 'nvgre'
         return False

      def isInnerIp( treeDict ):
         return treeDict[ '<innerPacketType>' ] in [ 'ipv4', 'ipv6' ]

      def promptForIpVersion( treeDict, ipVersion ):
         return ( treeDict[ '<packetType>' ] == ipVersion or
                  ( isGre( treeDict ) and treeDict[ '<ipVersion>' ] == ipVersion ) )

      def promptForIpv4( treeDict ):
         return promptForIpVersion( treeDict, 'ipv4' )

      def promptForIpv6( treeDict ):
         return promptForIpVersion( treeDict, 'ipv6' )

      self.fields = [
         PromptTreeField(
            '<ingressIntf>', matchers[ '<ingressIntf>' ],
            ArgToLabel[ '<ingressIntf>' ] ),
         PromptTreeChoice( '<packetType>', 'Packet type', packetTypes ),
         PromptTreeChoice( '<ipVersion>', 'IP version', [ 'ipv4', 'ipv6' ],
                           promptLambda=(
                              lambda treeDict :
                              innerPacketPrompt and isGre( treeDict ) ) ),
         PromptTreeChoice( '<innerPacketType>', 'Inner packet type',
                           [ 'none', 'ipv4', 'ipv6' ],
                           promptLambda=( lambda treeDict :
                                          innerPacketPrompt and isIp( treeDict ) ) ),
         PromptTreeChoice(
            '<l4Type>', 'L4 header', [ 'none', 'tcp', 'udp' ],
            promptLambda=( lambda treeDict : ( isIp( treeDict ) or
                           isInnerIp( treeDict ) ) and not isGre( treeDict ) ) ),
         PromptTreeField(
            '<srcMac>', matchers[ '<srcMac>' ],
            ArgToLabel[ '<srcMac>' ],
            promptLambda=lambda treeDict : not isRaw( treeDict ) ),
         PromptTreeField(
            '<dstMac>', matchers[ '<dstMac>' ],
            ArgToLabel[ '<dstMac>' ],
            promptLambda=lambda treeDict : not isRaw( treeDict ) ),
         PromptTreeField(
            '<etherType>', matchers[ '<etherType>' ],
            ArgToLabel[ '<etherType>' ],
            promptLambda=isEthernet ),
         PromptTreeField(
            '<vlan>', matchers[ '<vlan>' ],
            ArgToLabel[ '<vlan>' ],
            required=False,
            promptLambda=lambda treeDict : not isRaw( treeDict ) ),
         PromptTreeField(
            '<innerVlan>', matchers[ '<innerVlan>' ],
            ArgToLabel[ '<innerVlan>' ],
            required=False,
            promptLambda=lambda treeDict : treeDict.get( '<vlan>' ) ),
      ]
      self.furtherPromptTrees = [
         ( Ipv4PromptTree(), promptForIpv4 ),
         ( Ipv6PromptTree(), promptForIpv6 ),
         ( GrePromptTree(), isGre ),
         ( InnerL2PromptTree(), isNvgre ),
         ( InnerIpv4PromptTree(),
           lambda treeDict : treeDict[ '<innerPacketType>' ] == 'ipv4' ),
         ( InnerIpv6PromptTree(),
           lambda treeDict : treeDict[ '<innerPacketType>' ] == 'ipv6' ),
         ( L4PromptTree(),
           lambda treeDict : treeDict[ '<l4Type>' ] in [ 'tcp', 'udp' ] ),
         ( RawPacketPromptTree(), isRaw ),
      ]
