#!/usr/bin/env python3
# Copyright (c) 2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import CliSave, Arnet
from CliSavePlugin import IntfCliSave
from CliSavePlugin.IgmpVrfMode import RouterIgmpBaseConfigMode
from CliSavePlugin.IgmpVrfMode import RouterIgmpVrfConfigMode
# pylint: disable-msg=W0611
import IpUtils
import Tac, McastCommonCliLib
from RoutingIntfUtils import allRoutingProtocolIntfNames
from IpLibConsts import DEFAULT_VRF
from functools import cmp_to_key
from Toggles.IgmpToggleLib import toggleIgmpVirtualQuerierEnabled

IntfCliSave.IntfConfigMode.addCommandSequence( 'Igmp.intf', after=[ 'Ira.ipIntf' ] )

def sourceGroupCompare( sg1, sg2 ):
   sg1groupVal = Arnet.IpAddress( sg1.groupAddr ).value
   sg2groupVal = Arnet.IpAddress( sg2.groupAddr ).value
   if sg1groupVal != sg2groupVal:
      if sg1groupVal < sg2groupVal:
         return -1
      else:
         return 1

   sg1sourceVal = Arnet.IpAddress( sg1.sourceAddr ).value
   sg2sourceVal = Arnet.IpAddress( sg2.sourceAddr ).value
   if sg1sourceVal == 0:
      return -1
   elif sg2sourceVal == 0:
      return 1
   else:
      if sg1sourceVal == sg2sourceVal:
         return 0
      elif sg1sourceVal < sg2sourceVal:
         return -1
      else:
         return 1

def saveSsmAware( igmpConfig, root, options ):

   if igmpConfig.ssmAwareEnabled:
      mode = root[ RouterIgmpBaseConfigMode ].getSingletonInstance()
      saveRoot = mode[ 'Igmp.config' ]
      saveRoot.addCommand( "ssm aware" )
   elif options.saveAll:
      mode = root[ RouterIgmpBaseConfigMode ].getSingletonInstance()
      saveRoot = mode[ 'Igmp.config' ]
      saveRoot.addCommand( "no ssm aware" )

def saveIgmpServiceAclConfig( aclCpConfig, root, options ):
   for vrf, serviceAclVrfConfig in aclCpConfig.serviceAcl.items():
      serviceAclConfig = serviceAclVrfConfig.service.get( 'igmp' )
      aclName = serviceAclConfig.aclName if serviceAclConfig else None

      if not aclName and not options.saveAll:
         continue

      if vrf == DEFAULT_VRF:
         mode = root[ RouterIgmpBaseConfigMode ].getSingletonInstance()
         saveRoot = mode[ 'Igmp.config' ]
      else :
         parentMode = root[ RouterIgmpBaseConfigMode ].getSingletonInstance()
         mode = parentMode[ RouterIgmpVrfConfigMode ].getOrCreateModeInstance( vrf )
         saveRoot = mode[ 'Igmp.vrf.config' ]
      if aclName:
         saveRoot.addCommand( "ip igmp access-group %s" % aclName )
      else:
         saveRoot.addCommand( "no ip igmp access-group" )

@CliSave.saver( 'Routing::Igmp::Config', 'routing/igmp/config',
                requireMounts = ( 'routing/hardware/status',
                                  'interface/config/all', 
                                  'interface/status/all', 'acl/cpconfig/cli' ) )
def saveIgmpConfig( igmpConfig, root, requireMounts, options ):
   # Save defaults only if the platform supports multicast routing
   saveAll = options.saveAll
   saveAllDetail = options.saveAllDetail
   if not McastCommonCliLib.mcastRoutingSupported(
         None,
         requireMounts[ 'routing/hardware/status' ] ):
      saveAll = False
      saveAllDetail = False

   if saveAllDetail:
      cfgIntfNames = allRoutingProtocolIntfNames( requireMounts,
                                                  includeEligible=True )
   elif saveAll:
      # Routing configuration is allowed on switchports as well. 
      # Save configuration on all routing protocol interfaces and switchports
      # with non-default config.
      cfgIntfNames = set( allRoutingProtocolIntfNames( requireMounts ) )
      cfgIntfNames.update( igmpConfig.intfConfig )
   else:
      cfgIntfNames = igmpConfig.intfConfig

   for intfName in cfgIntfNames:
      intfConfig = igmpConfig.intfConfig.get( intfName )
      if not intfConfig:
         if saveAll:
            intfConfig = Tac.newInstance( 'Routing::Igmp::IntfConfig', intfName ) 
            intfConfig.gmpQuerierConfig = ( intfName, )
            intfConfig.gmpQuerierConfig.querierVersion = intfConfig.versionDefault
            intfConfig.gmpQuerierConfig.lastMemberQueryCount = \
                  intfConfig.lastMemberQueryCountDefault
            intfConfig.gmpQuerierConfig.lastMemberQueryInterval = \
                  intfConfig.lastMemberQueryIntervalDefault
            intfConfig.gmpQuerierConfig.queryResponseInterval = \
                  intfConfig.queryResponseIntervalDefault
            intfConfig.gmpQuerierConfig.queryInterval = \
                  intfConfig.queryIntervalDefault
            intfConfig.gmpQuerierConfig.startupQueryCount = \
                  intfConfig.startupQueryCountDefault
            intfConfig.gmpQuerierConfig.querierAddressVirtual = \
                  intfConfig.querierAddressVirtualDefault
            intfConfig.gmpQuerierConfig.startupQueryInterval = \
                  intfConfig.startupQueryIntervalDefault
            intfConfig.gmpQuerierConfig.addRandomDelayToTimer = True
         else:
            continue
      saveIntfConfig( intfConfig, root, saveAll, saveAllDetail )
   aclCpConfig = requireMounts[ 'acl/cpconfig/cli' ].cpConfig[ 'ip' ]
   saveIgmpServiceAclConfig( aclCpConfig, root, options )

   saveSsmAware( igmpConfig, root, options )

def saveIntfConfig( igmpIntfConfig, root, saveAll, saveAllDetail ):
   intf = igmpIntfConfig
   mode = root[ IntfCliSave.IntfConfigMode ].getOrCreateModeInstance( intf.intfId )
   cmds = mode[ 'Igmp.intf' ]

   if intf.enabled:
      cmds.addCommand( 'ip igmp' )
   elif saveAll:
      cmds.addCommand( 'no ip igmp' )

   # Save default for version, when this command is visible.
   if intf.gmpQuerierConfig:
      if ( intf.gmpQuerierConfig.querierVersion != intf.versionDefault ) or saveAll:
         version = str( intf.gmpQuerierConfig.querierVersion[ -1 ] )
         cmds.addCommand( 'ip igmp version %s' % version ) 
      if( intf.gmpQuerierConfig.lastMemberQueryCount != \
             intf.lastMemberQueryCountDefault or \
             saveAll ):
         cmds.addCommand(  'ip igmp last-member-query-count %d' % 
                          intf.gmpQuerierConfig.lastMemberQueryCount  )
      if( intf.gmpQuerierConfig.lastMemberQueryInterval != \
             intf.lastMemberQueryIntervalDefault \
             or saveAll ):
         cmds.addCommand(  'ip igmp last-member-query-interval %d' %  
                          ( intf.gmpQuerierConfig.lastMemberQueryInterval * 10 )  )
      if( intf.gmpQuerierConfig.queryResponseInterval !=
          intf.queryResponseIntervalDefault ):
         cmds.addCommand( 'igmp query-max-response-time %d' %
                          ( intf.gmpQuerierConfig.queryResponseInterval * 10 ) )
      elif saveAll:
            # Use new cmd when doing saveAll
         cmds.addCommand( 'igmp query-max-response-time %d' %  
                        ( intf.gmpQuerierConfig.queryResponseInterval * 10 ) )
      if( intf.gmpQuerierConfig.queryInterval != intf.queryIntervalDefault or \
             saveAll ):
         cmds.addCommand( 'ip igmp query-interval %d' % \
                             intf.gmpQuerierConfig.queryInterval )
      if( intf.gmpQuerierConfig.startupQueryCount != \
             intf.startupQueryCountDefault or saveAll ):
         cmds.addCommand(  'ip igmp startup-query-count %d' % 
                          intf.gmpQuerierConfig.startupQueryCount  )
      if( intf.gmpQuerierConfig.startupQueryInterval != \
             intf.startupQueryIntervalDefault or \
             saveAll ):
         cmds.addCommand(  'ip igmp startup-query-interval %d' % 
                          ( intf.gmpQuerierConfig.startupQueryInterval * 10 )  )
      if toggleIgmpVirtualQuerierEnabled():
         if intf.gmpQuerierConfig.querierAddressVirtual != \
            intf.querierAddressVirtualDefault:
            cmds.addCommand( 'ip igmp querier address virtual' )
         elif saveAll:
            cmds.addCommand( 'no ip igmp querier address virtual' )

   for sg in sorted( intf.staticJoinSourceGroup,
         key=cmp_to_key( sourceGroupCompare )):
      if Arnet.IpAddress( sg.sourceAddr ).value:
         cmds.addCommand( 'ip igmp static-group %s %s' % ( sg.groupAddr, 
                                                           sg.sourceAddr ) )
      else:
         cmds.addCommand( 'ip igmp static-group %s' % sg.groupAddr )

   for acl in sorted( intf.staticJoinAcl ):
      cmds.addCommand( 'ip igmp static-group acl %s' % acl )

   if( intf.routerAlertConfigOption != intf.routerAlertConfigOptionDefault or \
         saveAll ):
      routerAlertConfigDict = {
            "routerAlertMandatory" : "mandatory",
            "routerAlertOptional"  : "optional",
            "routerAlertOptionalConnected"  : "optional connected",
      }
      assert intf.routerAlertConfigOption in routerAlertConfigDict
      if routerAlertConfigDict[ intf.routerAlertConfigOption ]:
         cmds.addCommand( 'ip igmp router-alert %s' %
               routerAlertConfigDict [ intf.routerAlertConfigOption ] )
