# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, division, print_function
import Arnet
import ConfigMount
import LazyMount
import CliExtensions
import CliParser
import CliMatcher
import CliCommand
from CliPlugin import IntfCli
from CliPlugin.RouterGeneralCli import ( RouterGeneralMode,
                                         routerGeneralCleanupHook )
import Tac
import Tracing
from CliMode.SegmentRoutingMode import SrModeBase
import BasicCli
from TypeFuture import TacLazyType
from enum import Enum
from Toggles.MplsToggleLib import (
   toggleReserveServiceLabelEnabled,
   toggleSrProxySegDefaultPopEnabled,
)

traceHandle = Tracing.Handle( 'SegmentRoutingCli' )
t1 = traceHandle.trace1 # Info

# pkgdeps: library MplsSysdbTypes

srConfig = None
mplsHwCapability = None
mplsConfig = None
ip6Config = None
ipConfig = None
flexAlgoConfig = None
srSidInvalid = Tac.Type( "Routing::SegmentRoutingCli::Constants" ).srSidInvalid
MplsLabel = TacLazyType( 'Arnet::MplsLabel' )

class NodeSegmentType( Enum ):
   NONE = 0
   INDEX = 1
   LABEL = 2

#--------------------------------------------------------
# Segment-routing currently only support mpls dataplane
#--------------------------------------------------------
def mplsSupportedGuard( mode, token ):
   if mplsHwCapability.mplsSupported:
      return None
   else:
      return CliParser.guardNotThisPlatform

# The node-segment ipv[4|6] command is common to both ISIS and OSPF,
# and implemented here. The below callback is used to check for
# conflicts with other prefix/proxy-node configs in ISIS/OSPF.
# If a configuration is acceptable to both OSPF and ISIS only then
# is it accepted. This should typically not be a problem as the protocol
# callbacks will check that the interface is of interest to them.
# NOTE : Bugs 373190 & 147320 track removing the conflict resolution
# logic from the Cli, once fixed we could removed the below hooks.
validateNodeSegmentHook = CliExtensions.CliHook()

def checkForConflicts( mode, index=None, ipv6=False, label=None ):
   for hook in validateNodeSegmentHook.extensions():
      ( accept, hookMsg ) = hook( mode.intf.name, index=index, ipv6=ipv6,
                                  label=label )
      if hookMsg:
         if accept:
            mode.addWarning( hookMsg )
         else:
            mode.addError( hookMsg )

      if not accept:
         t1( '%s reported conflict for %s' % ( hook.__str__, mode.intf.name ) )
         return True

   return False

def _deleteSrIntfConfigIfAllDefaults( intfName ):
   srIntfConfig = srConfig.intfConfig.get( intfName, None )
   if srIntfConfig:
      if ( srIntfConfig.srV6NodeSegmentIndex == srSidInvalid and
           srIntfConfig.srNodeSegmentIndex == srSidInvalid and
           not srIntfConfig.srV4NodeSegment and
           not srIntfConfig.srV6NodeSegment and
           not srIntfConfig.srV4NodeSegmentLabel and
           not srIntfConfig.srV6NodeSegmentLabel ):
         del srConfig.intfConfig[ intfName ]

def deleteExistingSameFlexAlgoConfig( srIntfConfig, label, nodeSegIndexCol,
                                      nodeSegLabelCol, ipv6=False,
                                      flexAlgoName=None ):
   if not flexAlgoName:
      if not ipv6:
         srIntfConfig.srNodeSegmentIndex = srSidInvalid
      else:
         srIntfConfig.srV6NodeSegmentIndex = srSidInvalid
   else:
      for sid, sidAlgo in nodeSegIndexCol.items():
         if sidAlgo.flexAlgoName == flexAlgoName:
            del nodeSegIndexCol[ sid ]

   for labelKey, labelEntry in nodeSegLabelCol.items():
      if ( labelKey == label or labelEntry.flexAlgoName == flexAlgoName ):
         del nodeSegLabelCol[ labelKey ]

def updateSrNodeSid( intfName, index, explicitNull, noPhp, flexAlgoName=None,
                     label=None, ipv6=False ):
   # All registered hooks find this configuration acceptable
   # so update the config
   srIntfConfig = srConfig.intfConfig.get( intfName, None )
   if not srIntfConfig:
      srIntfConfig = srConfig.intfConfig.newMember( intfName )

   if ipv6:
      nodeSegIndexCol = srIntfConfig.srV6NodeSegment
      nodeSegLabelCol = srIntfConfig.srV6NodeSegmentLabel
   else:
      nodeSegIndexCol = srIntfConfig.srV4NodeSegment
      nodeSegLabelCol = srIntfConfig.srV4NodeSegmentLabel

   if label:
      nodeSeg = Tac.Value( 'Routing::SegmentRoutingCli::NodeSegmentLabel',
                           label )
      flexAlgoName = '' if not flexAlgoName else flexAlgoName
      deleteExistingSameFlexAlgoConfig( srIntfConfig, label, nodeSegIndexCol,
                                        nodeSegLabelCol, ipv6=ipv6,
                                        flexAlgoName=flexAlgoName )
      nodeSeg.flexAlgoName = flexAlgoName
      nodeSeg.explicitNull = explicitNull
      nodeSeg.noPhp = noPhp
      nodeSegLabelCol.addMember( nodeSeg )
      return

   for labelKey, labelEntry in nodeSegLabelCol.items():
      if ( not labelEntry.flexAlgoName ) and ( not flexAlgoName ):
         del nodeSegLabelCol[ labelKey ]
      elif labelEntry.flexAlgoName == flexAlgoName:
         del nodeSegLabelCol[ labelKey ]

   t1( 'Updating', 'ipv6' if ipv6 else 'ipv4', 'node-segment for', intfName, 'to',
       index, 'flexAlgo %s' % flexAlgoName if flexAlgoName else '' )

   # Delete the existing configuration for this index
   if ipv6 and srIntfConfig.srV6NodeSegmentIndex == index:
      srIntfConfig.srV6NodeSegmentIndex = srSidInvalid
   elif not ipv6 and srIntfConfig.srNodeSegmentIndex == index:
      srIntfConfig.srNodeSegmentIndex = srSidInvalid

   del nodeSegIndexCol[ index ]

   # Set configuration
   if not flexAlgoName:
      if ipv6:
         srIntfConfig.srV6NodeSegmentIndex = index
         srIntfConfig.srV6NodeSegmentExplicitNull = explicitNull
         srIntfConfig.srV6NodeSegmentNoPhp = noPhp
      else:
         srIntfConfig.srNodeSegmentIndex = index
         srIntfConfig.srNodeSegmentExplicitNull = explicitNull
         srIntfConfig.srNodeSegmentNoPhp = noPhp
      return

   # Flex-Algo node SIDs. Allow one node SID per flex-algo
   for sid, sidAlgo in nodeSegIndexCol.items():
      if sidAlgo.flexAlgoName == flexAlgoName:
         del nodeSegIndexCol[ sid ]
         break
   nodeSeg = Tac.Value( 'Routing::SegmentRoutingCli::NodeSegment', index )
   nodeSeg.flexAlgoName = flexAlgoName
   nodeSeg.explicitNull = explicitNull
   nodeSeg.noPhp = noPhp
   nodeSegIndexCol.addMember( nodeSeg )

def deleteSrNodeSid( intfName, index=None, label=None, nodeSegmentType=None,
                     ipv6=False ):
   srIntfConfig = srConfig.intfConfig.get( intfName, None )

   if not srIntfConfig:
      return

   t1( 'Deleting %s node-segment for %s' % ( 'ipv6' if ipv6 else 'ipv4', intfName ) )
   if ipv6:
      nodeSegIndexCol = srIntfConfig.srV6NodeSegment
      nodeSegLabelCol = srIntfConfig.srV6NodeSegmentLabel
   else:
      nodeSegIndexCol = srIntfConfig.srV4NodeSegment
      nodeSegLabelCol = srIntfConfig.srV4NodeSegmentLabel

   def setNodeSegmentIndexToInvalid():
      if index is not None:
         if ipv6 and srIntfConfig.srV6NodeSegmentIndex == index:
            srIntfConfig.srV6NodeSegmentIndex = srSidInvalid
         elif srIntfConfig.srNodeSegmentIndex == index:
            srIntfConfig.srNodeSegmentIndex = srSidInvalid
      else:
         if ipv6:
            srIntfConfig.srV6NodeSegmentIndex = srSidInvalid
         else:
            srIntfConfig.srNodeSegmentIndex = srSidInvalid

   if nodeSegmentType == NodeSegmentType.LABEL:
      if label:
         del nodeSegLabelCol[ label ]
      else:
         nodeSegLabelCol.clear()
   elif nodeSegmentType == NodeSegmentType.INDEX:
      setNodeSegmentIndexToInvalid()
      if index is not None:
         del nodeSegIndexCol[ index ]
      else:
         nodeSegIndexCol.clear()
   else:
      setNodeSegmentIndexToInvalid()
      nodeSegIndexCol.clear()
      nodeSegLabelCol.clear()

   _deleteSrIntfConfigIfAllDefaults( intfName )

#---------------------------------------------------------------------------------
# Create a new mode SrMode to add segment routing related command agnostic to IGP
#---------------------------------------------------------------------------------
class SrMode( SrModeBase, BasicCli.ConfigModeBase ):
   name = "Segment Routing Configuration"

   def __init__( self, parent, session ):
      SrModeBase.__init__( self, None )
      BasicCli.ConfigModeBase.__init__( self, parent, session )

#-------------------------------------------------------------------------------
# Create a new modelet so that we can add node-segment command only to relavant
# interfaces
#-------------------------------------------------------------------------------
class SrIntfConfigModelet( CliParser.Modelet ):
   @staticmethod
   def shouldAddModeletRule( mode ):
      return ( mode.intf.routingSupported() and
               not mode.intf.name.startswith( "Management" ) )

# Add SrIntfConfigModelet to IntfConfigMode
IntfCli.IntfConfigMode.addModelet( SrIntfConfigModelet )
srIntfModelet = SrIntfConfigModelet

class SegmentRoutingIntf( IntfCli.IntfDependentBase ):
   def setDefault( self ):
      deleteSrNodeSid( self.intf_.name )
      deleteSrNodeSid( self.intf_.name, ipv6=True )

#-------------------------------------------------------------------------------
# [no|default] node-segment ipv4|ipv6 index <index> [flex-algo <name>]
# also adds node-segment index <index> which is hidden since it is deprecated
#-------------------------------------------------------------------------------
def setNodeSegmentV4( mode, index, flexAlgoName, explicitNull, noPhp, label ):
   intfName = mode.intf.name
   ipIntfConfig = ipConfig.ipIntfConfig.get( intfName )

   if ( ipIntfConfig is None ) or \
          ( ipIntfConfig is not None and
              ipIntfConfig.addrWithMask.address != '0.0.0.0'
              and ipIntfConfig.addrWithMask.len != 32 ):
      mode.addWarning( "/32 IPv4 address is not configured on the interface" )

   if label:
      if not checkForConflicts( mode, label=label ):
         updateSrNodeSid( intfName, index, explicitNull, noPhp, flexAlgoName, label )
   elif not checkForConflicts( mode, index ):
      updateSrNodeSid( intfName, index, explicitNull, noPhp, flexAlgoName )

def setNodeSegmentV6( mode, index, flexAlgoName, explicitNull, noPhp, label ):
   intfName = mode.intf.name
   ip6IntfConfig = ip6Config.intf.get( intfName )
   v6HostAddrPresent = False
   if ip6IntfConfig:
      for prefix in ip6IntfConfig.addr:
         prefix = Arnet.IpGenPrefix( str( prefix ) )
         if prefix.isHost:
            v6HostAddrPresent = True

   if not v6HostAddrPresent:
      mode.addWarning( "/128 IPv6 address is not configured on the interface" )

   if label:
      if not checkForConflicts( mode, ipv6=True, label=label ):
         updateSrNodeSid( intfName, index, explicitNull, noPhp, flexAlgoName, label,
                          ipv6=True )
   elif not checkForConflicts( mode, index, ipv6=True ):
      updateSrNodeSid( intfName, index, explicitNull, noPhp, flexAlgoName,
                       ipv6=True )

def setNodeSegment( mode, af, index, flexAlgoName,
                    explicitNull=False, noPhp=False, label=None ):
   if af == 'ipv4':
      setNodeSegmentV4( mode, index, flexAlgoName, explicitNull, noPhp, label )
   elif af == 'ipv6':
      setNodeSegmentV6( mode, index, flexAlgoName, explicitNull, noPhp, label )

def noNodeSegment( mode, af, index, label=None, nodeSegmentType=None ):
   deleteSrNodeSid( mode.intf.name, index, label, nodeSegmentType,
                    ipv6=( af == 'ipv6' ) )

srIndexMatcher = CliMatcher.IntegerMatcher( MplsLabel.min,
                              MplsLabel.max - MplsLabel.unassignedMin + 1,
                              helpdesc='Index to be mapped with IP prefix' )
srLabelMatcher = CliMatcher.IntegerMatcher( 16, MplsLabel.max,
                              helpdesc="Absolute Node-SID for the prefix" )

flexAlgoNameMatcher = CliMatcher.DynamicNameMatcher(
   lambda mode: ( [ fad.name for fad in flexAlgoConfig.definition.values() ]
                  if flexAlgoConfig else [] ),
   helpdesc='Algorithm name' )

#-------------------------------------------------------------------------------
# [no|default] node-segment ipv4|ipv6 index <index> [flex-algo <fadName>]
# [ explicit-null | no-php ]
#-------------------------------------------------------------------------------
class NodeSegmentCommand( CliCommand.CliCommandClass ):
   syntax = 'node-segment ( ipv4 | ipv6 ) ( index INDEX ) | ( label LABEL )' \
      ' [ flex-algo FAD_NAME ] [ explicit-null | no-php ]'
   noOrDefaultSyntax = 'node-segment ( ipv4 | ipv6 ) [ ( index [ INDEX ] ) |' \
      ' ( label [ LABEL ] ) ]...'

   data = {
         'node-segment': 'Configure a node segment',
         'ipv4': 'IPv4 node config',
         'ipv6': 'IPv6 node config',
         'index': 'Node segment identifier',
         'INDEX': srIndexMatcher,
         'label': 'Node segment label',
         'LABEL': srLabelMatcher,
         'flex-algo': 'Flexible algorithm',
         'FAD_NAME': flexAlgoNameMatcher,
         'explicit-null': 'Set Explicit Null flag',
         'no-php': 'Set No-PHP flag',
   }

   @staticmethod
   def handler( mode, args ):
      af = 'ipv4' if 'ipv4' in args else 'ipv6'
      index = args.get( 'INDEX' )
      label = args.get( 'LABEL' )
      flexAlgoName = args.get( 'FAD_NAME' )
      explicitNull = 'explicit-null' in args
      noPhp = 'no-php' in args
      setNodeSegment( mode, af, index, flexAlgoName, explicitNull, noPhp, label )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      af = 'ipv4' if 'ipv4' in args else 'ipv6'
      index = args.get( "INDEX" )
      label = args.get( "LABEL" )
      nodeSegmentType = NodeSegmentType.NONE
      if 'index' in args:
         nodeSegmentType = NodeSegmentType.INDEX
      elif 'label' in args:
         nodeSegmentType = NodeSegmentType.LABEL
      noNodeSegment( mode, af, index, label, nodeSegmentType )

class HiddenNodeSegmentCommand( CliCommand.CliCommandClass ):
   syntax = 'node-segment index INDEX'
   noOrDefaultSyntax = 'node-segment index ...'

   data = {
         'node-segment': 'Configure a node segment',
         'index': 'Node segment identifier',
         'INDEX': srIndexMatcher
   }

   @staticmethod
   def handler( mode, args ):
      setNodeSegment( mode, 'ipv4', args[ 'INDEX' ], '' )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      noNodeSegment( mode, 'ipv4', args.get( 'INDEX' ) )

   hidden = True

srIntfModelet.addCommandClass( NodeSegmentCommand )
srIntfModelet.addCommandClass( HiddenNodeSegmentCommand )

def delSegmentRouting():
   mplsConfig.tunnelIgpFecSharing = True
   mplsConfig.reserveServiceLabel = False
   srConfig.enabled = False
   mplsConfig.srProxySegFallbackPopEnabled = False

class CfgSegmentRoutingCmd( CliCommand.CliCommandClass ):
   syntax = 'segment-routing'
   noOrDefaultSyntax = syntax
   data = {
      'segment-routing': CliCommand.guardedKeyword( 'segment-routing',
                                                    'Segment Routing configuration',
                                                    guard=mplsSupportedGuard )
   }

   @staticmethod
   def handler( mode, args ):
      childMode = mode.childMode( SrMode )
      mode.session_.gotoChildMode( childMode )
      srConfig.enabled = True

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      delSegmentRouting()

RouterGeneralMode.addCommandClass( CfgSegmentRoutingCmd )

#-------------------------------------------------------------------------------#
# rtr1(config-sr)#[no|default] fec sharing igp tunnel disabled
# No need to add guard because segment routing mode is already protected by guard
#-------------------------------------------------------------------------------#
class TunnelIgpFecSharingCmd( CliCommand.CliCommandClass ):
   syntax = 'fec sharing igp tunnel disabled'
   noOrDefaultSyntax = syntax

   data = {
      'fec': 'FEC configuration',
      'sharing': 'Share the FEC',
      'igp': 'Running IGP protocol',
      'tunnel': 'Tunnel for the IGP',
      'disabled': 'Disable FEC sharing',
   }

   @staticmethod
   def handler( mode, args ):
      mplsConfig.tunnelIgpFecSharing = False

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      mplsConfig.tunnelIgpFecSharing = True

SrMode.addCommandClass( TunnelIgpFecSharingCmd )

class ProxyNodeSegmentFallbackPopForwardCmd( CliCommand.CliCommandClass ):
   syntax = "proxy-node-segment mpls fallback pop forward"
   noOrDefaultSyntax = "proxy-node-segment mpls fallback ..."

   data = {
      'proxy-node-segment': 'Node segment on behalf of another node',
      'mpls': 'MPLS configuration',
      'fallback': 'Fallback entry if SR reachability is not available',
      'pop': 'Pop the top label',
      'forward': 'Forward the packet based on IGP best path',
   }

   @staticmethod
   def handler( mode, args ):
      mplsConfig.srProxySegFallbackPopEnabled = True

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      mplsConfig.srProxySegFallbackPopEnabled = False

if toggleSrProxySegDefaultPopEnabled():
   SrMode.addCommandClass( ProxyNodeSegmentFallbackPopForwardCmd )

class ReserveServiceLabelCmd( CliCommand.CliCommandClass ):
   syntax = 'maximum-sid-depth reserve service-label'
   noOrDefaultSyntax = syntax

   data = {
         'maximum-sid-depth': 'Maximum depth of MPLS label stack',
         'reserve': 'Reserve space for',
         'service-label': 'Service-label in MPLS label stack',
   }

   @staticmethod
   def handler( mode, args ):
      mplsConfig.reserveServiceLabel = True

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      mplsConfig.reserveServiceLabel = False

if toggleReserveServiceLabelEnabled():
   SrMode.addCommandClass( ReserveServiceLabelCmd )
#---------------------------------------------------------------
# Remove all segment-routing configs when the parent is removed
# i.e., "no router general" in config mode.
#---------------------------------------------------------------

def noOrDefaultRouterGeneralMode( mode=None ):
   delSegmentRouting()

routerGeneralCleanupHook.addExtension( noOrDefaultRouterGeneralMode )

#-------------------------------------------------------------------------------
# Have the Cli Agent mount all needed state from sysdb
#-------------------------------------------------------------------------------
def Plugin( entMan ):
   global srConfig
   global mplsConfig
   global ipConfig
   global ip6Config
   global mplsHwCapability
   global flexAlgoConfig

   entityManager = entMan

   # pkgdeps: rpmwith %{_libdir}/SysdbMountProfiles/ConfigAgent-SegmentRoutingCli
   srConfig = ConfigMount.mount( entityManager, "routing/sr/config",
                                 "Routing::SegmentRoutingCli::Config", 'w' )

   mplsConfig = ConfigMount.mount( entityManager, "routing/mpls/config",
                                   "Mpls::Config", "w" )
   mplsHwCapability = LazyMount.mount( entityManager,
                                       "routing/hardware/mpls/capability",
                                       "Mpls::Hardware::Capability",
                                       "r" )
   ipConfig = LazyMount.mount( entityManager, "ip/config", "Ip::Config", "r" )
   ip6Config = LazyMount.mount( entityManager, "ip6/config", "Ip6::Config", "r" )
   flexAlgoConfig = LazyMount.mount( entityManager, "te/flexalgo/config",
                                     "FlexAlgo::Config", "r" )

   IntfCli.Intf.registerDependentClass( SegmentRoutingIntf, priority=20 )
