# Copyright (c) 2016 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac, CliSave, Tracing, EthIntfUtil
from CliSavePlugin.IntfCliSave import IntfConfigMode
from CliMode.Stp import MstMode
from StpCliUtil import loopGuardAllowed

__defaultTraceHandle__ = Tracing.Handle( 'StpCli' )

class StpConfigMode( MstMode, CliSave.Mode ):

   def __init__( self, param ):
      MstMode.__init__( self, param )
      CliSave.Mode.__init__( self, param )

@CliSave.saver( 'Stp::Input::Config', 'stp/input/config/cli',
                requireMounts=( 'bridging/input/config/cli',
                                'hwEpoch/status' ) )
def saveStpLoopGuardConfig( entity, root, requireMounts, options ):

   if not loopGuardAllowed( requireMounts[ 'hwEpoch/status' ] ):
      return

   cmds = root[ 'Stp.global' ] 
   saveAll = options.saveAll
   saveAllDetail = options.saveAllDetail
  
   if entity.loopGuardEnabled != entity.loopGuardEnabledDefault:
      cmds.addCommand( 'spanning-tree guard loop default' )
   elif saveAll:
      cmds.addCommand( 'no spanning-tree guard loop default' )

   # loopguard port configs
   if saveAllDetail:
      cfgPorts = EthIntfUtil.allSwitchportNames( requireMounts,
                                                 includeEligible=True )
   elif saveAll:
      swPorts = EthIntfUtil.allSwitchportNames( requireMounts )
      cfgPorts = set( swPorts ) | set( entity.portConfig )
   else:
      cfgPorts = entity.portConfig

   for portName in cfgPorts:
      portConfig = entity.portConfig.get( portName )
      if not portConfig:
         if saveAll:
            portConfig = Tac.newInstance( 'Stp::Input::PortConfig', portName )
         else:
            continue
      if portConfig.guard == 'loopguardEnabled':
         mode = root[ IntfConfigMode ].getOrCreateModeInstance( portConfig.name )
         intfCmds = mode[ 'Stp.intf' ]
         intfCmds.addCommand( 'spanning-tree guard loop' )

