# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from CliPlugin import MacsecModel
from CliPlugin.MacsecShowCli import (
   _macsecCounters,
   _macsecStatus,
   _macsecMkaStatus,
   _macsecInput,
   _hwStatusSliceDir
)
from CliPlugin.MaintenanceCliLib import parentIntf, isSubIntf
import Tac
import TacSigint

def getParentIntf( intfName ):
   return parentIntf( intfName ) if isSubIntf( intfName ) else intfName

#----------------------------------------------------------------------------------
# The "show mac security interface [ intf <intfName ] [ detail ]" command.
#----------------------------------------------------------------------------------
def showInterfaces( mode, args ):
   detail = 'detail' in args
   def getHwPostStatus( intfName ):
      parentIntfName = getParentIntf( intfName )

      # Check all subdirs for the intf
      for subDir in _hwStatusSliceDir.values():
         status = subDir.status.get( parentIntfName )
         if status:
            return status.hwPostStatus

      return None

   def getHwCapabilities( intfName ):
      for subDir in _hwStatusSliceDir.values():
         hwCapabilities = subDir.hwCapabilities.get( intfName )
         if hwCapabilities:
            return hwCapabilities
      return None

   def getHwIntfStatus( intfName ):
      for subDir in _hwStatusSliceDir.values():
         hwIntfStatus = subDir.status.get( intfName )
         if hwIntfStatus:
            return hwIntfStatus
      return None

   fipsRestrictions = _macsecStatus.fipsStatus.fipsRestrictions
   macsecInterfaces = MacsecModel.MacsecInterfaces()

   intfs = args.get( 'INTFS', _macsecMkaStatus.portStatus )
   for intfId in intfs:
      intfStatus = _macsecStatus.intfStatus.get( intfId )
      if intfStatus and intfStatus.staticSakInUse and detail:
         continue
      portStatus = _macsecMkaStatus.portStatus.get( intfId )
      cpStatus = _macsecStatus.cpStatus.get( intfId )
      if portStatus and cpStatus and intfStatus:
         hwPostStatus = getHwPostStatus( intfId ) if fipsRestrictions else None
         macsecInterface = MacsecModel.MacsecInterface()

         hwIntfStatus = getHwIntfStatus( intfId )
         hwCapabilities = getHwCapabilities( intfId )
         ptpBypassesSupported = []
         if hwCapabilities:
            ptpBypassesSupported = hwCapabilities.ptpBypassesSupported

         macsecInterface.fromTacc( portStatus, cpStatus, intfStatus,
                                   hwPostStatus, hwIntfStatus,
                                   ptpBypassesSupported, fipsRestrictions,
                                   detail=detail )
         macsecInterfaces.interfaces[ intfId ] = macsecInterface
         TacSigint.check()

   return macsecInterfaces

#----------------------------------------------------------------------------------
# The "show mac security mka counters [intf <intfName>] [ detail ]" command.
#----------------------------------------------------------------------------------
def showCounters( mode, args ):
   detail = 'detail' in args

   messageCountersModel = MacsecModel.MacsecMessageCounters()

   intfs = args.get( 'INTFS' )
   if intfs:
      intfs = [ intf
                for intf in intfs
                if intf in _macsecCounters.msgCounter ]
   else:
      intfs = _macsecCounters.msgCounter

   for intfId in intfs:
      if intfCounter := _macsecCounters.get( intfId ):
         messageCounterModel = MacsecModel.MacsecMessageCountersInterface()
         messageCounterModel.fromTacc( intfCounter, detail )
         messageCountersModel.interfaces[ intfId ] = messageCounterModel
         TacSigint.check()

   return messageCountersModel

#----------------------------------------------------------------------------------
# The "show mac security participants [ intf <intfName ] [ detail ]" command.
#----------------------------------------------------------------------------------
def showParticipants( mode, args ):
   detail = 'detail' in args

   macsecParticipantsModel = MacsecModel.MacsecParticipants()

   intfs = args.get( 'INTFS', _macsecMkaStatus.portStatus )
   for intfId in intfs:
      portStatus = _macsecMkaStatus.portStatus.get( intfId )
      if portStatus and portStatus.actorStatus:
         macsecParticipantIntfModel = MacsecModel.MacsecParticipantsInterface()
         for actorStatus in portStatus.actorStatus.values():
            macsecParticipantModel = MacsecModel.MacsecParticipant()
            macsecParticipantModel.fromTacc( actorStatus, detail=detail )
            macsecParticipantIntfModel.participants[ actorStatus.ckn ] = \
                  macsecParticipantModel
            TacSigint.check()
         macsecParticipantsModel.interfaces[ intfId ] = \
               macsecParticipantIntfModel

   return macsecParticipantsModel

#----------------------------------------------------------------------------------
# The "show mac security status" command.
#----------------------------------------------------------------------------------
def showStatus( mode, args ):
   macsecStatusModel = MacsecModel.MacsecStatus()
   adminState = not _macsecInput[ "cli" ].shutdown
   macsecStatusModel.fromTacc( _macsecStatus, adminState )
   return macsecStatusModel

#----------------------------------------------------------------------------------
# The "show mac security profile [ <profile_name> ] [ source <src> ]" command
#----------------------------------------------------------------------------------
def showMacsecProfile( mode, args ):

   def getProfileModel( config, source, priority ):
      macsecProfileModel = MacsecModel.MacsecProfile()
      macsecProfileModel.source = source
      macsecProfileModel.priority = priority
      macsecProfileModel.fromTacc( config, source )
      return macsecProfileModel

   def getProfilesModel( config, src, profileName=None ):
      profileConfig = None
      priority = config.configPriority
      macsecProfilesModel = MacsecModel.MacsecProfiles()
      profiles = None
      if profileName is None:
         profiles = config.profile
      elif profileName in config.profile:
         profiles = [ profileName ]
      else:
         return macsecProfilesModel

      for profile in profiles:
         profileConfig = config.profile.get( profile )
         if profileConfig:
            profileModel = getProfileModel( profileConfig, src, priority )
            macsecProfilesModel.profiles[ profile ] = profileModel
      return macsecProfilesModel

   profileName = args.get( "PROFILE_NAME" )
   srcName = args.get( "SRC_NAME" )
   sources = None
   if srcName is None:
      sources = _macsecInput
   elif srcName in _macsecInput:
      sources = [ srcName ]
   else:
      mode.addErrorAndStop( f"source { srcName } doesn't exist." )

   macsecAllProfilesModel = MacsecModel.MacsecAllProfiles()
   for srcName in sources:
      source = _macsecInput.get( srcName )
      if source:
         profilesModel = getProfilesModel( source, src=srcName,
                                           profileName=profileName )
         macsecAllProfilesModel.sources[ srcName ] = profilesModel
   return macsecAllProfilesModel

#----------------------------------------------------------------------------------
# The "show mac security sak interface [ <intfName ]" command
#----------------------------------------------------------------------------------
def showSak( mode, args ):
   macsecSaks = MacsecModel.MacsecSaks()
   intfs = args.get( 'INTFS', _macsecMkaStatus.sakRecord )
   for intf in intfs:
      intfId = Tac.newInstance( 'Arnet::IntfId', intf )
      sakRecord = _macsecMkaStatus.sakRecord.get( intfId )
      cpStatus = _macsecStatus.cpStatus.get( intfId )
      intfStatus = _macsecStatus.intfStatus.get( intfId )
      if sakRecord and intfStatus and cpStatus:
         macsecSak = MacsecModel.MacsecSak()
         macsecSak.fromTacc( sakRecord, cpStatus, intfStatus )
         macsecSaks.interfaces[ intfId ] = macsecSak

   return macsecSaks

#----------------------------------------------------------------------------------
# The "show mac security ptp bypass" command
#----------------------------------------------------------------------------------
def showPtpBypass( mode, args ):
   def getPtpBypasses( intfName ):
      # Returns the bypasses that are supported on the interface
      for subDir in _hwStatusSliceDir.values():
         if hwCapabilities := subDir.hwCapabilities.get( intfName ):
            return hwCapabilities.ptpBypassesSupported
      return []

   macsecPtpBypass = MacsecModel.MacsecPtpBypass()
   for intfId in _macsecMkaStatus.portStatus:
      cpStatus = _macsecStatus.cpStatus.get( intfId )
      intfStatus = _macsecStatus.intfStatus.get( intfId )
      if not cpStatus or not intfStatus:
         continue
      # Only display interfaces that have been explicitly configured to bypass PTP.
      PtpBypass = Tac.Type( "Macsec::PtpBypass" )
      if intfStatus.bypassPtp != PtpBypass.ptpBypass:
         continue

      macsecPtpBypassInterface = MacsecModel.MacsecPtpBypassInterface()
      bypassesSupported = getPtpBypasses( intfId )

      # Only display interfaces without full PTP bypass support.
      PtpBypassSupport = Tac.Type( "Macsec::PtpBypassSupport" )
      if PtpBypassSupport.full in bypassesSupported:
         continue

      macsecPtpBypassInterface.fromTacc( bypassesSupported )
      macsecPtpBypass.interfaces[ intfId ] = macsecPtpBypassInterface
      TacSigint.check()

   return macsecPtpBypass

# ----------------------------------------------------------------------------------
# The "show interfaces [ INTF ] mac security capabilities" command
# ----------------------------------------------------------------------------------
def showMacsecCapableInterfaces( mode, args ):
   def getCapabilitiesFromSlice( hwSlice, intf ):
      status = hwSlice.status.get( intf )
      hwCapabilities = hwSlice.hwCapabilities.get( intf )
      if( status and hwCapabilities ):
         return {
            "subInterface": status.subIntfSupported,
            "ciphers": hwCapabilities.ciphersSupported,
            "ptpBypasses": hwCapabilities.ptpBypassesSupported,
            "perIntfPtpBypass": hwCapabilities.ptpBypassProfileSupported
         }
      return None

   allIntfs = { intf: hwSlice
                for hwSlice in _hwStatusSliceDir.values()
                for intf in hwSlice.status }
   if intfs := args.get( 'INTFS' ):
      intfs = set( intfs )
      intfs = { intf: hwSlice
                for intf, hwSlice in allIntfs.items()
                if intf in intfs }
   else:
      intfs = allIntfs
   model = MacsecModel.MacsecCapableInterfaces()
   for intf, hwSlice in intfs.items():
      capabilities = getCapabilitiesFromSlice( hwSlice, intf )
      if capabilities:
         subModel = MacsecModel.MacsecCapableInterface()
         subModel.fromTacc( capabilities )
         model.interfaces[ intf ] = subModel
         TacSigint.check()

   return model
