#!/usr/bin/env python3
# Copyright (c) 2021 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Arnet
import CliCommand
import CliMatcher
import ConfigMount
from McastCommonCliLib import AddressFamily
from CliPlugin import RouterMulticastCliLib
from CliPlugin.RouterMulticastCliLib import (
   RouterMulticastMode,
   RouterMulticastSharedModelet,
   getVrfNameFromMode,
   RouterMulticastMplsStaticTunnelMode,
   RouterModeCallbackBase
)
from CliPlugin.MplsCli import mplsMvpnSupportedGuard, nextHopMatcher
from CliPlugin.MplsCli import labelStackKeywordNode, labelStackValNode
from CliPlugin.MplsCli import labelOperation
from CliPlugin.MplsCli import intfValMatcher, validateLabelStackSize, MplsVia
from CliPlugin.MplsCli import mplsNodeForConfig, \
      mplsStaticVrfLabelSupported, topLabelValMatcher
from CliPlugin.MrouteCli import minSupportVersion, enterSubmode, deleteSubmode
from CliPlugin.StaticMrouteCli import groupAddrMatcher, sourceAddrMatcher
from CliPlugin import IpAddrMatcher
from CliCommand import isNoOrDefaultCmd
import Tac
from TypeFuture import TacLazyType
import Toggles.McastCommonToggleLib

# pkgdeps: library MplsSysdbTypes

_mvpnConfig = None
_mplsVrfLabelConfig = None
_staticConfig = None
_multicastLegacyConfig = None

VrfLabel = TacLazyType( 'Mpls::VrfLabel' )
MvpnConfig = TacLazyType( 'Routing::Multicast::MvpnConfig' )
StaticTunnelSourceGroupConfigEntry = TacLazyType(
   'Tunnel::Static::Multicast::StaticTunnelSourceGroupConfigEntry' )
StaticSourceGroup = TacLazyType(
   'Tunnel::Static::Multicast::StaticSourceGroup' )

IpGenAddr = Tac.Type( 'Arnet::IpGenAddr' )

mplsKwMatcher = CliMatcher.KeywordMatcher( 'mpls',
                                helpdesc='Configure MPLS command(s)' )
tunnelKwMatcher = CliMatcher.KeywordMatcher( 'tunnel',
                                helpdesc='Configure MPLS tunnel command(s)' )
staticKwMatcher = CliMatcher.KeywordMatcher( 'static',
                                helpdesc='Configure MPLS tunnel static command(s)' )
mplsTunnelStaticNode = CliCommand.Node( staticKwMatcher,
                                        guard=mplsMvpnSupportedGuard )

mvpnStaticTunnelSelectiveEnabled = \
   Toggles.McastCommonToggleLib.toggleMvpnStaticTunnelSelectiveEnabled()

class RouterMulticastMplsStaticTunnel( CliCommand.CliCommandClass ):
   syntax = ( f"mpls tunnel static"
              f"{' [ GROUP SOURCE ]' if mvpnStaticTunnelSelectiveEnabled else ''}" )
   noOrDefaultSyntax = syntax
   data = {
      'mpls' : mplsKwMatcher,
      'tunnel' : tunnelKwMatcher,
      'static' : mplsTunnelStaticNode,
      'GROUP' : groupAddrMatcher,
      'SOURCE' : sourceAddrMatcher
   }

   @staticmethod
   def handler( mode, args ):
      minSupportVersion( _multicastLegacyConfig.ipMode )
      vrfName = RouterMulticastCliLib.getVrfNameFromMode( mode )
      af = RouterMulticastCliLib.getAddressFamilyFromMode( mode )
      src = args.get( 'SOURCE' )
      grp = args.get( 'GROUP' )
      zeroGenAddr = IpGenAddr.ipGenAddrZero( af )
      srcGenAddr = Arnet.IpGenAddr( str( src ) ) if src else zeroGenAddr
      grpGenAddr = Arnet.IpGenAddr( str( grp ) ) if grp else zeroGenAddr

      if grp or src:
         # Check for multicast grp and unicast src address
         if srcGenAddr.isMulticast == grpGenAddr.isMulticast:
            mode.addError( "Need valid multicast group and unicast source adresses" )
            return

         # Do not change config if src or grp is zero
         if srcGenAddr.isAddrZero or grpGenAddr.isAddrZero:
            return

         # Check if src and grp addresses were interchanged
         if srcGenAddr.isMulticast and not grpGenAddr.isMulticast:
            srcGenAddr, grpGenAddr = grpGenAddr, srcGenAddr

      childMode = mode.childMode(
         RouterMulticastCliLib.RouterMulticastMplsStaticTunnelMode,
         vrfName=vrfName, af=af, source=srcGenAddr, group=grpGenAddr )
      mode.session_.gotoChildMode( childMode )
      enterSubmode( "mpls tunnel static", vrfName=vrfName, af=af, source=srcGenAddr,
                    group=grpGenAddr )

      staticTunnelConfigEntry = _staticConfig.newEntry( vrfName )
      staticTunnelConfigEntry.af = af

      sourceGroup = StaticSourceGroup( srcGenAddr, grpGenAddr )

      sourceGroupConfigEntry = \
         staticTunnelConfigEntry.sourceGroupEntry.get( sourceGroup )
      if not sourceGroupConfigEntry:
         sourceGroupConfigEntry = StaticTunnelSourceGroupConfigEntry( sourceGroup )
         staticTunnelConfigEntry.sourceGroupEntry.addMember(
            sourceGroupConfigEntry )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      vrfName = RouterMulticastCliLib.getVrfNameFromMode( mode )
      af = RouterMulticastCliLib.getAddressFamilyFromMode( mode )
      src = args.get( 'SOURCE' )
      grp = args.get( 'GROUP' )
      zeroGenAddr = IpGenAddr.ipGenAddrZero( af )
      srcGenAddr = Arnet.IpGenAddr( str( src ) ) if src else zeroGenAddr
      grpGenAddr = Arnet.IpGenAddr( str( grp ) ) if grp else zeroGenAddr

      if vrfName in _staticConfig.entry:
         staticTunnelConfigEntry = _staticConfig.entry[ vrfName ]
         sourceGroup = StaticSourceGroup( srcGenAddr, grpGenAddr )

         del staticTunnelConfigEntry.sourceGroupEntry[ sourceGroup ]
         deleteSubmode( "mpls tunnel static", vrfName=vrfName, af=af,
                        source=srcGenAddr, group=grpGenAddr )

         if not staticTunnelConfigEntry.sourceGroupEntry:
            del _staticConfig.entry[ vrfName ]

RouterMulticastCliLib.RouterMulticastIpv4Modelet.addCommandClass(
   RouterMulticastMplsStaticTunnel )

def enableMvpnStatic( mode, args ) :
   _mvpnConfig.mvpnStatic = True

def disableMvpnStatic( mode, args ):
   _mvpnConfig.mvpnStatic = False

mvpnKwMatcher = CliMatcher.KeywordMatcher( 'mvpn',
                           helpdesc='Mvpn commands' )
mvpnNode = CliCommand.Node( matcher=mvpnKwMatcher,
                            guard=mplsMvpnSupportedGuard )
ipv4KwMatcher = CliMatcher.KeywordMatcher( 'ipv4',
                                           helpdesc='Ipv4 addressfamily' )
staticKwMatcher = CliMatcher.KeywordMatcher( 'static',
   helpdesc='Static creation command' )
pmsiKwMatcher = CliMatcher.KeywordMatcher( 'pmsi',
   helpdesc='PMSI interface' )

class MvpnStatic( CliCommand.CliCommandClass ):
   syntax = "mvpn ipv4 static pmsi"
   noOrDefaultSyntax = syntax
   data = {
      'mvpn': mvpnNode,
      'ipv4': ipv4KwMatcher,
      'static': staticKwMatcher,
      'pmsi': pmsiKwMatcher,
      }
   handler = enableMvpnStatic
   noOrDefaultHandler = disableMvpnStatic

def updateTunnelVias( mode, args ):
   vrfName = mode.vrfName
   source = mode.source
   group = mode.group
   nexthop = args[ 'ADDR' ]
   intf = args[ 'INTF' ]
   labels = args[ 'LABELS' ]

   add = not isNoOrDefaultCmd( args )
   if add:
      if not validateLabelStackSize( mode, labels ):
         return
      if len( labels ) > 1:
         labels = [ labels[ -1 ], ]
         mode.addWarning( 'Only one label is supported for now. Only the last '
                          'label will be used' )

   assert vrfName in _staticConfig.entry
   staticTunnelConfigEntry = _staticConfig.entry[ vrfName ]
   sourceGroup = StaticSourceGroup( source, group )
   assert sourceGroup in staticTunnelConfigEntry.sourceGroupEntry
   sourceGroupConfigEntry = \
      Tac.nonConst( staticTunnelConfigEntry.sourceGroupEntry.get( sourceGroup ) )

   nexthopAddr = Arnet.IpGenAddr( str( nexthop ) )
   intfId = Tac.Value( 'Arnet::IntfId', intf.name )
   mplsVia = MplsVia( nexthopAddr, intfId, labelOperation( labels ) )
   if add:
      sourceGroupConfigEntry.via[ mplsVia ] = True
   else:
      del sourceGroupConfigEntry.via[ mplsVia ]
   staticTunnelConfigEntry.sourceGroupEntry.addMember( sourceGroupConfigEntry )

class MvpnTunnelStatic( CliCommand.CliCommandClass ):
   syntax = 'next-hop ADDR INTF label-stack { LABELS }'
   noOrDefaultSyntax = syntax
   data = {
      'next-hop' : nextHopMatcher,
      'ADDR' : IpAddrMatcher.IpAddrMatcher( "Address of the nexthop router" ),
      'INTF' : intfValMatcher,
      'label-stack' : labelStackKeywordNode,
      'LABELS': labelStackValNode,
      }

   handler = updateTunnelVias
   noOrDefaultHandler = handler

RouterMulticastMode.addCommandClass( MvpnStatic )
RouterMulticastMplsStaticTunnelMode.addCommandClass( MvpnTunnelStatic )

#------------------------------------------
# [no] mpls static label <label>
#------------------------------------------
class MplsStaticVrfLabelCmd( CliCommand.CliCommandClass ):
   syntax = "mpls static label LABEL"
   noOrDefaultSyntax = syntax
   data = {
      'mpls': mplsNodeForConfig,
      'static': staticKwMatcher,
      'label': CliCommand.guardedKeyword( 'label',
                      helpdesc="Specify the MPLS label to vrf mapping",
                      guard=mplsStaticVrfLabelSupported ),
      'LABEL': topLabelValMatcher,
   }

   @staticmethod
   def handler( mode, args ):
      # Instantiate a via which contains info for vrf-name
      vrfName = getVrfNameFromMode( mode )
      labelValue = args[ 'LABEL' ]
      if labelValue in _mplsVrfLabelConfig.vrfLabel:
         mode.addError( 'label currently in use' )
      else:
         vrfLabel = VrfLabel( labelValue, vrfName )
         vrfLabel.multicast = True
         # add it to sysdb
         _mplsVrfLabelConfig.addVrfLabel( vrfLabel )

   @staticmethod
   def noOrDefaultHandler( mode, args ):
      labelVal = args[ 'LABEL' ]
      if labelVal not in _mplsVrfLabelConfig.vrfLabel:
         return
      vrfLabel = _mplsVrfLabelConfig.vrfLabel[ labelVal ]
      vrfName = getVrfNameFromMode( mode )
      if vrfLabel.vrfName == vrfName:
         del _mplsVrfLabelConfig.vrfLabel[ labelVal ]

RouterMulticastSharedModelet.addCommandClass( MplsStaticVrfLabelCmd )

class RouterMulticastListener( RouterModeCallbackBase ):
   def modeDeleted( self, **kwargs ):
      _mvpnConfig.mvpnStatic = False
      _staticConfig.entry.clear()
      _mplsVrfLabelConfig.vrfLabel.clear()

   def vrfModeDeleted( self, vrfName, **kwargs ):
      del _staticConfig.entry[ vrfName ]
      for label in _mplsVrfLabelConfig.vrfLabel:
         if _mplsVrfLabelConfig.vrfLabel[ label ].vrfName == vrfName:
            del _mplsVrfLabelConfig.vrfLabel[ label ]
            break

   def afModeDeleted( self, vrfName, af, **kwargs ):
      del _staticConfig.entry[ vrfName ]

def Plugin( entityManager ):
   global _mplsVrfLabelConfig
   global _mvpnConfig, _multicastLegacyConfig, _staticConfig

   _mvpnConfig = ConfigMount.mount( entityManager,
                                    MvpnConfig.mountPath( AddressFamily.ipv4 ),
                                    'Routing::Multicast::MvpnConfig', 'w' )

   _mplsVrfLabelConfig = ConfigMount.mount( entityManager,
                        'routing/mpls/multicast/vrfLabel/input/cli',
                        'Mpls::VrfLabelConfigInput', 'wi' )

   _staticConfig = ConfigMount.mount( entityManager,
                                      'tunnel/static/multicast/config',
                                      'Tunnel::Static::Multicast::Config', 'w' )
   _multicastLegacyConfig = ConfigMount.mount( entityManager,
      'routing/multicast/legacyconfig',
      'Routing::Multicast::MulticastLegacyConfig', 'w' )

   RouterMulticastMode.registerCallback( RouterMulticastListener() )
