#!/usr/bin/env python3
# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import LazyMount
import BasicCli, Tac, Arnet
from IpLibConsts import DEFAULT_VRF
import McastCommonCliLib
from McastCommonCliLib import createMcastIntfConfig
from McastCommonCliLib import AddressFamily
from McastCommonCliLib import validateMulticastAddress
from McastCommonCliLib import mcastGenRoutingSupportedGuard
from McastCommonCliLib import validateRouting
from McastCommonCliLib import RoutePriority
import CliToken.Ip
from CliPlugin import EthIntfCli
from CliPlugin import Ip6AddrMatcher
from CliPlugin import IpAddrMatcher
from CliPlugin import IraIpCli
from CliPlugin import IraIpIntfCli
from CliPlugin import IraIpRouteCliLib
from CliPlugin import LagIntfCli
from CliPlugin import MrouteCli
from CliPlugin import RouterMulticastCliLib
from CliPlugin import SubIntfCli
from CliPlugin import VlanIntfCli
from CliPlugin.IntfCli import IntfConfigMode
from CliPlugin.RouterMulticastCliLib import (
      RouterModeCallbackBase,
      RouterMulticastMode,
      configGetters,
      legacyCliCallback,
)
from CliPlugin.VirtualIntfRule import VirtualIntfMatcher, IntfMatcher
from CliPlugin import TunnelIntfCli
from CliPlugin import SwitchIntfCli
import CliCommand
import CliMatcher

StaticMrouteConfigColl = "McastCommon::StaticMrouteConfigColl"

ipGenStatus = None

( mfibConfigRoot,
  mfibConfigRootFromMode,
  mfibConfig,
  mfibConfigFromMode ) = configGetters( MrouteCli.MfibVrfConfig,
                                        collectionName='config' )
( mcastStaticConfigColl,
  mcastStaticConfigCollFromMode,
  mcastStaticConfig,
  mcastStaticConfigFromMode ) = configGetters( StaticMrouteConfigColl )
def mcastStaticConfigCollFromFamily( ipFamily, legacy=False ):
   if legacy == True: # pylint: disable=singleton-comparison
      af = AddressFamily.ipv4
   elif ipFamily is None:
      # Applies to common config
      af = AddressFamily.ipunknown
   else:
      af = ipFamily
   return mcastStaticConfigColl( af )

def _staticMcastVrfCreationHook( vrfName ):
   for af in [ AddressFamily.ipv4, AddressFamily.ipv6 ]:
      mcastStaticConfig( af, vrfName )

def _staticMcastVrfDeletionHook( vrfName ):
   for af in [ AddressFamily.ipv4, AddressFamily.ipv6 ]:
      if vrfName == DEFAULT_VRF:
         mcastStaticConfig( af, vrfName ).reset()
      else:
         _mcastStaticConfigColl = mcastStaticConfigColl( af )
         if vrfName in _mcastStaticConfigColl.vrfConfig:
            del _mcastStaticConfigColl.vrfConfig[ vrfName ]

def mcastStaticConfigCommand( func ):
   ''' Decorator for interface config commands.  If either routing or
   multicast-routing is disabled, print warning. '''
   def newFunc( mode, *args, **kwargs):
      if McastCommonCliLib.ipFamilyAlias in kwargs[ 'args' ]:
         af = kwargs[ 'args' ][ McastCommonCliLib.ipFamilyAlias ]
      elif kwargs.get( 'legacy' ):
         af = AddressFamily.ipv4
      else:
         af = mode.af

      if hasattr( mode, 'vrfName' ):
         vrfName = mode.vrfName
      elif isinstance( mode, IntfConfigMode ):
         vrfName = McastCommonCliLib.getVrfNameFromIntf( ipGenStatus,
                                                         mode.intf.name )
      else:
         vrfName = DEFAULT_VRF
      validateRouting( mode, vrfName, af, msgType='warn' )
      return func( mode, *args, **kwargs )
   return newFunc

mcastKwMatcher = CliMatcher.KeywordMatcher( 'multicast',
   helpdesc='Multicast routing commands' )
mcastNode = CliCommand.Node( matcher=mcastKwMatcher,
   guard=mcastGenRoutingSupportedGuard )
mcastDeprecatedNode = CliCommand.Node( matcher=mcastKwMatcher,
   guard=mcastGenRoutingSupportedGuard )
staticKwMatcher = CliMatcher.KeywordMatcher( 'static',
   helpdesc='Static Multicast routes allowed' )
fastdropKwMatcher = CliMatcher.KeywordMatcher( 'fastdrop',
   helpdesc='Fastdrop routes' )
fastdropStaticKwMatcher = CliMatcher.KeywordMatcher( 'static',
   helpdesc='Enable dynamic fastdrop route creation' )

modelet = IraIpIntfCli.RoutingProtocolIntfConfigModelet

def noMcastStatic( mode, args, legacy=False ):
   ipFamily = args.get( McastCommonCliLib.ipFamilyAlias )
   _mcastStaticConfigColl = mcastStaticConfigCollFromFamily( ipFamily, legacy )
   msic = _mcastStaticConfigColl.intfConfig.get( mode.intf.name )
   if not msic:
      return
   del _mcastStaticConfigColl.intfConfig[ mode.intf.name ]

noMcastStaticLegacy = legacyCliCallback( noMcastStatic )

@mcastStaticConfigCommand
def setMcastStatic( mode, args, legacy=False ):
   ipFamily = args.get( 'AF' )
   _mcastStaticConfigColl = mcastStaticConfigCollFromFamily( ipFamily,
                                                             legacy=legacy )
   if mode.intf.name not in _mcastStaticConfigColl.intfConfig:
      def createFcn():
         _mcastStaticConfigColl.intfConfig[ mode.intf.name ] = True

      vrfName = McastCommonCliLib.getVrfNameFromIntf( ipGenStatus, mode.intf.name,
                                                      af=ipFamily )

      createMcastIntfConfig( mode, vrfName, ipFamily,
                             mode.intf.name, createFcn, True )

setMcastStaticLegacy = legacyCliCallback( setMcastStatic )

class LegacyMcastStatic( CliCommand.CliCommandClass ):
   syntax = 'ip multicast static'
   noOrDefaultSyntax = syntax + " ..."
   data = {
      'ip': CliToken.Ip.ipMatcherForConfigIf,
      'multicast': mcastDeprecatedNode,
      'static': staticKwMatcher
   }
   handler = setMcastStaticLegacy
   noOrDefaultHandler = noMcastStaticLegacy

modelet.addCommandClass( LegacyMcastStatic )

class McastStatic( CliCommand.CliCommandClass ):
   syntax = 'multicast AF static'
   noOrDefaultSyntax = syntax + " ..."
   data = {
      'multicast': mcastNode,
      'AF': McastCommonCliLib.IpFamilyExpr,
      'static': staticKwMatcher
   }
   handler = setMcastStatic
   noOrDefaultHandler = noMcastStatic

modelet.addCommandClass( McastStatic )

groupAddrMatcher = IpAddrMatcher.IpAddrMatcher( "Group address" )
sourceAddrMatcher = IpAddrMatcher.IpAddrMatcher( "Source address" )
numGroupsMatcher = CliMatcher.IntegerMatcher( 1, 1000,
   helpdesc='the number of multicast groups to be configured' )

#ipv6
group6AddrMatcher = Ip6AddrMatcher.Ip6AddrMatcher( "IPv6 Group address" )
source6AddrMatcher = Ip6AddrMatcher.Ip6AddrMatcher( "IPv6 Source address" )
iifKwMatcher = CliMatcher.KeywordMatcher( 'iif',
   helpdesc="Incoming interface" )
iifFrrKwMatcher = CliMatcher.KeywordMatcher( 'iifFrr',
   helpdesc="MoFrr interface" )

intfMatcher = IntfMatcher()
intfMatcher |= VlanIntfCli.VlanIntf.matcher
intfMatcher |= EthIntfCli.EthPhyIntf.ethMatcher
intfMatcher |= LagIntfCli.EthLagIntf.matcher
intfMatcher |= SubIntfCli.subMatcher
intfMatcher |= LagIntfCli.subMatcher
intfMatcher |= VirtualIntfMatcher( 'Register', 0, 0,
   helpdesc='Interface that drops all traffic' )
intfMatcher |= IraIpCli.nullIntfMatcher
intfMatcher |= VirtualIntfMatcher( 'Pmsi', 0, 0,
   helpdesc='PMSI interface for the VRF' )
intfMatcher |= TunnelIntfCli.TunnelIntf.matcher
intfMatcher |= SwitchIntfCli.SwitchIntf.matcher

oifKwMatcher = CliMatcher.KeywordMatcher( 'oif', helpdesc="Outgoing interfaces" )
class OifsExpr( CliCommand.CliExpression ):
   expression = 'oif { OIFS }'
   data = {
      'oif': oifKwMatcher,
      'OIFS': intfMatcher
   }

cpuKwMatcher = CliMatcher.KeywordMatcher( 'cpu', helpdesc="Copy to CPU" )

priorityKwMatcher = CliMatcher.KeywordMatcher( 'priority',
   helpdesc='Route Priority' )
routePriorityMatcher = CliMatcher.IntegerMatcher(
      RoutePriority.min,
      RoutePriority.max,
   helpdesc='Priority of route' )

class RoutePriorityExpr( CliCommand.CliExpression ):
   expression = 'priority PRIORITY'
   data = {
      'priority': priorityKwMatcher,
      'PRIORITY': routePriorityMatcher
   }

class IifFrrExpr( CliCommand.CliExpression ):
   expression = 'iifFrr IIFFRR'
   data = {
      'iifFrr': iifFrrKwMatcher,
      'IIFFRR': intfMatcher
   }

def sgKey( s, g ):
   group = McastCommonCliLib.toPrefix( g )
   if s:
      source = McastCommonCliLib.toPrefix( s )
   else:
      source = McastCommonCliLib.defaultPrefix( group.af )
   return Tac.Value( "Routing::Multicast::Fib::IpGenRouteKey", source, group )

@mcastStaticConfigCommand
def setMulticastRoute( mode, group, args, legacy=False ):
   source = args.get( "SOURCE" )
   intf = args[ "IIF" ]
   outputIntfs = args.get( "OIFS" )
   cpu = args.get( "cpu" )
   iifFrr = args.get( "IIFFRR" )
   rtPriority = args.get( "PRIORITY" ) or RoutePriority.staticPriority
   routeKey = sgKey( source, group )
   mcastCommonConfig = mcastStaticConfigFromMode( mode, legacy=legacy )
   if not routeKey in mcastCommonConfig.staticMcastRoute:
      mcastCommonConfig.newStaticMcastRoute( routeKey )
   smr = mcastCommonConfig.staticMcastRoute[ routeKey ]
   smr.iif = str( intf )
   if iifFrr:
      smr.iifFrr = str( iifFrr )
   smr.routePriority = rtPriority

   oldOifs = set( smr.oifs.keys() )
   newOifs = {  oif.name if hasattr( oif, 'name' ) else
               oif for oif in outputIntfs  } if outputIntfs else set()

   smr.toCpu = not not cpu # pylint: disable=unneeded-not
   smr.rpaId = 0

   if oldOifs == newOifs:
      return None

   addedOifs = newOifs - oldOifs
   for oif in addedOifs:
      # ignore the ingress interface if it is being set as an output interface
      if oif != str( intf ):
         smr.oifs[ oif ] = True

   deletedOifs = oldOifs - newOifs
   for oif in deletedOifs:
      del smr.oifs[ oif ]

   return None

def setMulticastRoutes( mode, args, legacy=False ):
   group = args[ "GROUP" ]
   numGroups = args.get( "NUM_GROUPS" )
   err = validateMulticastAddress( group )
   if err:
      mode.addError( err )
      return

   if not numGroups:
      setMulticastRoute( mode, group, args=args, legacy=legacy )
   else:
      assert RouterMulticastCliLib.getAddressFamilyFromMode( mode, legacy ) == \
            AddressFamily.ipv4
      for g in Arnet.getMcastGroupAddresses( group, numGroups ):
         setMulticastRoute( mode, g, args=args, legacy=True )

setMulticastRoutesLegacy = legacyCliCallback( setMulticastRoutes )

#-------------------------------------------------------------------------------
#    + if only the group is specified we only delete the route with only a group
#      not all the routes with that group
#-------------------------------------------------------------------------------
def noMulticastRoute( mode, group, source=None, rtPriority=None, legacy=False ):
   routeKey = sgKey( source, group )
   mcastCommonConfig = mcastStaticConfigFromMode( mode, legacy=legacy )
   smr = mcastCommonConfig.staticMcastRoute.get( routeKey )
   if smr:
      if rtPriority:
         mcastCommonConfig.staticMcastRoute[ routeKey ].routePriority = \
               RoutePriority.staticPriority
      else:
         del mcastCommonConfig.staticMcastRoute[ routeKey ]

def noMulticastRoutes( mode, args, legacy=False ):
   group = args[ "GROUP" ]
   source = args.get( "SOURCE" )
   numGroups = args.get( "NUM_GROUPS" )
   rtPriority = args.get( "priority" )

   err = validateMulticastAddress( group )
   if err:
      mode.addError( err )
      return

   if not numGroups:
      noMulticastRoute( mode, group, source=source,
                        rtPriority=rtPriority, legacy=legacy )
   else:
      for g in Arnet.getMcastGroupAddresses( group, numGroups ):
         noMulticastRoute( mode, g, source=source,
                           rtPriority=rtPriority, legacy=legacy )

noMulticastRoutesLegacy = legacyCliCallback( noMulticastRoutes )

class McastRoutesBase( CliCommand.CliCommandClass ):
   baseData = {
      "route": IraIpRouteCliLib.routeMatcherForConfig,
      "iif": iifKwMatcher,
      "IIF": intfMatcher,
      "OIFS": OifsExpr,
      "cpu": cpuKwMatcher,
      "IIFFRR": IifFrrExpr,
      "ROUTE_PRIORITY": RoutePriorityExpr,
      "priority": priorityKwMatcher,
   }

class LegacyMcastRoutes( McastRoutesBase ):
   syntax = ( "ip route multicast GROUP [ NUM_GROUPS ] [ SOURCE ] iif IIF "
              "[ OIFS ] [ cpu ] [ IIFFRR ] [ ROUTE_PRIORITY ]" )
   noOrDefaultSyntax = ( "ip route multicast GROUP [ NUM_GROUPS ] [ SOURCE ] "
                         "[ priority ] ..." )
   data = {
      "ip": CliToken.Ip.ipMatcherForConfig,
      "multicast": mcastDeprecatedNode,
      "GROUP": groupAddrMatcher,
      "NUM_GROUPS": numGroupsMatcher,
      "SOURCE": sourceAddrMatcher,
   }
   data.update( McastRoutesBase.baseData )
   handler = setMulticastRoutesLegacy
   noOrDefaultHandler = noMulticastRoutesLegacy

class McastRoutes( McastRoutesBase ):
   syntax = ( "route GROUP [ NUM_GROUPS ] [ SOURCE ] iif IIF "
              "[ OIFS ] [ cpu ] [ IIFFRR ] [ ROUTE_PRIORITY ]" )
   noOrDefaultSyntax = "route GROUP [ NUM_GROUPS ] [ SOURCE ] [ priority ] ..."
   data = {
      "GROUP": groupAddrMatcher,
      "NUM_GROUPS": numGroupsMatcher,
      "SOURCE": sourceAddrMatcher,
   }
   data.update( McastRoutesBase.baseData )
   handler = setMulticastRoutes
   noOrDefaultHandler = noMulticastRoutes

#TODO: Enable numGroupsRule as soon as the address generated
class Mcast6Routes( McastRoutesBase ):
   syntax = ( "route GROUP [ SOURCE ] iif IIF "
              "[ OIFS ] [ cpu ] [ IIFFRR ] [ ROUTE_PRIORITY ]" )
   noOrDefaultSyntax = "route GROUP [ SOURCE ] [ priority ] ..."
   data = {
      "GROUP": group6AddrMatcher,
      "SOURCE": source6AddrMatcher,
   }
   data.update( McastRoutesBase.baseData )
   handler = setMulticastRoutes
   noOrDefaultHandler = noMulticastRoutes

#---------------------------------------------------------------
# The "[no | default] fastdrop static" command, in
# router multicast/vrf/af config mode.
#---------------------------------------------------------------
def setFastdropStatic( mode, args ):
   staticMrouteConfig = mcastStaticConfigFromMode( mode )
   staticMrouteConfig.fastdropEnabled = True

def noFastdropStatic( mode, args=None ):
   staticMrouteConfig = mcastStaticConfigFromMode( mode )
   # Only update clearFastdrop if fastdropEnabled is being changed from True
   # to False. This prevents any side effects from occuring if
   # [ no | default ] fastdrop static is used when fastdrop static is already
   # disabled.
   if staticMrouteConfig.fastdropEnabled:
      vrfName = RouterMulticastCliLib.getVrfNameFromMode( mode )
      configRoot = mfibConfigRootFromMode( mode )
      if vrfName in configRoot.config:
         mfibVrfConfig = mfibConfigFromMode( mode )
         mfibVrfConfig.clearFastdrop = Tac.now()
   staticMrouteConfig.fastdropEnabled = False

class FastdropResetCallback( RouterModeCallbackBase ):
   def modeDeleted( self, **kwargs ):
      for af in [ AddressFamily.ipv4, AddressFamily.ipv6 ]:
         self.af = af # pylint: disable=attribute-defined-outside-init
         noFastdropStatic( self )

RouterMulticastMode.registerCallback( FastdropResetCallback() )

class FastdropStatic( CliCommand.CliCommandClass ):
   syntax = 'fastdrop static'
   noOrDefaultSyntax = syntax
   data = {
      'fastdrop': fastdropKwMatcher,
      'static': fastdropStaticKwMatcher,
   }
   handler = setFastdropStatic
   noOrDefaultHandler = noFastdropStatic


RouterMulticastCliLib.RouterMulticastSharedModelet.addCommandClass(
   LegacyMcastRoutes )
BasicCli.GlobalConfigMode.addCommandClass( LegacyMcastRoutes )

RouterMulticastCliLib.RouterMulticastIpv4Modelet.addCommandClass(
   McastRoutes )

RouterMulticastCliLib.RouterMulticastIpv6Modelet.addCommandClass(
   Mcast6Routes )

RouterMulticastCliLib.RouterMulticastAfSharedModelet.addCommandClass(
   FastdropStatic )

def Plugin( entityManager ):
   global ipGenStatus

   #Af independent Config mounts
   configTypes = [ StaticMrouteConfigColl, ]
   RouterMulticastCliLib.doConfigMounts( entityManager, configTypes )

   ipConfig = LazyMount.mount( entityManager, 'ip/config', 'Ip::Config', 'r' )
   ip6Config = LazyMount.mount( entityManager, 'ip6/config', 'Ip6::Config', 'r' )
   ipStatus = LazyMount.mount( entityManager, 'ip/status', 'Ip::Status', 'r' )
   ip6Status = LazyMount.mount( entityManager, 'ip6/status', 'Ip6::Status', 'r' )

   ipGenStatus = McastCommonCliLib.ipGenStatusInit( ipConfig, ipStatus, ip6Config,
                                                    ip6Status )

   MrouteCli.routerMcastVrfDefinitionHook.addExtension( _staticMcastVrfCreationHook )
   MrouteCli.routerMcastVrfDeletionHook.addExtension( _staticMcastVrfDeletionHook )

   # Install a hook for static mroute interface collection so that
   # Cli can keep track of all multicast routing interfaces engaged.
   def _staticMrouteIntfConfColl( vrfName, af ):
      coll = { }

      # pylint: disable-next=consider-using-in
      if af != AddressFamily.ipv4 and af != AddressFamily.ipv6:
         return coll

      _coll = mcastStaticConfigCollFromFamily( af )
      for staticMrouteIntf in _coll.intfConfig:
         if vrfName == McastCommonCliLib.getVrfNameFromIntf(
                       ipGenStatus, staticMrouteIntf, af):
            coll[ staticMrouteIntf ] = _coll.intfConfig[ staticMrouteIntf ]

      return coll
   McastCommonCliLib.mcastIfCollHook.addExtension( _staticMrouteIntfConfColl )
