# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

from typing import Literal

import Ark
import Tac
from CliModel import Bool, Dict, Enum, Float, List, Model, Str, Submodel
from IntfModels import Interface

from CliPlugin.XcvrCliLib import subPort
from CliPlugin.XcvrModel import ( SUPPORTED_FORMATS,
                                 InterfaceTransceiverDetailThresholds,
                                 paramTypeToStr, formatThreshold )
from XcvrLib import ( getXcvrSlotName, noneToNegInf, isCmisCoherentMediaType )

# Format of this is { channel : { paramType : label } }
polsDomAdditionalLabel = {
      1: { 'txPower': '(line)',
            'rxPower': '(local)',
          },
      2: { 'txPower': '(local)',
            'rxPower': '(line)',
          },
      }

# ----------------------------------------------------------------------------
#
# Helpers for "show interfaces [ <interface> ] transceiver dom [thresholds]"
#
# ----------------------------------------------------------------------------

def _printDomHeaderDefault() -> None:
   print( "Ch: Channel, N/A: not applicable, TX: transmit, RX: receive" )
   print( "mA: milliamperes, dBm: decibels (milliwatts), C: Celsius, V: Volts" )

def indicator( val: float, highAlarm: float, highWarn: float, lowWarn: float,
               lowAlarm: float ) -> Literal[ '', 'ALARM', 'WARN' ]:
   # If value is -inf(not supported) or the BER is initialized to it's max value
   # or if high alarm is same as low alarm (which means high warn is same as low
   # warn, which means all thresholds are same) then don't print the indicator.
   pm = Tac.Value( 'Xcvr::Sff8436PerformanceMonitoring' )
   if val == float( "-inf" ) or val == pm.berMaxValue or highAlarm == lowAlarm:
      return ''
   if ( highAlarm and val >= highAlarm ) or ( lowAlarm and val <= lowAlarm ):
      return 'ALARM'
   if ( highWarn and val >= highWarn ) or ( lowWarn and val <= lowWarn ):
      return 'WARN'
   return ''

# ----------------------------------------------------------------------------
#
# Models for "show interface [ <interface> ] transceiver dom [thresholds]"
#
# ----------------------------------------------------------------------------

# Helper to return a presentable string for a parameter type
def parameterTypeToStr( paramType: str ) -> str:
   if paramType in paramTypeToStr:
      return paramTypeToStr[ paramType ]
   return paramTypeToStr[ 'unknown' ]

class InterfaceTransceiverDomBase( Model ):
   displayName = Str( help="Interfaces corresponding to the channel of the"
                           " transceiver", optional=True )
   _usePolsLabels = Bool( help="Whether to render using labels specific to "
                               "POLS modules", optional=True )

   def renderModelNoThresholds( self, intfName: str, printFmt: str,
                                printCommonDomInfo: bool ) -> None:
      assert printFmt == 'default', "Unrecognized DOM output format"
      print( "" )
      print( subPort( intfName ) )
      print( "Last update: %s" % "N/A" )
      print( "%50s" % ( "Value" ) )
      print( "%54s" % ( "------------" ) )
      fmtValue = "   %-40s%5s %s"
      print( fmtValue % ( "Temperature", "N/A", "N/A" ) )
      print( fmtValue % ( "Voltage", "N/A", "N/A" ) )
      print( fmtValue % ( "TX bias current", "N/A", "N/A" ) )
      print( fmtValue % ( "Optical TX power", "N/A", "N/A" ) )
      print( fmtValue % ( "Optical RX power", "N/A", "N/A" ) )

   def renderModelThresholds( self, intfName: str, printCommonDomInfo: bool )\
                              -> None:
      print( "" )
      print( subPort( intfName ) )
      print( "Last update: %s" % "N/A" )

      print( "%-38s%11s%11s%11s%11s" %
            ( " ", "High Alarm", "High Warn", "Low Warn", "Low Alarm" ) )
      print( "%-32s%-8s%-11s%-11s%-11s%-9s%6s%11s" % ( " ", "Value", "Threshold",
                  "Threshold", "Threshold", "Threshold", "Unit", "Indicator" ) )
      print( "%-26s%-77s" % ( " ", "-" * 75 ) )
      fmtValue = "   %-23s%11s%12s%11s%11s%11s %5s%11s"

      print( fmtValue % ( "Temperature",
            "N/A", "N/A", "N/A", "N/A", "N/A", "N/A", "N/A" ) )
      print( fmtValue % ( "Voltage",
            "N/A", "N/A", "N/A", "N/A", "N/A", "N/A", "N/A" ) )
      print( fmtValue % ( "TX bias current",
            "N/A", "N/A", "N/A", "N/A", "N/A", "N/A", "N/A" ) )
      print( fmtValue % ( "Optical TX power",
            "N/A", "N/A", "N/A", "N/A", "N/A", "N/A", "N/A" ) )
      print( fmtValue % ( "Optical RX power",
            "N/A", "N/A", "N/A", "N/A", "N/A", "N/A", "N/A" ) )

class InterfaceTransceiverDomParameter( Model ):
   unit = Str( help="Parameter unit" )
   channels = Dict( keyType=str, valueType=float, valueOptional=True,
         help="Per channel parameter value. " +
               "'-' for the channel indicates a channel independent parameter" )
   threshold = Submodel( valueType=InterfaceTransceiverDetailThresholds,
                       help="Threshold for monitored parameter",
                       optional=True )

class InterfaceTransceiverDom( InterfaceTransceiverDomBase ):
   vendorSn = Str( help="Transceiver serial number" )
   mediaType = Str( help="Media type" )
   updateTime = Float( help="Last update time in UTC" )
   _hasMultiChannelDomParam = Bool( help="Has atleast one dom parameter that is"
                                    " channel specific", optional=True )
   _dualLaserModulePresent = Bool( help="Dual laser DCO module requires special "
                                        "dom formatting" )
   parameters = \
         Dict( keyType=str, valueType=InterfaceTransceiverDomParameter,
               help="A mapping of parameter name to per channel value and " +
                     "thresholds" )
   _paramOrder = \
         List( valueType=str, help="List of parameters in the order they " +
               "should be printed" )

   def _getCommonDomParams( self ) -> list:
      return [ paramType for paramType in self._paramOrder
               if '-' in self.parameters[ paramType ].channels ]

   def _getSingleChannelDomParamChannelLabel( self, intfName: str ) -> str:
      channel = ""
      # pylint: disable-msg=E1135
      # Pylint doesn't have enough control flow inference to know that we set
      # displayName to be a string instead of NoneType.
      if "Channel" not in self.displayName:
         # Get channel number from any of the channel specific DOM parameters
         for param in self.parameters.values():
            if '-' not in param.channels:
               # We found one, get the channel number entry.
               channelNum = int( next( iter( param.channels ) ) )
               # Since we are in a single channel domain,
               # we know that all channel specific DOM parameters will have the same
               # channel number inserted through the LaneMapping CLI infra
               # Breaking here since we found one entry, although it's not necessary.
               break
         channel = f"(Channel {channelNum})"
         if self._usePolsLabels:
            if channelNum == 1:
               channel = "Booster"
            elif channelNum == 2:
               channel = "Pre-amp"
      return channel

   def _getPerChannelDomParamLabel( self, paramType: str, channel: str ) -> str:
      label = parameterTypeToStr( paramType )
      if self._usePolsLabels:
         channelLabels = polsDomAdditionalLabel.get( int( channel ) )
         if channelLabels and paramType in channelLabels:
            label = " ".join( [ label, channelLabels[ paramType ] ] )
      return label

   def renderModelNoThresholds( self, intfName: str, printFmt: str,
                                printCommonDomInfo: bool ) -> None:
      assert printFmt == 'default', "Unrecognized DOM output format"
      commonDomParams = self._getCommonDomParams()
      # ---------------------- Common-Transceiver-Info----------------------- #
      # Do not print DOM headers and common params if it is already printed
      # for the port this interface belongs to.
      # Essentially this:
      #
      # Port 1
      # Last update: x:xx:xx ago
      #                           Value       Unit
      #                          --------    -------
      # Temperature                 x           C
      # Voltage                     x           V
      # Aggregate TX Power          x           dBm
      # Aggregate RX Power          x           dBm
      # .....
      # ...
      # once per port.
      if printCommonDomInfo:
         print( "" )
         # ----------------------DOM - Header-------------------------------- #
         updateTime = Ark.timestampToStr( self.updateTime, now=Tac.utcNow() )
         print( subPort( intfName ) )
         print( f"Last update: {updateTime}" )
         print( "%50s" % ( "Value" ) )
         print( "%56s" % ( 16 * '-' ) )
         # ---------------------Common-DOM-Params---------------------------- #
         for commonDomParam in commonDomParams:
            param = self.parameters[ commonDomParam ]
            val = noneToNegInf( param.channels[ '-' ] )
            if any( berFmt in parameterTypeToStr( commonDomParam ) for
                    berFmt in [ 'BER', 'Post-FEC errored frames' ] ):
               fmtValue = "   %-38s%4.2e %s"
            elif "BlkCount" in commonDomParam:
               fmtValue = "   %-38s%8d %s"
            else:
               fmtValue = "   %-40s%6.2f %s"

            domLine = fmtValue % ( parameterTypeToStr( commonDomParam ),
                                   val,
                                   param.unit )

            print( domLine.replace( "-inf", " N/A" ) )
      # -----------------------Channel-Specific-DOM-Info--------------------- #

      if self._dualLaserModulePresent or \
         isCmisCoherentMediaType( self.mediaType ):
         # dual laser dco module has only a single channel and should not group
         # parameters under a single interface header
         return

      # ---------------------Handle-Interface-Display-Name------------------- #
      # Essentially this:
      # EthernetX/Ethernet... Lane X ( Here Lane label is optional ) ( Channel X )
      # ( Channel X is optional as well ). Lane and Channel are dependent
      # on _hasMultiChannelDomParam being False.
      # This is printed for every interface.

      channel = ""
      fmtValue = "      %-8s%8s %26.2f %s"
      if not self._hasMultiChannelDomParam:
         channel = self._getSingleChannelDomParamChannelLabel( intfName )
         fmtValue = "   %-34s%12.2f %s"
      if self._usePolsLabels:
         print( channel )
      else:
         print( self.displayName, channel )

      # ----------------------Channel-Specific-DOM-Params-------------------- #
      # Essentially this:
      # If self._hasMultiChannelDomParam is True ...
      #    TX bias current
      #       Channel    1                 x           mA
      #       Channel    2                 x           mA
      #       Channel    ....
      #       ...
      #    Optical TX power
      #       Channel    1                 x           dBm
      #       Channel    2                 x           dBm
      #       Channel    ....
      #       ....
      #    Optical RX power
      #       ....
      # Else ...
      #    TX bias current                 x           mA
      #    Optical TX power                x           dBm
      #    Optical RX power                x           dBm
      # This is printed for every interface as well.
      for paramType in self._paramOrder:
         # Do not print common dom params again.
         if paramType in commonDomParams:
            continue
         param = self.parameters[ paramType ]
         unit = param.unit
         # Print paramType label only if it is a channel specific dom param
         if self._hasMultiChannelDomParam:
            print( f"   {parameterTypeToStr( paramType )}" )
         for chan, val in sorted( param.channels.items(),
                                  key=lambda chan: int( chan[ 0 ] ) ):
            val = noneToNegInf( val )
            if self._hasMultiChannelDomParam:
               domLine = fmtValue % ( "Channel", chan, val, unit )
            else:
               paramTypeStr = self._getPerChannelDomParamLabel( paramType, chan )
               if paramType in [ 'preFecBERCurr', 'uncorrectedBERCurr' ]:
                  domLine = "   %-38s%12.2e %s" % ( paramTypeStr, val, unit )
               else:
                  domLine = fmtValue % ( paramTypeStr, val, unit )
            print( domLine.replace( "-inf", " N/A" ) )

   def renderModelThresholds( self, intfName: str,
                              printCommonDomInfo: bool ) -> None:
      commonDomParams = self._getCommonDomParams()
      # ---------------------Common-Transceiver-Info------------------------------- #
      # Do not print DOM threshold headers and common params if it is already printed
      # for the port this interface belongs to.
      # Essentially this:
      #
      # Port 1
      # Last update: x:xx:xx ago
      #              High Alarm   High Warn   Low Warn   Low Alarm
      #      Value    Threshold   Threshold  Threshold   Threshold   Unit   Indicator
      #
      # Temp ......
      # Vol  ....
      # .....
      # ...
      # once per port ( Format fit to 85 chars for comments )
      if printCommonDomInfo:
         print( "" )
         # ----------------------DOM-Theshold-Header------------------------------- #
         updateTime = Ark.timestampToStr( self.updateTime, now=Tac.utcNow() )
         print( subPort( intfName ) )
         print( f"Last update: {updateTime}" )
         print( "%-45s%11s%11s%11s%11s" %
            ( " ", "High Alarm", "High Warn", "Low Warn", "Low Alarm" ) )
         print( "%-39s%-8s%-11s%-11s%-11s%-9s%6s%11s" % ( " ", "Value", "Threshold",
                     "Threshold", "Threshold", "Threshold", "Unit", "Indicator" ) )
         print( "%-33s%-77s" % ( " ", "-" * 75 ) )
         # -------------------------Common-DOM-Params------------------------------ #
         for commonDomParam in commonDomParams:
            param = self.parameters[ commonDomParam ]
            thresh = param.threshold
            if not param.threshold:
               continue

            parameterStr = parameterTypeToStr( commonDomParam )
            paramVal = noneToNegInf( param.channels[ '-' ] )
            if 'BER' in commonDomParam:
               fmtString = ( "   {:<30}{:>11.2e}{:>12.2e}{:>11.2e}{:>11.2e}{:>11.2e}"
                           " {:>5}{:>11}" )
               highAlarm = noneToNegInf( thresh.highAlarm )
               highWarn = noneToNegInf( thresh.highWarn )
               lowWarn = noneToNegInf( thresh.lowWarn )
               lowAlarm = noneToNegInf( thresh.lowAlarm )
            else:
               fmtString = "   {:<30}{:>11.2f}{}{}{}{} {:>5}{:>11}"
               highAlarm = formatThreshold( noneToNegInf( thresh.highAlarm ),
                                           thresh.highAlarmOverridden, length=12,
                                           alignLeft=False )
               highWarn = formatThreshold( noneToNegInf( thresh.highWarn ),
                                          thresh.highWarnOverridden, length=11,
                                          alignLeft=False )
               lowWarn = formatThreshold( noneToNegInf( thresh.lowWarn ),
                                         thresh.lowWarnOverridden, length=11,
                                         alignLeft=False )
               lowAlarm = formatThreshold( noneToNegInf( thresh.lowAlarm ),
                                          thresh.lowAlarmOverridden, length=11,
                                          alignLeft=False )

            domLine = fmtString.format( parameterStr,
                         paramVal, highAlarm,
                         highWarn,
                         lowWarn,
                         lowAlarm, param.unit,
                         indicator( paramVal, thresh.highAlarm,
                                    thresh.highWarn, thresh.lowWarn,
                                    thresh.lowAlarm ) )
            print( domLine.replace( "-inf", " N/A" ) )

      # ------------------Channel-Specific-DOM-Threshold-Info--------------------- #

      if self._dualLaserModulePresent or \
         isCmisCoherentMediaType( self.mediaType ):
         # dual laser dco module has only a single channel and should not group
         # parameters under a single interface header
         return

      # -----------------------Handle-Interface-Display-Name---------------------- #
      # Essentially this:
      # EthernetX/Ethernet... Lane X ( Here Lane lable is optional ) ( Channel X )
      # ( Channel X is optional as well ). Lane and Channel are dependent
      # on _hasMultiChannelDomParam being False.
      # This is printed for every interface.
      channel = ""
      fmtValue = "      {:<15}{:>8}{:>15.2f}{}{}{}{} {:>5}{:>11}"
      if not self._hasMultiChannelDomParam:
         channel = self._getSingleChannelDomParamChannelLabel( intfName )
         fmtValue = "   {:<34}{:>7.2f}{}{}{}{} {:>5}{:>11}"
      if self._usePolsLabels:
         print( channel )
      else:
         print( self.displayName, channel )

      # --------------------Channel-Specific-DOM-Threshold-Params---------------- #
      # Essentially this:
      # If self._hasMultiChannelDomParam is True ...
      #    TX bias current
      #       Channel    1     x     x      x     x     x     mA
      #       Channel    2     .....
      #       Channel    ....
      #       ...
      #    Optical TX power
      #       Channel    1     x     x      x     x     x     dBm
      #       Channel    2    ......
      #       Channel    ....
      #       ....
      #    Optical RX power
      #       ....
      # Else ...
      #    TX bias  current    x     x      x     x     x     mA
      #    Optical TX power    x     x      x     x     x     dBm
      #    Optical RX power    .....
      # This is printed for every interface as well.
      for paramType in self._paramOrder:
         # Do not print common dom params again.
         if paramType in commonDomParams:
            continue
         param = self.parameters[ paramType ]
         thresh = param.threshold
         if not param.threshold:
            continue
         unit = param.unit
         highAlarm = formatThreshold( noneToNegInf( thresh.highAlarm ),
                                      thresh.highAlarmOverridden, length=12,
                                      alignLeft=False )
         highWarn = formatThreshold( noneToNegInf( thresh.highWarn ),
                                     thresh.highWarnOverridden, length=11,
                                     alignLeft=False )
         lowWarn = formatThreshold( noneToNegInf( thresh.lowWarn ),
                                    thresh.lowWarnOverridden, length=11,
                                    alignLeft=False )
         lowAlarm = formatThreshold( noneToNegInf( thresh.lowAlarm ),
                                     thresh.lowAlarmOverridden, length=11,
                                     alignLeft=False )
         # Print paramType label only if it is a channel specific dom param
         if self._hasMultiChannelDomParam:
            print( f"   {parameterTypeToStr( paramType )}" )
         for chan, val in sorted( param.channels.items(),
                                  key=lambda chan: int( chan[ 0 ] ) ):
            val = noneToNegInf( val )
            indicatorLabel = indicator( val, thresh.highAlarm, thresh.highWarn,
                                        thresh.lowWarn, thresh.lowAlarm )
            if self._hasMultiChannelDomParam:
               domLine = fmtValue.format( "Channel", chan, val,
                                       highAlarm,
                                       highWarn,
                                       lowWarn,
                                       lowAlarm, unit,
                                       indicatorLabel )
            else:
               paramTypeStr = self._getPerChannelDomParamLabel( paramType, chan )
               domLine = fmtValue.format( paramTypeStr,
                                       val, highAlarm,
                                       highWarn,
                                       lowWarn,
                                       lowAlarm, unit,
                                       indicatorLabel )
            print( domLine.replace( "-inf", " N/A" ) )

class InterfacesTransceiverDom( Model ):
   _printFmt = Enum( values=SUPPORTED_FORMATS, help="Type of print format" )
   _detailed = Bool( help="Include warning and alarm thresholds" )
   interfaces = Dict( keyType=Interface, valueType=InterfaceTransceiverDomBase,
                     help="Mapping between interface name and the transceiver \
                           information" )
   _xcvrNames = List( valueType=str,
                      help="Transceiver slot names corresponding to the \
                            requested interfaces", optional=True )
   _interfacesOrder = List( valueType=str,
                            help="Order of interfaces to be displayed",
                            optional=True )
   _domThresholdOverrideEnabled = Bool(
      help="Indicates whether DOM threshold override is enabled", default=False )

   def render( self ) -> None:
      if self._detailed:
         self._renderThresholds()
      else:
         self._renderNotThresholds()

   def _printDomThresHeaderDefault( self ) -> None:
      print( "Ch: Channel, mA: milliamperes, dBm: decibels (milliwatts)," )
      print( "C: Celsius, V: Volts, NA or N/A: not applicable." )
      if self._domThresholdOverrideEnabled:
         print( "(*) Threshold has been overridden via configuration." )

   def _renderThresholds( self ) -> None:
      assert self._printFmt == 'default', "Unrecognized DOM output format"

      if not self.interfaces:
         # If there are no interfaces, skip even printing headers.
         return

      self._printDomThresHeaderDefault()

      # Print according to _interfacesOrder and not sorted( self.interfaces ... )
      for intf in self._interfacesOrder:
         xcvrName = getXcvrSlotName( intf )
         self.interfaces[ intf ].renderModelThresholds( intf,
                                            xcvrName not in self._xcvrNames )
         # Having duplicates won't affect the output.
         self._xcvrNames.append( xcvrName )

   def _renderNotThresholds( self ) -> None:
      assert self._printFmt == 'default', "Unrecognized DOM output format"
      printHeaderFn = _printDomHeaderDefault

      if not self.interfaces:
         # If there are no interfaces, skip even printing headers.
         return

      printHeaderFn()

      # Print according to _interfacesOrder and not sorted( self.interfaces ... )
      for intf in self._interfacesOrder:
         xcvrName = getXcvrSlotName( intf )
         self.interfaces[ intf ].renderModelNoThresholds( intf, self._printFmt,
                                                   xcvrName not in self._xcvrNames )
         # Having duplicates won't affect the output.
         self._xcvrNames.append( xcvrName )
