#!/usr/bin/env python3
# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-in

import Tracing
import AsuPStore
from CliPlugin.AsuPStoreModel import ReloadHitlessBlockingReason
from CliPlugin.AsuPStoreModel import ReloadHitlessWarningReason
import Tac
from StageGraphUtil import getPlatformType

__defaultTraceHandle__ = Tracing.Handle( "StpAsu" )
t0 = __defaultTraceHandle__.trace0

class StpAsuPStoreEventHandler( AsuPStore.PStoreEventHandler ):
   '''Stp mode other than NoSTP or MSTP blocks ASU2 hitless reload.
      AsuPStorePlugin's hitlessReloadSupported is used for this
      purpose. Other apis do not apply, and should not be registered.
   '''
   def __init__( self, stpAsuContext ):
      self.funcs = { 'preserveOriginal2p5GAnd5GCost' :
                     self.preserveOriginal2p5GAnd5GCost }
      self.stpConfig_ = stpAsuContext[ 'stpConfig' ] if stpAsuContext else None
      self.stpStatus_ = stpAsuContext[ 'stpStatus' ] if stpAsuContext else None
      self.stpTopoStatus_ = stpAsuContext[ 'stpTopoStatus' ] \
                            if stpAsuContext else None
      self.errDisabled_ = stpAsuContext[ 'errDisabled' ] if stpAsuContext else None
      self.stpPortMode_ = stpAsuContext[ 'stpPortMode' ]  if stpAsuContext else None
      self.mlagConfigured_ = False
      forceForwarding = \
         Tac.Type( 'EthIntf::PortModeConfig::PortMode' ).forceForwarding
      if self.stpPortMode_ and self.stpPortMode_.entity.get( 'mlag' ) :
         mlagStpPortModes = self.stpPortMode_.entity[ 'mlag' ].portMode.values()
         # We only care about presence of atleast one interface in stp/portMode/mlag
         # with portMode as 'forceForwarding'. In reality only the Mlag PeerLink
         # interface is present within stp/portMode/mlag.portMode collection anyways.
         self.mlagConfigured_ = forceForwarding in mlagStpPortModes
      AsuPStore.PStoreEventHandler.__init__( self )

   def preserveOriginal2p5GAnd5GCost( self ):
      # We use a boolean to represent if we should preserve the original cost
      t0( self.__class__.__name__, "store preserveOriginal2p5GAnd5GCost" )
      return self.stpTopoStatus_.preserveOriginal2p5GAnd5GCost

   def save( self, pStoreIO ):
      keys = self.getKeys()
      for k in keys:
         t0( 'saving', k )
         pStoreItem = self.funcs[k]()
         pStoreIO.set( k, pStoreItem )

   def getKeys( self ):
      return [ 'preserveOriginal2p5GAnd5GCost' ]

   def getSupportedKeys( self ):
      return [ 'preserveOriginal2p5GAnd5GCost' ]

   # The hitlessReloadSupported function has two versions defined - 
   # hitlessReloadSupported and hitlessReloadSupportedDeprecated.
   # If anyone is updating one version, please ensure to update the 
   # other version as well.
   def hitlessReloadSupportedDeprecated( self ):
      warningList, blockingList = [], []
      # only MSTP and NoSTP are supported
      if self.stpConfig_ and ( self.stpConfig_.version != "none" ) and \
                             ( self.stpConfig_.version != "mstp" ):
         t0( 'hitlessReloadSupport: stpConfig.version = ',
             self.stpConfig_.version )
         # Stp does not support Asu2, for modes other than NoSTP, Mstp
         blockingList.append( ReloadHitlessBlockingReason( reason='stpMode' ) )

      # if STP mode is MSTP, check:
      #   - all bridge ports are edge ports
      #   - bpduguard is enabled for all bridge ports
      if self.stpConfig_ and ( self.stpConfig_.version == "mstp" ):
         portfastBpduguard = self.stpConfig_.portfastBpduguard
         mstiConfigs = self.stpConfig_.stpiConfig[ 'Mst' ].mstiConfig.values()
         mstiConfigs = list( mstiConfigs )
         if not self.mlagConfigured_:
            stpUtils = Tac.newInstance( "Stp::StpUtils" )
            if not stpUtils.checkAllAdminEdgePorts( self.stpConfig_ ):
               t0( 'hitlessReloadSupport: adminEdgePort is not set for one or'
                     ' more interfaces' )
               blockingList.append( ReloadHitlessBlockingReason(
                  reason='edgePort' ) )

            for intf, config in ( it for mstiConfig in mstiConfigs
                                     for it in mstiConfig.mstiPortConfig.items() ):
               portBpduguard = self.stpConfig_.portConfig[ intf ].bpduguard
               # pylint: disable-next=singleton-comparison
               if( ( config.enabled == True ) and
                   ( ( portBpduguard == 'bpduguardDisabled' ) or
                     ( ( portBpduguard == 'bpduguardDefault' ) and
                       # pylint: disable-next=singleton-comparison
                       ( portfastBpduguard == False ) ) ) ):
                  t0( 'hitlessReloadSupport: Bpduguard is not enabled for ',
                      intf )
                  blockingList.append( ReloadHitlessBlockingReason(
                     reason='stpBpduguard' ) )
                  break

         if self.errDisabled_:
            for intf in ( it for mstiConfig in mstiConfigs
                             for it in mstiConfig.mstiPortConfig ):
               if intf in self.errDisabled_.intfStatus:
                  t0( 'hitlessReloadSupport: port ', intf, ' is errDisabled' )
                  warningList.append( ReloadHitlessWarningReason(
                                          reason='stpErrDisabled' ) )
                  break

      return ( warningList, blockingList )

   # The hitlessReloadSupported function has two versions defined - 
   # hitlessReloadSupported and hitlessReloadSupportedDeprecated.
   # If anyone is updating one version, please ensure to update the 
   # other version as well.
   def hitlessReloadSupported( self ):
      platform = getPlatformType()

      if platform != "strata":
         return self.hitlessReloadSupportedDeprecated()

      warningList, blockingList = [], []

      if not self.stpConfig_:
         return ( warningList, blockingList )

      # If STP mode is other than None or Mstp, then block ASU from proceeding
      if self.stpConfig_.version != "none" and \
                          self.stpConfig_.version != "mstp" :
         t0( 'hitlessReloadSupport: stpConfig.version = ',
                self.stpConfig_.version )
         # Stp does not support Asu2, for modes other than NoSTP, Mstp
         blockingList.append( ReloadHitlessBlockingReason( reason='stpMode' ) )
      # If STP mode is MSTP, check:
      # ----------------------------------------------------------------------
      #         Validate for        |           Blocking/Warning     
      # ----------------------------------------------------------------------
      #      If STP in unstable     |               Warning           
      #       ErrDisabled intf      |               Warning           
      #    Designated/Backup port   |               Blocking          
      #       Bridge assurance      |               Blocking
      #       Bpduguard enabled     |               Blocking
      #    (for admin edge ports)
      #   If mlag configured        |               Warning
      # ----------------------------------------------------------------------
      elif self.stpConfig_.version == "mstp":
         mstiConfigList = self.stpConfig_.stpiConfig[ 'Mst' ].mstiConfig.values()

         # hitlessReload is not supported in STP unstable state
         if not self.stpStatus_.stable:
            t0( 'hitlessReloadSupport: stpStatus.stable = ', 
                  self.stpStatus_.stable )
            warningList.append( ReloadHitlessWarningReason( reason='stpState' ) )

         if self.errDisabled_:
            for intf in ( it for mstiConfig in mstiConfigList
                          for it in mstiConfig.mstiPortConfig ):
               if intf in self.errDisabled_.intfStatus:
                  t0( 'hitlessReloadSupport: port ', intf, ' is errDisabled' )
                  warningList.append( ReloadHitlessWarningReason(
                                       reason='stpErrDisabled' ) )
                  break

         if self.mlagConfigured_:
            t0( 'hitlessReloadSupport: Mlag is configured' )
            warningList.append( ReloadHitlessWarningReason( 
               reason='mlagConfigured' ) )
            return ( warningList, blockingList )

         # Mlag not configured case
         mstiConfigList = self.stpConfig_.stpiConfig[ 'Mst' ].\
               mstiConfig.values()
         mstiConfigList = list( mstiConfigList )
         mstiStatusList = self.stpStatus_.stpiStatus[ 'Mst' ].\
               mstiStatus.values()
         portfastBpduguard = self.stpConfig_.portfastBpduguard
         bridgeAssurance = self.stpConfig_.bridgeAssurance

         # Get all admin edge ports and oper edge ports
         adminEdgePortList = []
         for intf, config in ( it for mstiConfig in mstiConfigList
               for it in mstiConfig.mstiPortConfig.items()):
            if( config.enabled and 
                  self.stpConfig_.portConfig[ intf ].adminEdgePort ):
               adminEdgePortList.append( intf )

         stpiPortStatus = self.stpStatus_.stpiStatus[ 'Mst' ].stpiPortStatus
         operEdgePortList = []
         for intf, status in stpiPortStatus.items():
            if status.operEdgePort:
               operEdgePortList.append( intf )

         # Check for non-edge ports status
         for intf, status in ( it for mstiStatus in mstiStatusList
                     for it in mstiStatus.mstiPortStatus.items() ):
            if( ( status.role == 'designated' ) or 
                  ( status.role == 'backup' ) ):
               # Skip for edge port
               if intf in operEdgePortList:
                  continue
               t0( 'hitlessReloadSupport: portRole is set to ', status.role,
                     ' for ', intf )
               blockingList.append( ReloadHitlessBlockingReason(
                  reason='portRole' ) )
               break

         # Check for bridge assurance/transmit active
         for intf, config in ( it for mstiConfig in mstiConfigList
               for it in mstiConfig.mstiPortConfig.items() ):
            if( config.enabled and 
                  self.stpConfig_.portConfig[ intf ].networkPort
                  and bridgeAssurance ):
               t0( 'hitlessReloadSupport: TransmitActive is enabled for ', 
                     intf )
               blockingList.append( ReloadHitlessBlockingReason(
                  reason='transmitActive' ) )
               break

         # Check bpduguard config for admin edge ports
         for intf in adminEdgePortList:
            portBpduguard = self.stpConfig_.portConfig[ intf ].bpduguard
            if( ( portBpduguard == 'bpduguardDisabled' ) or
                  ( portBpduguard == 'bpduguardDefault' and
                    not portfastBpduguard ) ):
               t0( 'hitlessReloadSupport: Bpduguard is not enabled for ', 
                     intf )
               blockingList.append( ReloadHitlessBlockingReason(
                  reason='stpBpduguard' ) )
               break

      return ( warningList, blockingList )

def Plugin( ctx ):
   featureName = 'Stp'
   stpAsuContext = {}
   if ctx.opcode() == 'GetSupportedKeys':
      ctx.registerAsuPStoreEventHandler( featureName,
                                         StpAsuPStoreEventHandler( None ) )
      return

   entityManager = ctx.entityManager()
   mg = entityManager.mountGroup()
   stpConfig = mg.mount( 'stp/config', 'Stp::Config', 'r' )
   stpStatus = mg.mount( 'stp/status', 'Stp::Status', 'r' )
   stpTopoStatus = mg.mount( 'stp/topology/status',
                             'Stp::Topology::Status', 'r' )
   stpPortMode = mg.mount( 'stp/portMode', 'Tac::Dir', 'ri' )
   errDisabled = mg.mount( 'interface/errdisable/cause/bpduguard',
                           'Errdisable::CauseStatus', 'r' )
   mg.close( blocking=True )

   stpAsuContext[ 'stpConfig' ] = stpConfig
   stpAsuContext[ 'stpStatus' ] = stpStatus
   stpAsuContext[ 'stpTopoStatus' ] = stpTopoStatus
   stpAsuContext[ 'stpPortMode' ] = stpPortMode
   stpAsuContext[ 'errDisabled' ] = errDisabled

   ctx.registerAsuPStoreEventHandler( featureName,
                                      StpAsuPStoreEventHandler( stpAsuContext ) )
