# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import CliModel
from CliPlugin import PhySerdesModelCommon
from CliPlugin import PhyStatusModel
from CliPlugin.PhySerdesModelCommon import ( brPdEnToStr, linktrnStsToStr,
                                             rxPamModeToStr,
                                             SerdesLaneStatus )

from itertools import chain
from collections.abc import Callable
from collections import defaultdict

# Example output
# Lane 0 showing max length values for comma separated
# -----------------------------------------------------------------------------
# Link Training
#   Lane 0                    off
#   Lane 1
#   Lane 2
#   Lane 3
# Tx Equalization
#   Lane     pre3  pre2  pre1  main  post1  post2
#   0        -...  -...  -...  -...  -...   -...
#   1
#   2
#   3
# SerDes Statistics
#   Lane  RXPPM  PF(M,L,H)   VGA   DCO  TP(0,1,2)   CDR        TXPPM
#   0     -..... ...,...,... -...  -... -.....,..,. ......-..  -.....*
#   1
#   2
#   3
#   Lane  RXFFE(N3,N2,N1,M,P1,P2)                   DFE(1,2)  FLT(M,S)
#   0     -.....,-.....,-.....,-.....,-.....,-..... -...,-... -.....,.....
#   1
#   2
#   3
#   Lane  SNR(dB)    VEYE(U,M,L)
#   0     .......... -.....,-.....,-.....
#   1
#   2
#   3

def serdesHeaderRow( colWidths, colHeaders ):
   assert len( colWidths ) == len( colHeaders )
   header = "    " + " ".join( f"{{:{w}}}" for w in colWidths )
   row = header.replace( ":", ":<" )
   header = header.format( *colHeaders )
   return ( header, row )

serdesStatsLine1Header, serdesStatsLine1Row = serdesHeaderRow(
   ( 5, 6, 11, 4, 4, 11, 10, 7 ),
   ( "Lane", "RXPPM", "PF(M,L,H)", "VGA", "DCO", "TP(0,1,2)", "CDR", "TXPPM" ) )
serdesStatsLine2Header, serdesStatsLine2Row = serdesHeaderRow(
   ( 5, 41, 9, 12 ),
   ( "Lane", "RXFFE(N3,N2,N1,M,P1,P2)", "DFE(1,2)", "FLT(M,S)" ) )
serdesStatsLine3Header, serdesStatsLine3Row = serdesHeaderRow(
   ( 5, 10, 20 ),
   ( "Lane", "SNR(dB)", "{}" ) ) # to be able to conditionally display veye

txEqColWidths = ( 8, 5, 5, 5, 5, 6, 6 )
txEqHeader = '    ' + ' '.join( f'{{:{w}}}' for w in txEqColWidths )
txEqRow = txEqHeader.replace( ':', ':<' )

tp2ToString = defaultdict( lambda: "10GRate",
   { 0: "0",
     1: "1",
     2: "2",
     3: "LinkCAT",
     4: "HalfRate",
     5: "10GRate", }
)

def tp2StringToCli( inputStr ):
   if inputStr == tp2ToString[ 3 ]:
      return "L"
   elif inputStr == tp2ToString[ 4 ]:
      return "H"
   elif inputStr == tp2ToString[ 5 ]:
      return "Q"
   else:
      return inputStr

class PeregrinePhySerdesLaneModel( SerdesLaneStatus ):
   rxPamMode = CliModel.Enum( values=list( rxPamModeToStr.values() ),
                              help="Receiver PAM4 mode" )
   cdrMode = CliModel.Enum( values=list( brPdEnToStr.values() ),
                            help="Baud Rate Phase Detector enable" )
   osForced = CliModel.Bool( help="Whether the OS CDR is forced")
   brForced = CliModel.Bool( help="Whether the BR CDR is forced")

   midFrequencyPeakingFilter = CliModel.Int( help="Mid frequency peaking"
                                                  " filter control" )
   lowFrequencyPeakingFilter = CliModel.Int( help="Low frequency peaking"
                                                  " filter control" )
   highFrequencyPeakingFilter = CliModel.Int( help="High frequency peaking"
                                                   " filter control" )

   tp0 = CliModel.Int( help="Echos 'RX_Channel_Loss_hint' provided by user in both"
                            " PAM4 and NRZ modes" )
   tp1 = CliModel.Int( help="Initial Channel Loss Estimate from 0-50:"
                            "Range corresponds with short to long channels" )
   tp2 = CliModel.Enum( values=list( tp2ToString.values() ),
                        help= "AFE Static BW setting. "
                              "0 to 2 - Used for 100G, tuned at link startup on"
                              " short channels." )

   rxFfe = CliModel.Dict( keyType=int, valueType=int,
                          help="A mapping of RX feed-forward equalizer (FFE) tap "
                               "indices to their values" )

   floatingTapMax = CliModel.Int( help="Max magnitude of all floating taps" )
   floatingTapSum = CliModel.Int( help="Absolute sum of magnitudes of all taps" )

   linkTrainingStatus = CliModel.Enum( values=list( chain( linktrnStsToStr.values(),
                                                           [ "off" ] ) ),
                                       help="Link training Status" )

   snrDfe = CliModel.Float( help="SNR in db. "
                                 "(20.42, 19.46, 18.26 corresponding roughly to "
                                 "1e-6,1e-5,1e-4 PAM4 BER)" )

   txPrecoderEnable = CliModel.Bool( help="TX precoder enable" )

   eyeUpper = CliModel.Int( optional=True,
                            help="Upper Eye margin @ 1e-5 (NRZ) or 1e-3 (PAM4) as"
                                 " seen by internal diagnostic slicer in mV" )
   eyeMiddle = CliModel.Int( optional=True,
                             help="Middle Eye margin @ 1e-5 (NRZ) or 1e-3 (PAM4) as"
                                  " seen by internal diagnostic slicer in mV"  )
   eyeLower = CliModel.Int( optional=True,
                            help="Lower Eye margin @ 1e-5 (NRZ) or 1e-3 (PAM4) as"
                                 " seen by internal diagnostic slicer in mV"  )

   _enable6Taps = CliModel.Int( help="6 tap TX equalization enabled" )

   txPiEn = CliModel.Bool( help="TX phase interpolator enable" )

   def toModel( self, serdes ):
      super().toModel( serdes )
      self.rxPamMode = rxPamModeToStr[ serdes.rxPamMode ]
      self.cdrMode = brPdEnToStr[ serdes.cdrMode ]
      self.brForced = serdes.brForced
      self.osForced = serdes.osForced

      self.midFrequencyPeakingFilter = serdes.midFrequencyPeakingFilter
      self.lowFrequencyPeakingFilter = serdes.lowFrequencyPeakingFilter
      self.highFrequencyPeakingFilter = serdes.highFrequencyPeakingFilter

      self.tp0 = serdes.tp0
      self.tp1 = serdes.tp1
      self.tp2 = tp2ToString[ serdes.tp2 ]

      self.rxFfe = dict( serdes.rxFfe )

      self.floatingTapMax = serdes.floatingTapMax
      self.floatingTapSum = serdes.floatingTapSum

      self.linkTrainingStatus = ( linktrnStsToStr[ serdes.linkTrainingStatus ]
                                  if serdes.linkTrainingEnable else "off" )

      self.snrDfe = serdes.snrDfe
      self.txPrecoderEnable = serdes.txPrecoderEnable == 1

      if not serdes.disableEyeDisplay:
         self.eyeUpper = serdes.eyeUpper
         self.eyeMiddle = serdes.eyeMiddle
         self.eyeLower = serdes.eyeLower
      else:
         self.eyeUpper = None
         self.eyeMiddle = None
         self.eyeLower = None

      self._enable6Taps = serdes.enable6Taps
      if not serdes.enable6Taps:
         toRemove = [ -3, -2, 2 ]
         for i in toRemove:
            # Use pop to remove instead of del, as it wont throw an exception
            # on KeyError, as we give it a default return of None
            self.txTaps.pop( i, None )
      self.txPiEn = serdes.txPiEn
      return self

   def printLinkTrainingInfo( self ):
      PhySerdesModelCommon.printLinkTrainingInfo( self._laneId,
                                                  self.linkTrainingStatus != "off",
                                                  self.linkTrainingStatus )

   def printTxEqInfo( self ):
      print( txEqRow.format( self._laneId,
                             self.txTaps.get( -3, 'x' ),
                             self.txTaps.get( -2, 'x' ),
                             self.txTaps.get( -1, 'x' ),
                             self.txTaps.get( 0, 'x' ),
                             self.txTaps.get( 1, 'x' ),
                             self.txTaps.get( 2, 'x' ) ) )

   def eyeValuesDisabled( self ):
      return all( eyeVal is None for eyeVal in [ self.eyeUpper,
                                                 self.eyeMiddle,
                                                 self.eyeLower ] )

   def printSerdesStatsLine1( self ):
      #   Lane  RXPPM  PF(M,L,H)   VGA   DCO  TP(0,1,2)   CDR        TXPPM
      #   0     -..... ...,...,... -...  -... -.....,..,. ......-..  -.....*
      laneId = str(  self._laneId )
      rxPpm = str( self.rxPpm )
      pfTaps = f"{self.midFrequencyPeakingFilter},{self.lowFrequencyPeakingFilter}"\
               f",{self.highFrequencyPeakingFilter}"
      vga = str( self.vga )
      dco = str( self._dcOffset )
      tpValues = f"{self.tp0},{self.tp1},{tp2StringToCli( self.tp2 )}"
      autoForcedStr = "auto"
      if ( self.cdrMode == brPdEnToStr[ 0 ] and self.osForced or
           self.cdrMode == brPdEnToStr[ 1 ] and self.brForced ):
         autoForcedStr = "forced"
      cdr = f"{autoForcedStr}-{self.cdrMode}"
      txPiMark = "*" if self.txPiEn else ""
      txPpm = str( self.txPpm ) + txPiMark
      print( serdesStatsLine1Row.format(
                laneId, rxPpm, pfTaps, vga, dco, tpValues, cdr, txPpm ) )

   def printSerdesStatsLine2( self ):
      #   Lane  RXFFE(N3,N2,N1,M,P1,P2)                   DFE(1,2)  FLT(M,S)
      #   0     -.....,-.....,-.....,-.....,-.....,-..... -...,-... -.....,.....
      def tryGetValues( toLookup : dict[ int, int ],
                        rangeTuple : tuple[ int, int ] ):
         vals = []
         for i in range( rangeTuple[ 0 ], rangeTuple[ 1 ] + 1 ):
            vals.append( str( toLookup.get( i, "x" ) ) )
         return ",".join( vals )
      laneId = str( self._laneId )
      rxFfe = tryGetValues( self.rxFfe, ( -3, 2 ) )
      dfe = tryGetValues( self.dfeTaps, ( 1, 2 ) )
      flt = f"{self.floatingTapMax},{self.floatingTapSum}"
      print( serdesStatsLine2Row.format( laneId, rxFfe, dfe, flt ) )

   def printSerdesStatsLine3( self ):
      #   Lane  SNR(dB)    VEYE(U,M,L)
      #   0     .......... -.....,-.....,-.....
      laneId = str(  self._laneId )
      snr = f"{self.snrDfe:.2f}"
      eye = ""
      if not self.eyeValuesDisabled():
         isNrz = self.rxPamMode == rxPamModeToStr[ 0 ]
         if isNrz:
            eye = f"{self.eyeUpper},x,x"
         else:
            eye = f"{self.eyeUpper},{self.eyeMiddle},{self.eyeLower}"
      print( serdesStatsLine3Row.format( laneId, snr, eye ) )

SerdesCallable = Callable[ [ PeregrinePhySerdesLaneModel ], None ]

class PeregrinePhySerdesModel( CliModel.Model ):
   serdesStats = CliModel.Dict( keyType=int,
                                valueType=PeregrinePhySerdesLaneModel,
                                help="A mapping of lanes to SerDes statistics" )
   _sortedSerdesStats = CliModel.List( valueType=int,
                                       help="Sorted SerDes statistics")
   def toModel( self, serdes ):
      for lane, stats in serdes.items():
         laneStats = PeregrinePhySerdesLaneModel().toModel( stats )
         # XXX kewei: Remove serdesId once BUG336826 is resolved.
         self.serdesStats[ getattr( lane, "laneId", lane ) ] = laneStats
      return self

   def callOnOrderedLanes( self, func : SerdesCallable ):
      return [
         func( self.serdesStats[ laneId ] ) for laneId in self._sortedSerdesStats
      ]

   def printLaneInfo( self, func : SerdesCallable, header : str="",
                      legend : str="" ):
      if header:
         # I am not getting into updating this
         # pylint: disable=consider-using-f-string
         print( PhyStatusModel.phyDetailLongHeaderFmt % ( header ) )
      if legend:
         print( legend )
      self.callOnOrderedLanes( func )

   def printLinkTrainingInfo( self ):
      self.printLaneInfo(  lambda x: x.printLinkTrainingInfo(), "Link Training" )

   def printTxEqInfo( self ):
      self.printLaneInfo( lambda x: x.printTxEqInfo(), "Tx Equalization",
                          txEqHeader.format( "Lane", "pre3", "pre2", "pre1", "main",
                                             "post1", "post2" ) )

   def printSerdesStats( self ):
      # I am not getting into updating this
      # pylint: disable=consider-using-f-string
      print( PhyStatusModel.phyDetailLongHeaderFmt % ( "SerDes Statistics" ) )
      self.printLaneInfo( lambda x: x.printSerdesStatsLine1(), "",
                          serdesStatsLine1Header )
      self.printLaneInfo( lambda x: x.printSerdesStatsLine2(), "",
                          serdesStatsLine2Header )
      eyeDisplayDisabled = all(
         self.callOnOrderedLanes( lambda x: x.eyeValuesDisabled() ) )
      eyeHeaderStr = "" if eyeDisplayDisabled else "VEYE(U,M,L)"
      self.printLaneInfo( lambda x: x.printSerdesStatsLine3(), "",
                          serdesStatsLine3Header.format( eyeHeaderStr ) )

   def renderSerdesDetail( self ):
      self._sortedSerdesStats = sorted( self.serdesStats )
      self.printLinkTrainingInfo()
      self.printTxEqInfo()
      self.printSerdesStats()

