# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import Tac
import CliSave
import Tracing
import Arnet
from CliSavePlugin.IntfCliSave import IntfConfigMode
from CliMode.McastDns import McastDnsMode, McastDnsServiceMode
from RoutingIntfUtils import allRoutingProtocolIntfNames
from Toggles.McastDnsToggleLib import toggleMcastDnsFloodSuppressionEnabled
t0 = Tracing.trace0

DefaultPort = Tac.Value( 'McastDns::DefaultPort' )
IntfKey = Tac.Type( 'McastDns::IntfKey' )
AddressFamily = Tac.Type( 'Arnet::AddressFamily' )
IntfId = Tac.Type( 'Arnet::IntfId' )

#------------------------------------------------------------------------------
# McastDns Mode Saver
#------------------------------------------------------------------------------
class McastDnsServiceConfigSaveMode( McastDnsServiceMode, CliSave.Mode ):
   def __init__( self, param ):
      McastDnsServiceMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

McastDnsServiceConfigSaveMode.addCommandSequence( 'McastDns.serviceRule' )

class McastDnsConfigSaveMode( McastDnsMode, CliSave.Mode ):
   def __init__( self, param ):
      McastDnsMode.__init__( self )
      CliSave.Mode.__init__( self, param )

   def skipIfEmpty( self ):
      return True

# Config with query/response links needs to come after interfaces so that we know
# which interfaces mdns will use, for example, as query and response links.
CliSave.GlobalConfigMode.addCommandSequence( 'McastDns.global',
                                             after=[ IntfConfigMode ] )
CliSave.GlobalConfigMode.addChildMode( McastDnsConfigSaveMode,
                                       after=[ IntfConfigMode ] )

McastDnsConfigSaveMode.addCommandSequence( 'McastDns.config' )
McastDnsConfigSaveMode.addChildMode( McastDnsServiceConfigSaveMode,
                                     after=[ 'McastDns.config' ] )

IntfConfigMode.addCommandSequence( 'McastDns.intf' )

#------------------------------------------------------------------------------
# Entity Saver
#------------------------------------------------------------------------------
@CliSave.saver( 'McastDns::Config', 'mdns/config',
                requireMounts=( 'interface/config/all', 'interface/status/all' ) )
def saveMcastDnsConfig( entity, root, requireMounts, options ):
   # mdns can still be enabled under interfaces even if mdns is disabled.
   # 'no mdns ipv4' should appear under each interface if saveAll or saveAllDetail
   # is true, not just under the interfaces specified as query and response links.
   saveMcastDnsIntfConfig( entity, root, options, requireMounts )

   saveAll = options.saveAll

   mode = root[ McastDnsConfigSaveMode ].getSingletonInstance()

   cmds = mode[ 'McastDns.config' ]
   if entity.enabled or saveAll:
      cmds.addCommand( '%sdisabled' % ( entity.enabled and 'no ' or '' ) )
   if toggleMcastDnsFloodSuppressionEnabled():
      if entity.floodSuppression:
         cmds.addCommand( 'flooding suppression' )
      elif saveAll:
         cmds.addCommand( 'no flooding suppression' )
   if entity.remoteGateway:
      for key, gw in entity.remoteGateway.items():
         cmd = f'remote-gateway ipv4 {key}'
         if gw.port != DefaultPort.dsoPort:
            cmd += ( f' tcp-port {gw.port}' )
         if saveAll:
            if gw.port == DefaultPort.dsoPort:
               cmd += ( f' tcp-port {gw.port}' )
         cmds.addCommand( cmd )
   elif saveAll:
      cmds.addCommand( 'no remote-gateway ipv4' )

   cmd = 'dso server ipv4'
   if entity.serverPort:
      if entity.serverPort != DefaultPort.dsoPort:
         cmd += ( f' tcp-port {entity.serverPort}' )
      cmds.addCommand( cmd )
   if saveAll:
      if not entity.serverPort:
         cmd = 'no ' + cmd
      if entity.serverPort == DefaultPort.dsoPort:
         cmd += ( f' tcp-port {entity.serverPort}' )
      cmds.addCommand( cmd )

   for serviceName, serviceRule in entity.serviceRule.items():
      serviceMode = mode[ McastDnsServiceConfigSaveMode ].getOrCreateModeInstance(
         serviceName )
      serviceCmds = serviceMode[ 'McastDns.serviceRule' ]
      for serviceCommand in getServiceRule( serviceRule, saveAll ):
         serviceCmds.addCommand( serviceCommand )

def saveMcastDnsIntfConfig( entity, root, options, requireMounts ):
   linkIntfIds = [ key.intfId for key in entity.link ]

   if options.saveAll:
      intfNames = allRoutingProtocolIntfNames( requireMounts,
                                               includeEligible=True )
      # mdns is in a modelet that only applies to Ethernet and Vlan interfaces.
      cfgIntfNames = \
         { i for i in intfNames if i.startswith( 'Ethernet' ) or
           i.startswith( 'Vlan' ) }
   else:
      cfgIntfNames = set( linkIntfIds )

   for intfId in cfgIntfNames:
      intfMode = root[ IntfConfigMode ].getOrCreateModeInstance( intfId )
      cmds = intfMode[ 'McastDns.intf' ]
      cmd = 'mdns ipv4'
      config = entity.getLinkConfig( IntfKey( intfId, AddressFamily.ipv4 ) )
      if config:
         if config.linkName:
            cmd += ( f' link {config.linkName}' )
         if config.defaultTag:
            cmd += ( f' default-tag {config.defaultTag}' )
         cmds.addCommand( cmd )
      elif options.saveAll:
         cmds.addCommand( 'no ' + cmd )

def getServiceRule( rule, saveAll ):
   cmds = []
   if rule.serviceType:
      cmds.append( 'type {}'.format( ' '.join( sorted( rule.serviceType ) ) ) )
   elif saveAll:
      cmds.append( 'no type' )

   if rule.queryLink:
      cmds.append( 'query {}'.format( ', '.join( sorted( rule.queryLink ) ) ) )
   elif saveAll:
      cmds.append( 'no query' )

   if rule.responseLink:
      intfs = [ response.intfId for response in rule.responseLink
                if IntfId( response.intfId ) != IntfId() ]
      links = [ response.name for response in rule.responseLink
                if response.name != '' ]
      if intfs:
         cmd = 'response interface {}'.format( ', '.join( Arnet.sortIntf( intfs ) ) )
         cmds.append( cmd )
      elif saveAll:
         cmds.append( 'no response interface' )
      if links:
         cmd = 'response link {}'.format( ' '.join( sorted( links ) ) )
         cmds.append( cmd )
      elif saveAll:
         cmds.append( 'no response link' )
   elif saveAll:
      cmds.append( 'no response interface' )
      cmds.append( 'no response link' )

   if rule.matchByTag:
      cmds.append( 'match by-tag' )
   elif saveAll:
      cmds.append( 'no match by-tag' )

   if rule.matchGroup:
      cmds.append( 'match group {}'.format( ' '.join(
         sorted( rule.matchGroup ) ) ) )
   elif saveAll:
      cmds.append( 'no match group' )
   return cmds

