# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import os
import Tac, Tracing
import AsuPStore
from CliPlugin.AsuPStoreModel import ReloadHitlessBlockingReason
from LagAsuLib import LagPortId, LagId
from LacpConstants import StateTimeout

__defaultTraceHandle__ = Tracing.Handle( "LagAsu" )
t0 = Tracing.trace0
t1 = Tracing.trace1

class LagPStoreEventHandler( AsuPStore.PStoreEventHandler ):

   def __init__( self, stores ):
      self.asuStores = stores
      self.funcs = { 'inputIntfLag' : self.storeInputIntfLag,
                     'lacpStatus' : self.storeLacpStatus,
                     'lacpConfig' : self.storeLacpConfig,
                     'collectDistribute' : self.storeCollectDistribute }
      super().__init__()

   def storeInputIntfLag( self ):
      '''EthLagIntfStatusDir is stored as { port-channel : list of members }.
      inputIntfLag (lag/input/interface/lag) is LAG-specific
      EthLagIntfStatusDir and should only have local interfaces. 
      '''
      t0( self.__class__.__name__, "store members." )

      elisd = self.asuStores[ 'inputIntfLag' ]
      pyCol = {}
      for intfId, intfStatus in elisd.intfStatus.items():
         info = pyCol.setdefault( intfId, {} )
         info[ 'fallbackEnabled' ] = intfStatus.fallbackEnabled
         info[ 'debugFallbackSmState' ] = intfStatus.debugFallbackSmState
         info[ 'member' ] = list( intfStatus.member )
      return pyCol

   def storeLacpConfig( self ):
      t0( self.__class__.__name__, "store LACP config." )

      lacpConfig = self.asuStores[ 'lacpConfig' ]
      pyCol = {}
      lics = pyCol.setdefault( 'lacpIntfConfig', {} )
      for intfId, lic in lacpConfig.lacpIntfConfig.items():
         # Saving all attributes in LacpIntfConfig entity
         info = lics.setdefault( intfId, {} )
         info[ 'priority' ] = lic.priority
         info[ 'portId' ] = lic.portId
         info[ 'mode' ] = lic.mode
         info[ 'timeout' ] = lic.timeout
         info[ 'individual' ] = lic.individualLagKey.individual
         info[ 'lag' ] = lic.individualLagKey.lagIntfId
         info[ 'key' ] = lic.individualLagKey.key
      lpccDict = pyCol.setdefault( 'lacpPCConfig', {} )
      for lagId, lpcc in lacpConfig.lacpPCConfig.items():
         info = lpccDict.setdefault( lagId, {} )
         info[ 'defaultLagKey' ] = lpcc.defaultLagKey
         info[ 'lagKey' ] = lpcc.lagKey
         info[ 'sysId' ] = lpcc.sysId
         info[ 'lacpRemoteLagId' ] = ( lpcc.lacpRemoteLagId.sysId,
                                       lpcc.lacpRemoteLagId.key,
                                       lpcc.lacpRemoteLagId.matchRequired )
         info[ 'portIdOffset' ] = lpcc.portIdInput.portIdMin
         info[ 'portIdMax' ] = lpcc.portIdInput.portIdMax
      return pyCol

   def _storeLacpIntfStatus( self, portStatus ):
      t0( self.__class__.__name__, "store LACP port status." )

      pyCol = {}
      for intfId, lis in portStatus.items():
         info = pyCol.setdefault( intfId, {} )
         info[ 'actorSynchronized' ] = lis.actorSynchronized
         if not os.environ.get( 'TEST_LEGACY_RELEASE' ):
            info[ 'portId' ] = lis.portId
         info[ 'actorPort' ] = repr( LagPortId( lis.actorPort ) )
         info[ 'actorState' ] = lis.actorState
         info[ 'partnerPort' ] = repr( LagPortId( lis.partnerPort ) )
         info[ 'partnerState' ] = lis.partnerState
         info[ 'selected' ] = lis.selected
         info[ 'rxSmState' ] = lis.rxSmState
         info[ 'muxSmState' ] = lis.muxSmState
         info[ 'actorCollecting' ] = lis.actorCollectingDistributing.collecting
         info[ 'actorDistributing' ] = lis.actorCollectingDistributing.distributing
      return pyCol

   def _storeAggStatus( self, aggStatus ):
      t0( self.__class__.__name__, "store LACP agg status." )

      pyCol = {}
      for lagId, agg in aggStatus.items():
         key = LagId( lagId )
         info = pyCol.setdefault( repr( key ), {} )
         info[ 'selected' ] = list( agg.selected )
         info[ 'standby' ] = list( agg.standby )
         info[ 'aggregate' ] = list( agg.aggregate )
      return pyCol

   def _storeLacpLagStatus( self, lacpLagStatus ):
      t0( self.__class__.__name__, "store LACP lag status." )

      pyCol = {}
      for intfId, lls in lacpLagStatus.items():
         info = pyCol.setdefault( intfId, {} )
         if lls.agg is None:
            info[ 'agg' ] = ""
         else:
            info[ 'agg' ] = repr( LagId( lls.agg.lagId ) )
         info[ 'otherIntf' ] = list( lls.otherIntf )
         if not os.environ.get( 'TEST_LEGACY_RELEASE' ):
            info[ 'protoCD' ] = list( lls.protoCollectDistribute )
      return pyCol

   def storeLacpStatus( self ):
      t0( self.__class__.__name__, "store LACP status." )

      lacpStatus = self.asuStores[ 'lacpStatus' ]
      pyCol = {}
      pyCol[ 'portStatus' ] = self._storeLacpIntfStatus( lacpStatus.portStatus )
      pyCol[ 'aggStatus' ] = self._storeAggStatus( lacpStatus.aggStatus )
      pyCol[ 'lacpLagStatus' ] = self._storeLacpLagStatus( lacpStatus.lagStatus )
      return pyCol

   def storeCollectDistribute( self ):
      t0( self.__class__.__name__, "store collect distribute flags." )

      hwLagConfig = self.asuStores[ 'hwLagConfig' ]
      pyCol = {}
      for lagInitialMember in hwLagConfig.initialMember.values():
         for memberName, portProgType in lagInitialMember.member.items():
            pyCol[ memberName ] = { 'collecting' : portProgType.collecting,
                                    'distributing' : portProgType.distributing }
      return pyCol

   def save( self, pStoreIO ):
      # Save TAC collection as dictionary in string format
      keys = self.getKeys()
      for k in keys:
         t0( 'saving', k )
         pyCol = self.funcs[k]()
         pStoreIO.set( k, pyCol )

   def getSupportedKeys( self ):
      # Need to store config before status
      return [ 'inputIntfLag', 'lacpConfig', 'lacpStatus', 'collectDistribute' ]

   def getKeys( self ):
      # Need to store config before status
      keys = []
      keys.append( 'inputIntfLag' )
      keys.append( 'lacpConfig' )
      keys.append( 'lacpStatus' )
      keys.append( 'collectDistribute' )
      return keys

   def hitlessReloadSupported( self ):
      t0( self.__class__.__name__ )

      lacpConfig = self.asuStores[ 'lacpConfig' ]
      lacpStatus = self.asuStores[ 'lacpStatus' ]
      lagConfig = self.asuStores[ 'lagConfig' ]
      ethPhyIntfConfigDir = self.asuStores[ 'ethPhyIntfConfigDir' ]
      blocking = []
      warning = []
      lacpPorts = []
      localFastPort = False
      remoteFastPort = False

      # Tabulate all Cli configured lacp ports
      for intfId, lic in lagConfig.phyIntf.items():
         if ethPhyIntfConfigDir.intfConfig.get( intfId, None ):
            # check for local ports
            if lic.mode != 'lacpModeOff':
               lacpPorts.append( intfId )

      # check that all LacpConfig has long timeout
      for intfId in lacpPorts:
         lic = lacpConfig.lacpIntfConfig.get( intfId, None )
         if lic and lic.timeout == 'lacpShortTimeout':
            localFastPort = True
            break

      # check that all remote partner has long timeout
      # We only believe partnerState & StateTimeout bit when receive (rx) sm is
      # in CURRENT state. In EXPIRED state, partnerState & StateTimeout will be
      # locally set to rate fast. In DEFAULTED, it will stay fast since rx sm
      # has to transition to EXPIRED state first. In PORT_DISABLED state, the
      # partnerState & StateTimeout is not changed. However, we don't know how
      # rx sm transition to PORT_DISABLED. At higher level, when receive sm is
      # in any state other than CURRENT, the port will not be in the
      # port-channel and it should not affect traffic. Hence, even if remote
      # side is configured lacp rate fast, we should be able to allow reload
      # hitless to proceed.
      for intfId in lacpPorts:
         lis = lacpStatus.portStatus.get( intfId, None )
         if lis and lis.rxSmState == 'rxSmCurrent' and \
               lis.partnerState & StateTimeout:
            remoteFastPort = True
            break

      if localFastPort:
         blocking.append( ReloadHitlessBlockingReason( 
                           reason='lacpRateFastLocalConfigured' ) )
      if remoteFastPort:
         blocking.append( ReloadHitlessBlockingReason( 
                           reason='lacpRateFastRemoteConfigured' ) )

      # Deny reload-hitless when any ports are configured lacp rate fast
      # locally or remotely.
      return( warning, blocking )

def Plugin( ctx ):
   featureName = 'Lag'

   if ctx.opcode() == 'GetSupportedKeys':
      ctx.registerAsuPStoreEventHandler( featureName,
            LagPStoreEventHandler( None ) )
      return

   entityManager = ctx.entityManager()
   mg = entityManager.mountGroup()
   stores = {}

   inputIntfLag = mg.mount( 'lag/input/interface/lag', 
                            'Lag::Input::EthLagIntfStatusDir', 'r' )
   lacpConfig = mg.mount( 'lag/lacp/config', 'Lacp::Config', 'r' )
   lacpStatus = mg.mount( 'lag/lacp/status', 'Lacp::LacpStatus', 'r' )
   ethPhyIntfConfigDir = mg.mount( 'interface/config/eth/phy/all',
                                   'Interface::AllEthPhyIntfConfigDir', 'r' )
   lagConfig = mg.mount( 'lag/config', 'Lag::Config', 'r' )
   hwLagConfig = mg.mount( 'hardware/lag/config', 'Hardware::Lag::Config', 'r' )

   def registerEventHandlers():
      # Stuff we actually store for ASU2
      stores[ 'inputIntfLag' ] = inputIntfLag
      stores[ 'lacpStatus' ] = lacpStatus
      stores[ 'lacpConfig' ] = lacpConfig
      stores[ 'lagConfig' ] = lagConfig
      stores[ 'ethPhyIntfConfigDir' ] = ethPhyIntfConfigDir
      stores[ 'hwLagConfig' ] = hwLagConfig
      ctx.registerAsuPStoreEventHandler( featureName,
            LagPStoreEventHandler( stores ) )
   ctx.mountsComplete( mg, 'LagAsuPStore', registerEventHandlers )
