# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import datetime
from time import ctime
from typing import Optional

import Ark
import Arnet
import Tac
from CliModel import Bool, Dict, Float, Int, List, Model, Str
from IntfModels import Interface

from CliPlugin.XcvrShowDomModel import ( InterfaceTransceiverDomParameter,
                                         parameterTypeToStr, indicator )
from XcvrLib import ( getXcvrSlotName, noneToNegInf )

# --------------------------------------------------------------------------------
#
# Models for
# "show interfaces [ <interface> ] transceiver performance-monitoring [thresholds]"
#
# --------------------------------------------------------------------------------

def _printPerformanceMonitoringHeaderDefault( isQsfpPmPresent=False ):
   if isQsfpPmPresent:
      print( "Ch: Channel, N/A: not applicable" )
   else:
      print( "Index: datapath, electrical lane, optical channel, "
             "N/A: not applicable" )

class InterfaceTransceiverPerformanceMonitoringInterval( Model ):
   intervalStartTime = Float(
      help="Time when this interval started in UTC",
      optional=True )
   updateTime = Float( help="Last update time in UTC", optional=True )
   parameters = Dict(
      keyType=str,
      valueType=InterfaceTransceiverDomParameter,
      help=( "A mapping of performance monitoring parameter names "
             "to per channel value and thresholds in this interval" ),
      optional=True )

class InterfaceTransceiverPerformanceMonitoring( Model ):
   displayName = Str( help="Interfaces corresponding to the channel of the "
                           "transceiver", optional=True )
   performanceMonitoringIntervals = \
         Dict( keyType=int,
               valueType=InterfaceTransceiverPerformanceMonitoringInterval,
                      help="A mapping of performance monitoring interval number to \
per interval information. Interval numbers are integers starting from 0. Interval 0 \
is the interval in progress,Interval 1 is the most recently completed interval and \
so on", optional=True )
   _moduleParamOrder = List( valueType=str,
                             help=( "List of parameters in the order they "
                                    "should be printed before the datapath and "
                                    "index-based parameters" ),
                             optional=True )
   _datapathParamOrder = List( valueType=str,
                               help=( "List of interface scoped parameters in the "
                                      "order they should be printed before the "
                                      "index-based parameters" ),
                               optional=True )
   _paramOrder = List( valueType=str,
                        help="List of index-based parameters in the order they "
                             "should be printed" )

   _qsfpPmPresent = Bool( help="Indicates that PM output should be "
                               "formatted different for QSFP enhanced DOM modules." )

   _rawFormat = Bool( help=( "Indicates that index/channel numbers should be "
                             "displayed to match exactly what is reported in the "
                             "module's EEPROM for each monitored parameter." ) )

   def paramOrderIs( self, paramOrder: list[ str ] ):
      self._paramOrder = paramOrder

   def datapathParamOrderIs( self, datapathParamOrder: list[ str ] ) -> None:
      self._datapathParamOrder = datapathParamOrder

   def moduleParamOrderIs( self, moduleParamOrder: list[ str ] ) -> None:
      self._moduleParamOrder = moduleParamOrder

   def _printPerformanceMonitoring( self,
               pm: InterfaceTransceiverPerformanceMonitoringInterval ) -> None:
      intervalStartTime = Ark.timestampToStr( pm.intervalStartTime,
                                              now=Tac.utcNow() )

      # Print the time info for the provided PM interval
      updateTime = Ark.timestampToStr( pm.updateTime, now=Tac.utcNow() )
      fmtTime = "  {: <20} {: >41s}"
      print( fmtTime.format( "Started", intervalStartTime ) )
      print( fmtTime.format( "Last update", updateTime ) )

      def _printPmParams( pm: InterfaceTransceiverPerformanceMonitoringInterval,
                          paramOrder: list[ str ], fmtPmBER: str,
                          fmtPmCount: str, fmtPmFloat: str,
                          fmtChPm: Optional[ str ] = None,
                          acceptedCh: Optional[ str ] = None ) -> None:
         fmtChannel = "  Channel {}"
         for _paramType in paramOrder:
            param = pm.parameters[ _paramType ]
            for chan, value in param.channels.items():
               if acceptedCh is not None and chan not in acceptedCh:
                  continue
               paramType = parameterTypeToStr( _paramType )
               chOrParamStr = paramType
               if fmtChPm is not None and chan != "-":
                  # There are potentially multiple channelized parameter values that
                  # require nesting under the pertaining parameter type.
                  print( fmtChPm.format( paramType ) )
                  chOrParamStr = fmtChannel.format( chan )
               if 'Exceeded' in _paramType:
                  print( fmtPmCount.format( chOrParamStr, value ) )
               elif 'BER' in _paramType:
                  print( fmtPmBER.format( chOrParamStr, value ) )
               else:
                  # Currently, only the float output format additionally displays
                  # units.
                  print( fmtPmFloat.format( chOrParamStr, value, param.unit ) )

      fmtModValBER = "  {: <45}{: 13.2e}"
      fmtModValCount = "  {: <45}{: 13d}"
      fmtModValFloat = "  {: <45}{: 13.2f} {: <5}"
      _printPmParams( pm, self._moduleParamOrder, fmtModValBER, fmtModValCount,
                      fmtModValFloat )

      # Print interface display name to nest datapath, host electrical, and
      # media-scoped parameters beneath.
      fmtIntfValBER = ( " " * 4 ) + "{: <43}{: 13.2e}"
      fmtIntfValCount = ( " " * 4 ) + "{: <43}{: 13d}"
      fmtIntfValFloat = ( " " * 4 ) + "{: <43}{: 13.2f} {: <5}"
      if pm.parameters and ( self._paramOrder or self._datapathParamOrder ):
         # pylint: disable-next=consider-using-f-string
         print( "  {}".format( self.displayName ) )

      # Datapath scoped parameters, these parameters will not display channel info
      _printPmParams( pm, self._datapathParamOrder, fmtIntfValBER, fmtIntfValCount,
                      fmtIntfValFloat, acceptedCh=( "-", "0" ) )

      # Media scoped parameters
      fmtChPm = "    {}"
      _printPmParams( pm, self._paramOrder, fmtIntfValBER, fmtIntfValCount,
                      fmtIntfValFloat, fmtChPm=fmtChPm )

   def _printPerformanceMonitoringRaw( self, paramOrder: list[ str ],
               pm: InterfaceTransceiverPerformanceMonitoringInterval ) -> None:
      fmtValueBER = "  %-45s%3s%13.2e"
      fmtValueCount = "  %-45s%3s%13d"
      fmtTime = "  %-45s%3s %13s"
      # Currently, only the float output format additionally displays units.
      fmtValueFloat = "  %-45s%3s%13.2f %-5s"
      if self._qsfpPmPresent:
         fmtValueBER = "  %-43s%3s%13.2e"
         fmtValueCount = "  %-43s%3s%13d"
         fmtTime = "  %-43s%3s %13s"
         fmtValueFloat = "  %-43s%3s%13.2f %-5s"
      intervalStartTime = Ark.timestampToStr( pm.intervalStartTime,
                                              now=Tac.utcNow() )
      updateTime = Ark.timestampToStr( pm.updateTime, now=Tac.utcNow() )
      print( fmtTime % ( "Started", "-", intervalStartTime ) )
      print( fmtTime % ( "Last update", "-", updateTime ) )
      for _paramType in paramOrder:
         param = pm.parameters[ _paramType ]
         alreadyPrintedParamType = False
         for chan, value in param.channels.items():
            # Print parameter name only once if this parameter is same as previous
            paramType = parameterTypeToStr( _paramType ) \
                  if not alreadyPrintedParamType else ''
            if 'Exceeded' in _paramType:
               print( fmtValueCount % ( paramType, chan, value ) )
            elif 'BER' in _paramType:
               pmLine = fmtValueBER % ( paramType, chan, value )
               if self._qsfpPmPresent:
                  pmLine = pmLine.replace( "2.05e+10", "0.00e+00" )
               print( pmLine )
            else:
               print( fmtValueFloat % ( paramType, chan, value, param.unit ) )
            alreadyPrintedParamType = True

   def _printPerformanceMonitoringThreshold( self,
                  pm: InterfaceTransceiverPerformanceMonitoringInterval ) -> None:
      intervalStartTime = (
         Ark.timestampToStr( pm.intervalStartTime, now=Tac.utcNow() ) )
      updateTime = Ark.timestampToStr( pm.updateTime, now=Tac.utcNow() )
      fmtTime = "  {: <13}{: >43}"
      print( fmtTime.format( "Started", intervalStartTime ) )
      print( fmtTime.format( "Last update", updateTime ) )

      def _printPmParamThresholds( pm:
                                   InterfaceTransceiverPerformanceMonitoringInterval,
                                   paramOrder: list[ str ], fmtPmBER: str,
                                   fmtPmCount: str, fmtPmFloat: str,
                                   fmtChPm: Optional[ str ] = None,
                                   acceptedCh: Optional[ str ] = None ) -> None:
         fmtChannel = "  Channel {}"
         for _paramType in paramOrder:
            param = pm.parameters[ _paramType ]
            thresh = param.threshold
            for chan, value in param.channels.items():
               if acceptedCh is not None and chan not in acceptedCh:
                  continue
               paramType = parameterTypeToStr( _paramType )
               chOrParamStr = paramType
               if fmtChPm is not None and chan != "-":
                  # There are potentially multiple channelized parameter values that
                  # require nesting under the pertaining parameter type.
                  print( fmtChPm.format( paramType ) )
                  chOrParamStr = fmtChannel.format( chan )
               if 'Exceeded' in _paramType:
                  pmLine = fmtPmCount.format( chOrParamStr, value,
                     noneToNegInf( thresh.highAlarm ),
                     noneToNegInf( thresh.highWarn ),
                     noneToNegInf( thresh.lowWarn ),
                     noneToNegInf( thresh.lowAlarm ),
                     indicator( value, thresh.highAlarm, thresh.highWarn,
                                thresh.lowWarn, thresh.lowAlarm ) )
               elif 'BER' in _paramType:
                  pmLine = fmtPmBER.format( chOrParamStr, value,
                     noneToNegInf( thresh.highAlarm ),
                     noneToNegInf( thresh.highWarn ),
                     noneToNegInf( thresh.lowWarn ),
                     noneToNegInf( thresh.lowAlarm ),
                     indicator( value, thresh.highAlarm, thresh.highWarn,
                                thresh.lowWarn, thresh.lowAlarm ) )
               else:
                  # Currently, only the float output format additionally displays
                  # units.
                  pmLine = fmtPmFloat.format( chOrParamStr, value,
                     noneToNegInf( thresh.highAlarm ),
                     noneToNegInf( thresh.highWarn ),
                     noneToNegInf( thresh.lowWarn ),
                     noneToNegInf( thresh.lowAlarm ),
                     param.unit,
                     indicator( value, thresh.highAlarm, thresh.highWarn,
                                thresh.lowWarn, thresh.lowAlarm ) )
               pmLine = pmLine.replace( "2.05e+10", "0.00e+00" )
               pmLine = pmLine.replace( "-inf", " N/A" )
               print( pmLine )

      fmtModValBER = \
         "  {: <43}{: >13.2e}{: >11.2e}{: >11.2e}{: >11.2e}{: >11.2e}{: >10s}"
      fmtModValCount = "  {: <43}{: >13d}{: >11d}{: >11d}{: >11d}{: >11d}{: >10s}"
      fmtModValFloat = \
         "  {: <43}{: >13.2f}{: >11.2f}{: >11.2f}{: >11.2f}{: >11.2f}{: >6s}{: >11s}"
      _printPmParamThresholds( pm, self._moduleParamOrder, fmtModValBER,
                               fmtModValCount, fmtModValFloat )

      # Print interface display name to nest datapath, host electrical, and
      # media-scoped parameters beneath.
      fmtIntfValBER = (
         ( " " * 4 ) +
         "{: <41}{: >13.2e}{: >11.2e}{: >11.2e}{: >11.2e}{: >11.2e}{: >10s}" )
      fmtIntfValCount = (
         ( " " * 4 ) + "{: <41}{: >13d}{: >11d}{: >11d}{: >11d}{: >11d}{: >10s}" )
      fmtIntfValFloat = (
         ( " " * 4 ) +
         "{: <41}{: >13.2f}{: >11.2f}{: >11.2f}{: >11.2f}{: >11.2f}{: >6s}{: >11s}" )
      if pm.parameters and ( self._paramOrder or self._datapathParamOrder ):
         # pylint: disable-next=consider-using-f-string
         print( "  {}".format( self.displayName ) )

      # Datapath scoped parameters will not display channel info
      _printPmParamThresholds( pm, self._datapathParamOrder, fmtIntfValBER,
                               fmtIntfValCount, fmtIntfValFloat,
                               acceptedCh=( "-", "0" ) )

      # Media scoped parameters
      fmtChPm = "    {}"
      _printPmParamThresholds( pm, self._paramOrder, fmtIntfValBER, fmtIntfValCount,
                               fmtIntfValFloat, fmtChPm=fmtChPm )

   def _printPerformanceMonitoringThresholdRaw( self,
               pm: InterfaceTransceiverPerformanceMonitoringInterval ) -> None:
      fmtValueBER = "  %-45s%3s%13.2e%11.2e%11.2e%11.2e%11.2e%10s"
      fmtValueCount = "  %-45s%3s%13d%11s%11s%11s%11s%10s"
      fmtTime = "  %-45s%3s %13s%11s%11s%11s%11s%10s"
      # Currently, only the float output format additionally displays units.
      fmtValueFloat = "  %-45s%3s%13.2f%11s%11s%11s%11s%6s%11s"
      if self._qsfpPmPresent:
         fmtValueBER = "  %-43s%3s%13.2e%11.2e%11.2e%11.2e%11.2e%10s"
         fmtValueCount = "  %-43s%3s%13d%11s%11s%11s%11s%10s"
         fmtTime = "  %-43s%3s %13s%11s%11s%11s%11s%10s"
         fmtValueFloat = "  %-43s%3s%13.2f%11s%11s%11s%11s%6s%11s"
      intervalStartTime = \
            Ark.timestampToStr( pm.intervalStartTime, now=Tac.utcNow() )
      updateTime = Ark.timestampToStr( pm.updateTime, now=Tac.utcNow() )
      print( fmtTime % ( "Started", "-", intervalStartTime, " ",
                         " ", " ", " ", " " ) )
      print( fmtTime % ( "Last update", "-", updateTime, " ", " ", " ", " ", " " ) )
      for _paramType in self._paramOrder:
         param = pm.parameters[ _paramType ]
         thresh = param.threshold
         # Print parameter name only once if this parameter is same as previous
         alreadyPrintedParamType = False
         for chan, value in param.channels.items():
            paramType = parameterTypeToStr( _paramType ) \
                  if not alreadyPrintedParamType else ''
            if 'Exceeded' in _paramType:
               pmLine = fmtValueCount % ( paramType, chan, value,
                  noneToNegInf( thresh.highAlarm ), noneToNegInf( thresh.highWarn ),
                  noneToNegInf( thresh.lowWarn ), noneToNegInf( thresh.lowAlarm ),
                  indicator( value, thresh.highAlarm, thresh.highWarn,
                     thresh.lowWarn, thresh.lowAlarm ) )
            elif 'BER' in _paramType:
               pmLine = fmtValueBER % ( paramType, chan, value,
                  noneToNegInf( thresh.highAlarm ), noneToNegInf( thresh.highWarn ),
                  noneToNegInf( thresh.lowWarn ), noneToNegInf( thresh.lowAlarm ),
                  indicator( value, thresh.highAlarm, thresh.highWarn,
                     thresh.lowWarn, thresh.lowAlarm ) )
            else:
               pmLine = fmtValueFloat % ( paramType, chan, value,
                  noneToNegInf( thresh.highAlarm ), noneToNegInf( thresh.highWarn ),
                  noneToNegInf( thresh.lowWarn ), noneToNegInf( thresh.lowAlarm ),
                  param.unit,
                  indicator( value, thresh.highAlarm, thresh.highWarn,
                     thresh.lowWarn, thresh.lowAlarm ) )
            pmLine = pmLine.replace( "2.05e+10", "0.00e+00" )
            pmLine = pmLine.replace( "-inf", " N/A" )
            print( pmLine )
            alreadyPrintedParamType = True

   def renderModelNoThresholds( self, intfName: str ):
      if self._qsfpPmPresent:
         print( intfName )
         fmtHeader = "%-45s%3s%13s"
         print( fmtHeader % ( "Parameter", "Ch", "Value" ) )
         print( "-" * 61 )
      else:
         print( "Port", getXcvrSlotName( intfName ) )
         if self._rawFormat:
            fmtHeader = "%-45s%3s%13s%5s"
            print( fmtHeader % ( "Parameter", "Index", "Value", "Unit" ) )
            print( "-" * 68 )
         else:
            fmtHeader = "%-47s%13s%5s"
            print( fmtHeader % ( "Parameter", "Value", "Unit" ) )
            print( "-" * 65 )

      # Print current interval information
      if 0 in self.performanceMonitoringIntervals:
         currentStr = ( "Current Interval 0" if self._qsfpPmPresent else
                        "Current Interval" )
         print( currentStr )
         if self._qsfpPmPresent or self._rawFormat:
            self._printPerformanceMonitoringRaw(
               self._paramOrder, self.performanceMonitoringIntervals[ 0 ] )
         else:
            self._printPerformanceMonitoring(
               self.performanceMonitoringIntervals[ 0 ] )

      # Print previous interval 1 information
      if 1 in self.performanceMonitoringIntervals:
         previousStr = ( "Interval 1" if self._qsfpPmPresent else
                         "Previous Interval" )
         print( previousStr )
         if self._qsfpPmPresent or self._rawFormat:
            self._printPerformanceMonitoringRaw(
               self._paramOrder, self.performanceMonitoringIntervals[ 1 ] )
         else:
            self._printPerformanceMonitoring(
               self.performanceMonitoringIntervals[ 1 ] )

   def renderModelThresholds( self, intfName: str ):
      if self._qsfpPmPresent:
         print( intfName )
         fmtHeader = "%-45s%3s%13s%11s%11s%11s%11s%10s"
         print( fmtHeader % ( " ", " ", " ", "High Alarm", "High Warn", "Low Warn",
                              "Low Alarm", " " ) )
         print( fmtHeader % ( "Parameter", "Ch", "Value", "Threshold", "Threshold",
                              "Threshold", "Threshold", "Indicator" ) )
         print( "-" * 116 )
      else:
         print( "Port", getXcvrSlotName( intfName ) )
         if self._rawFormat:
            fmtHeader = "%-45s%3s%13s%11s%11s%11s%11s%6s%11s"
            print( fmtHeader % ( " ", "     ", "      ", "High Alarm", "High Warn",
                                 "Low Warn", "Low Alarm", " ", " ", ) )
            print( fmtHeader % ( "Parameter", "Index", "Value", "Threshold",
                                 "Threshold", "Threshold", "Threshold", "Unit",
                                 "Indicator" ) )
            print( "-" * 124 )
         else:
            fmtHeader = "{: <45}{: >13}{: >11}{: >11}{: >11}{: >11}{: >6}{: >11}"
            print( fmtHeader.format( " ", "      ", "High Alarm", "High Warn",
                                     "Low Warn", "Low Alarm", "", "", ) )
            print( fmtHeader.format( "Parameter", "Value", "Threshold",
                                     "Threshold", "Threshold", "Threshold", "Unit",
                                     "Indicator" ) )
            print( "-" * 119 )
      if 0 in self.performanceMonitoringIntervals:
         currentStr = ( "Current Interval 0" if self._qsfpPmPresent else
                        "Current Interval" )
         print( currentStr )
         if self._qsfpPmPresent or self._rawFormat:
            self._printPerformanceMonitoringThresholdRaw(
               self.performanceMonitoringIntervals[ 0 ] )
         else:
            self._printPerformanceMonitoringThreshold(
               self.performanceMonitoringIntervals[ 0 ] )

      if 1 in self.performanceMonitoringIntervals:
         previousStr = ( "Interval 1" if self._qsfpPmPresent else
                         "Previous Interval" )
         print( previousStr )
         if self._qsfpPmPresent or self._rawFormat:
            self._printPerformanceMonitoringThresholdRaw(
               self.performanceMonitoringIntervals[ 1 ] )
         else:
            self._printPerformanceMonitoringThreshold(
               self.performanceMonitoringIntervals[ 1 ] )

class InterfacesTransceiverPerformanceMonitoring( Model ):
   _detailed = Bool( help="Include warning and alarm thresholds" )
   performanceMonitoringPeriodSeconds = \
         Int( help="Performance monitoring period in seconds",
              optional=True )
   interfaces = Dict( keyType=Interface,
                      valueType=InterfaceTransceiverPerformanceMonitoring,
                      help="Mapping between interface name and the " +
                           "performance monitoring information" )
   _performanceMonitoringConfigured = Bool( help="Performance monitoring " +
                                           "configured on any interface or not" )
   _qsfpPmPresent = Bool( help="Indicates that PM output should be "
                               "formatted different for QSFP enhanced DOM "
                               "modules." )
   _rawFormat = Bool( help=( "Indicates that index/channel numbers should be "
                             "displayed to match exactly what is reported in the "
                             "module's EEPROM for each monitored parameter." ) )

   def render( self ) -> None:

      if self._performanceMonitoringConfigured:
         # If there are interfaces to print, print the key first.
         if self.interfaces and ( self._qsfpPmPresent or self._rawFormat ):
            _printPerformanceMonitoringHeaderDefault(
               isQsfpPmPresent=self._qsfpPmPresent )
            print( "\n" )
         # Now print the performance monitoring period in format days, hh:mm:ss
         # pylint: disable-next=consider-using-f-string
         print( "Performance monitoring period : {} ".format(
            datetime.timedelta( seconds=self.performanceMonitoringPeriodSeconds ) ) )
         # pylint: disable-next=consider-using-f-string
         print( "Current system time: {}".format( ctime( Tac.utcNow() ) ) )

      if not self.interfaces:
         # If there are no interfaces, skip even printing headers.
         return

      if self._detailed:
         self._renderThresholds()
      else:
         self._renderNotThresholds()

   def _renderThresholds( self ):

      for intf in Arnet.sortIntf( self.interfaces ):
         print( "\n" )
         self.interfaces[ intf ].renderModelThresholds( intf )

   def _renderNotThresholds( self ):

      for intf in Arnet.sortIntf( self.interfaces ):
         print( "\n" )
         self.interfaces[ intf ].renderModelNoThresholds( intf )
