#!/usr/bin/env python3
# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from CliMode.McastVpn import RemoteDomainMode
from CliSavePlugin import IntfCliSave
import CliSave
import MultiRangeRule
import Tac
import Toggles.McastVpnLibToggleLib

IntfCliSave.IntfConfigMode.addCommandSequence( 'McastVpn.VxlanIntf',
                                               after=[ 'VxlanIntf.config' ] )

def getVlanFloodCmds( intfConfig, rangeSyntax=False ):
   vlanFloodCmds = []
   if rangeSyntax:
      # build reverse mapping
      floodGroupToVlan = {}
      for vlan in intfConfig.vlanToFloodGroup:
         group = \
            intfConfig.vlanToFloodGroup[ vlan ].defaultUnderlayGroup
         if group not in floodGroupToVlan:
            floodGroupToVlan[ group ] = []
         floodGroupToVlan[ group ].append( vlan )

      # create multiRange of vlan per group, sort output by group
      for group in sorted( floodGroupToVlan ):
         multiRangeString = MultiRangeRule.multiRangeToCanonicalString(
               floodGroupToVlan[ group ] )
         vlanFloodCmds.append( 'vxlan vlan %s flood group %s' %
               ( multiRangeString, group ) )
   else:
      for vlan in sorted( intfConfig.vlanToFloodGroup ):
         configEntry = intfConfig.vlanToFloodGroup[ vlan ]
         vlanFloodCmds.append( 'vxlan vlan %d flood group %s'
               % ( vlan, configEntry.defaultUnderlayGroup ) )
   return vlanFloodCmds

def vxlanStrForDomain( remoteDomain ):
   # Skip the vxlan str for "domain remote" mode
   # In the new style for these commands under the "interface vxlan1"
   # we decided to skip the vxlan keyword
   vxlanStr = "vxlan "
   if remoteDomain:
      vxlanStr = ""
   return vxlanStr

def getVlanMulticastCmds( intfConfig, remoteDomain=False ):
   vlanMulticastCmds = []

   vxlanStr = vxlanStrForDomain( remoteDomain )

   for vlan in sorted( intfConfig.vlanToMulticastGroup ):
      configEntryStr = intfConfig.vlanToMulticastGroup[ vlan ].tunnelConfigEntry.\
            defaultUnderlayGroup.stringValue
      vlanMulticastCmds.append(
            f'{vxlanStr}vlan {vlan} multicast group {configEntryStr}' )

   return vlanMulticastCmds

def getVrfMulticastCmds( intfConfig, remoteDomain=False ):
   vrfMulticastCmds = []
   vxlanStr = vxlanStrForDomain( remoteDomain )
   for vrf in sorted( intfConfig.vrfToMulticastGroup ):
      configEntry = intfConfig.vrfToMulticastGroup[ vrf ].tunnelConfigEntry
      if configEntry.defaultUnderlayGroup:
         underlayStr = configEntry.defaultUnderlayGroup.stringValue
         vrfMulticastCmds.append(
               f'{vxlanStr}vrf {vrf} multicast group {underlayStr}' )
      if configEntry.underlayGroupRange:
         underlayRangeStr = configEntry.underlayGroupRange.stringValue
         vrfMulticastCmds.append(
            f'{vxlanStr}vrf {vrf} multicast group encap range {underlayRangeStr}'
            ' delayed' )

      immediateColl = configEntry.immediateOverlayToUnderlayGroup
      for overlayGroup in sorted( immediateColl ):
         underlayGroup = immediateColl[ overlayGroup ]
         vrfMulticastCmds.append(
            f'{vxlanStr}vrf {vrf} multicast group overlay'
            f' {overlayGroup.stringValue} encap'
            f' {underlayGroup.stringValue} immediate' )

   return vrfMulticastCmds

def getUnderlayRouteTypeCmds( intfConfig, saveAll ):
   underlayRouteTypeCmds = []

   UnderlayRouteType = \
         Tac.Type( "Routing::Multicast::UnderlayRouteType::RouteType" )

   if intfConfig and \
            intfConfig.mcastUnderlayRouteType == UnderlayRouteType.pimasm:
      underlayRouteTypeCmds.append( 'vxlan multicast protocol pim asm' )
   elif saveAll:
      underlayRouteTypeCmds.append( 'vxlan multicast protocol pim ssm' )
   if Toggles.McastVpnLibToggleLib.toggleMcastVpnTestCliEnabled():
      if intfConfig and \
            intfConfig.floodUnderlayRouteType == UnderlayRouteType.pimasm:
         underlayRouteTypeCmds.append( 'vxlan flood protocol pim asm' )
      elif saveAll:
         underlayRouteTypeCmds.append( 'vxlan flood protocol pim ssm' )
   return underlayRouteTypeCmds

def getRemoteVniToVrfCmds( intfConfig ):
   remoteVniToVrfCmds = []
   for vni in sorted( intfConfig.remoteVniToVrf ):
      remoteVniToVrfCmds.append( 'vxlan remote vni %d vrf %s'
                                 % ( vni, intfConfig.remoteVniToVrf[ vni ] ) )
   return remoteVniToVrfCmds

def getMcastOverlayAfCmds( intfConfig, saveAll ):
   mcastOverlayAfCmds = []
   if not Toggles.McastVpnLibToggleLib.toggleMcastVpnOISMV6Enabled():
      return mcastOverlayAfCmds
   if intfConfig and not intfConfig.mcastOverlayV4:
      mcastOverlayAfCmds.append( "vxlan multicast ipv4 disable" )
   elif saveAll:
      mcastOverlayAfCmds.append( "vxlan multicast ipv4" )
   if intfConfig and intfConfig.mcastOverlayV6:
      mcastOverlayAfCmds.append( "vxlan multicast ipv6" )
   elif saveAll:
      mcastOverlayAfCmds.append( "vxlan multicast ipv6 disable" )
   return mcastOverlayAfCmds

def getUnderlayHerCmds( intfConfig, saveAll ):
   underlayHerCmds = []
   if intfConfig and intfConfig.underlayHerEnabled:
      underlayHerCmds.append( "vxlan multicast headend-replication" )
   elif saveAll:
      underlayHerCmds.append( "no vxlan multicast headend-replication" )
   return underlayHerCmds

typeName = 'Routing::Multicast::IpTunnelGroupConfig'
ipLocalDomainIpTunnelGroupConfigPath = Tac.Type( typeName ).mountPath

@CliSave.saver( typeName, ipLocalDomainIpTunnelGroupConfigPath,
                requireMounts=( 'interface/config/eth/vxlan', ) )
def saveMcastVpnConfig( ipTunnelGroupConfig, root, requireMounts, options ):
   vtiConfigDir = requireMounts[ 'interface/config/eth/vxlan' ]
   saveAll = options.saveAll
   for vti in vtiConfigDir.vtiConfig:
      intfConfig = ipTunnelGroupConfig.intfConfig.get( vti )
      if not intfConfig or intfConfig.isConfigEmpty:
         if not saveAll:
            continue

         mcastVpnCmds = getUnderlayRouteTypeCmds( intfConfig, saveAll ) + \
               getMcastOverlayAfCmds( intfConfig, saveAll ) +\
               getUnderlayHerCmds( intfConfig, saveAll )
      else:
         rangeSyntax = vtiConfigDir.vlanFloodGroupRangeSyntax
         mcastVpnCmds = getVlanFloodCmds( intfConfig, rangeSyntax=rangeSyntax ) + \
                        getVlanMulticastCmds( intfConfig ) + \
                        getVrfMulticastCmds( intfConfig ) + \
                        getUnderlayRouteTypeCmds( intfConfig, saveAll ) + \
                        getRemoteVniToVrfCmds( intfConfig ) + \
                        getMcastOverlayAfCmds( intfConfig, saveAll ) + \
                        getUnderlayHerCmds( intfConfig, saveAll )
      mode = root[ IntfCliSave.IntfConfigMode ].getOrCreateModeInstance( vti )
      cmds = mode[ 'McastVpn.VxlanIntf' ]
      for cmd in mcastVpnCmds:
         cmds.addCommand( cmd )

class DomainRemoteConfigMode( RemoteDomainMode, CliSave.Mode ):
   def __init__( self, param ):
      parent, session, intf = param
      RemoteDomainMode.__init__( self, parent, session, intf )
      CliSave.Mode.__init__( self, param )

   def skipIfEmpty( self ):
      return True

IntfCliSave.IntfConfigMode.addChildMode( DomainRemoteConfigMode,
      after=[ 'McastVpn.VxlanIntf' ] )
DomainRemoteConfigMode.addCommandSequence( 'McastVpn.remoteDomainConfig' )
ipRemoteDomainIpTunnelGroupConfigPath = Tac.Type( typeName ).remoteDomainMountPath

@CliSave.saver( typeName, ipRemoteDomainIpTunnelGroupConfigPath,
                requireMounts=( 'interface/config/eth/vxlan', ) )
def saveMcastVpnRemoteDomainConfig(
      ipTunnelGroupConfig, root, requireMounts, options ):
   if not Toggles.McastVpnLibToggleLib.toggleOismGatewayEnabled():
      return
   vtiConfigDir = requireMounts[ 'interface/config/eth/vxlan' ]
   saveAll = options.saveAll
   mcastVpnCmds = []
   for vti in vtiConfigDir.vtiConfig:
      intfMode = root[ IntfCliSave.IntfConfigMode ].getOrCreateModeInstance( vti )
      mode = intfMode[ DomainRemoteConfigMode ].getOrCreateModeInstance(
         ( "domain", "remote", vti ) )
      intfConfig = ipTunnelGroupConfig.intfConfig.get( vti )
      if intfConfig and not intfConfig.isConfigEmpty:
         cmds = mode[ 'McastVpn.remoteDomainConfig' ]
         mcastVpnCmds = getVlanMulticastCmds( intfConfig, remoteDomain=True ) + \
                        getVrfMulticastCmds( intfConfig, remoteDomain=True )
         for cmd in mcastVpnCmds:
            cmds.addCommand( cmd )
      elif saveAll:
         cmds = intfMode[ 'McastVpn.VxlanIntf' ]
         cmds.addCommand( "no domain remote" )
