# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Arnet
import BasicCliUtil
import CliPlugin.IntfCli as IntfCli # pylint: disable=consider-using-from-import
from CliPlugin.IntfPolicyManager import getRcdToIntfListMap, getIpm
import LazyMount
import Tracing
from TypeFuture import TacLazyType

__defaultTraceHandle__ = Tracing.Handle( "AlePhyIntfNoShutdownHook" )
t0 = __defaultTraceHandle__.trace0

ResourceType = TacLazyType( "Interface::ResourceType" )

promptTemplate = \
"""The following interfaces will be disabled due to insufficient {reason}:

{interfaces}"""

# Create the prompt string with list of interfaces that will be affected
def createPrompt( intfNames, reason ):
   prompt = ""
   if intfNames:
      # pylint: disable-next=consider-using-f-string
      intfNamesStr = "".join( ( "  %s\n" % i ) for i in intfNames )
      prompt = promptTemplate.format( reason=reason,
                                      interfaces=intfNamesStr )
   return prompt

# Ask the user if they would like to continue
def transitionPrompt( intfModes, warningPrompt ):
   if warningPrompt:
      mode = intfModes[ 0 ]
      mode.addWarning( warningPrompt )
      promptText = "Do you wish to proceed with this command? [y/N]"
      ans = BasicCliUtil.confirm( mode, promptText, answerForReturn=False )

      # See if user cancelled the command.
      if not ans:
         intfNamesStr = ", ".join( i.intf.name for i in intfModes )
         # pylint: disable-next=consider-using-f-string
         abortMsg = "Command aborted by user for %s" % intfNamesStr
         mode.addMessage( abortMsg )
         return False

   return True

# ---------------------------------------
# Logical Ports Checks
# ---------------------------------------
LogicalPortAllocationPriority = TacLazyType(
   "AlePhy::LogicalPortAllocationPriority::LogicalPortAllocationPriority" )
PRIORITY_INTFID_AND_INTF_ENABLED = LogicalPortAllocationPriority.intfIdAndIntfEnabled

resourceConsumerSliceDir = None

def getLogicalPortsAffectedIntfs( intfList ):
   allIntfNames = set()
   rcdToIntfListMap = getRcdToIntfListMap( intfList, resourceConsumerSliceDir )
   for rcd, intfsInRcd in rcdToIntfListMap.items():
      # Skip products not enabling releaseOnshut.
      if rcd.logicalPortAllocationPriority != PRIORITY_INTFID_AND_INTF_ENABLED:
         continue

      ipm = getIpm( rcd )
      if not ipm:
         continue

      ipm.synchronize()

      # Apply configs.
      for intf in intfsInRcd:
         ipm.enableInterface( intf.name )

      ipm.processIntfs()
      intfNames = ipm.getDisabledIntfs()
      allIntfNames |= intfNames

   return Arnet.sortIntf( allIntfNames )

def canNoShutdownLogicalPortsCheck( intfModes ):
   intfList = [ i.intf for i in intfModes ]
   intfNames = getLogicalPortsAffectedIntfs( intfList )
   warningPrompt = createPrompt( intfNames, "logical ports" )
   ans = transitionPrompt( intfModes, warningPrompt )
   return ( ans, warningPrompt )

# ---------------------------------------
# Serdes checks
# ---------------------------------------
def getSerdesAffectedIntfs( intfList ):
   allIntfNames = set()
   rcdToIntfListMap = getRcdToIntfListMap( intfList, resourceConsumerSliceDir )
   for rcd, intfsInRcd in rcdToIntfListMap.items():
      t0( "Processing rcd: ", rcd )
      # Skip products not enabling releaseOnshut.
      if not rcd.releaseOnShutIntfs:
         continue

      ipm = getIpm( rcd )
      if not ipm:
         continue

      ipm.synchronize()

      # Apply configs.
      for intf in intfsInRcd:
         if intf.name not in rcd.releaseOnShutIntfs:
            continue
         ipm.enableInterface( intf.name )

      ipm.processIntfs()
      intfNames = ipm.getDisabledIntfs( reason=ResourceType.intfPermissions )
      # BUG488999: remove intfRestrictions after everything is converted
      # to use intf permissions
      intfNames |= ipm.getDisabledIntfs( reason=ResourceType.intfRestricted )
      allIntfNames |= intfNames

   return Arnet.sortIntf( allIntfNames )

def canNoShutdownSerdesCheck( intfModes ):
   intfList = [ i.intf for i in intfModes ]
   intfNames = getSerdesAffectedIntfs( intfList )
   t0( "Affected serdes: ", intfNames )
   warningPrompt = createPrompt( intfNames, "SerDes" )
   ans = transitionPrompt( intfModes, warningPrompt )
   return ( ans, warningPrompt )

def canNoShutdownCheck( intfModes ):
   t0( "Shutdown hook called" )
   answer = True
   serdesAnswer, serdesWarn = canNoShutdownSerdesCheck( intfModes )
   logicalAnswer, logicalWarn = canNoShutdownLogicalPortsCheck( intfModes )
   if serdesWarn:
      assert not logicalWarn
   if logicalWarn:
      assert not serdesWarn
   answer &= serdesAnswer
   answer &= logicalAnswer
   return answer

IntfCli.canNoShutdownIfHook.addExtension( canNoShutdownCheck )

def Plugin( em ):
   global resourceConsumerSliceDir
   resourceConsumerSliceDir = LazyMount.mount( em,
                                               "interface/resources/consumers/slice",
                                               "Tac::Dir", "ri" )
