# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

# CmisCdb.py is a platform-independent library that implements
# CMIS4.0 CDB commands to Xcvr modules.

import time
import sys
from datetime import datetime
import Tac
import CmisErrors

# CDB Status ( 00l:37 )
CDB_SUCCESS = 0x1

class Cdb:
   def __init__( self, broker, arguments ):
      '''
      Arguments:
      broker - a class that implements platform-dependent registers access
      password - defines whether to apply CMIS4.0 8.2.10 password
                in firmware programming commands
      epl_page_delay, epl_command_delay - extra delays added within CMD104
                ( EPL download command ) for some modules not fully compliant
               with CMIS4.0 timing
      '''
      self._b = broker
      self.password = None
      self.epl_page_delay = None
      self.epl_command_delay = None
      self.maxDownloadBlockSize = 0
      self.errors = {}
      self.verbose = arguments.verbose or arguments.veryverbose
      self.skipCdbBlockComplete = arguments.skip_cdb_block_complete
      self.autoPagingSupport = False
      self.cmdProcessingInOneTransaction = False

   def smbusBroker( self ):
      return self._b

   def _dumpLplPacket( self, packet ):
      print( "  CDB command: %s" % " ".join( "0x%.2x" % i for i in packet[ : 2 ] ) )
      print( "  EPL length: %s" % " ".join( "0x%.2x" % i for i in packet[ 2 : 4 ] ) )
      print( "  LPL length: %s" % "0x%.2x" % packet[ 4 ] )
      print( "  CdbChkCode: %s" % "0x%.2x" % packet[ 5 ] )
      print( "  RLPLLen: %s" % "0x%.2x" % packet[ 6 ] )
      print( "  RLPLChkCode: %s" % "0x%.2x" % packet[ 7 ] )
      print( "  Offset: %s" % " ".join( "0x%.2x" % i for i in packet[ 8 : 12 ] ) )
      print( "  Payload: %s" % " ".join( "0x%.2x" % i for i in packet[ 12 : ] ) )

   def cdbErrorParse( self, errStatus ):
      errMsg = { 0: "Failed, no specific failure code",
                 1: "CMD code unknown",
                 2: "Param range error or not supported",
                 3: "Previous CMD was not aborted",
                 4: "CMD checking time out",
                 5: "CdbCheckCode error",
                 6: "Password error, insufficient privilege",
                 "default": f"Unknown status {errStatus:#x}" }
      if errStatus & 0x40:
         errCode = errStatus & 0x1F
         if errCode in errMsg:
            return errMsg[ errCode ]
      return errMsg[ "default" ]

   def convertStr( self, cmd ):
      tmp = []
      for i in cmd:
         tmp.append( f"{i:02x}" )
      return "".join( tmp )

   def addCdbError( self, errcode, error, errStatus=None ):
      cmisError = {}
      cmisError[ 'error' ] = error
      cmisError[ 'time' ] = f"{datetime.now()}"
      if errStatus is not None:
         cmisError[ 'errstatus' ] = f"{errStatus:#x}"
      self.errors[ errcode ] = cmisError

   def decodeCdbErrors( self, cdbBlockCompleteTimeout ):
      try:
         # Decode CDB Status errors
         cdbStatus = self._b.twi_rr( 37 )
         if cdbStatus & 0x40:
            self.addCdbError(
               CmisErrors.CMIS_STS_FAIL_ERROR,
               CmisErrors.cdbStatusCmdResults.get( cdbStatus,
                                                   "Fail: unknown error" ),
               cdbStatus )
         elif cdbStatus & 0x80:
            self.addCdbError(
               CmisErrors.CMIS_STS_BUSY_TIMEOUT_ERROR,
               CmisErrors.cdbStatusCmdResults.get( cdbStatus,
                                                   "Timeout: unknown error" ),
               cdbStatus )
         elif cdbStatus != CDB_SUCCESS:
            self.addCdbError(
               CmisErrors.CMIS_STS_UNKNOWN_ERROR,
               CmisErrors.cdbStatusCmdResults.get( cdbStatus,
                                                   "Unknown STS error" ),
               cdbStatus )

         # Detect CDB Complete timeouts
         errcode = CmisErrors.CMIS_MODFLAG_TIMEOUT_ERROR
         if cdbBlockCompleteTimeout:
            self.addCdbError(
               errcode, CmisErrors.cdbBlockCompleteTimeout, None )
         else:
            # Clear rhe error (requested by Microsoft)
            self.errors.pop( errcode, None )

      except AttributeError as e:
         print( e )
         self.addCdbError( CmisErrors.CMIS_CDB_SOFTWARE_ERROR,
                           CmisErrors.cmisCdbSoftwareError +
                           f": {e}" )

      if self.verbose and self.errors:
         print( f"{self.errors}" )

   def decodeRlplErrors( self ):
      try:
         self.addCdbError(
            CmisErrors.CMIS_CDB_RLPL_CHKSUM_ERROR,
            CmisErrors.cdbResponseChecksumError )

      except AttributeError as e:
         print( e )
         self.addCdbError( CmisErrors.CMIS_CDB_SOFTWARE_ERROR,
                           CmisErrors.cmisCdbSoftwareError +
                           f": {e}" )

      if self.verbose and self.errors:
         print( f"{self.errors}" )

   # Multibyte Write using advertised CDB smbus write transaction size
   def dataBlockWrite( self, reg, buf ):
      # avolinsk ToDo: switch to using only generic 'else' case
      if self.maxDownloadBlockSize == 0:
         self._b.twi_msbw( reg, buf )
      else:
         buf_len = len( buf )
         chunks = buf_len // self.maxDownloadBlockSize
         if buf_len % self.maxDownloadBlockSize:
            chunks += 1
         for i in range( chunks ):
            self._b.twi_msbw( ( reg + ( i * self.maxDownloadBlockSize ) ),
                              buf[ ( i * self.maxDownloadBlockSize ) :
                                   ( ( i + 1 ) * self.maxDownloadBlockSize ) ] )
   # Execute CDB command

   def cdbCmd( self, cmd ):
      self.dataBlockWrite( 130, cmd[ 2 : ] )
      self._b.twi_msbw( 128, cmd[ : 2 ] )

   # Get firmware info
   def cmd0100h( self ):
      if self.password == 'msaPassword':
         self._b.twi_sbw( 122, b'\x00\x00\x10\x11' )
      self._b.twi_sbw( 126, b'\x00\x9f' )
      cmd = bytearray( b'\x01\x00\x00\x00\x00\xFE\x00\x00' )
      self.cdbCmd( cmd )
      time.sleep( 0.5 )
      while self._b.cdb1_cip():
         time.sleep( 1 )
      if self._b.twi_rr( 37 ) & 0x40:
         print( "Cmd failed" )
         return 0, 0, [ 0 ]
      if self.password == 'msaPassword':
         self._b.twi_sbw( 122, b'\x00\x00\x00\x00' )
      rlplen = self._b.twi_rr( 134 )
      rlp_chkcode = self._b.twi_rr( 135 )
      msglen = self._b.cdb_rlp( cmd, rlplen, rlp_chkcode )
      rlp = self._b.twi_srr( 136, msglen )
      return rlplen, rlp_chkcode, rlp

   def cdb1FlagsWait( self, timeout=10 ):
      starttime = time.time() # Set start time
      cdbTimeout = False

      # Early exit for modules that do not populate CDB Block Complete bit
      if self.skipCdbBlockComplete:
         time.sleep( 0.2 )
         return cdbTimeout

      while self._b.cdb1_chkflags():
         if time.time() - starttime > timeout: # check time out
            cdbTimeout = True
            break
         time.sleep( 0.1 )
      return cdbTimeout

   def cdb1StatusWait( self, timeout=60 ):
      starttime = time.time() # Set start time
      # Start with checking CDB Block Complete flag,
      # to avoid scenario when CDB status is not updated yet
      # Since CDB Block Complete flag is latched, set timeout
      # flag and clear it upon detection of completion
      cdbBlockCompleteTimeout = self.cdb1FlagsWait()
      while self._b.cdb1_cip():
         if time.time() - starttime > timeout: # check time out
            break
         time.sleep( 0.1 )
      # Check CDB Block Complete, it should be updated by now
      if cdbBlockCompleteTimeout:
         cdbBlockCompleteTimeout = self.cdb1FlagsWait( timeout=0 )
      # Decode CDB errors
      self.decodeCdbErrors( cdbBlockCompleteTimeout )
      return self._b.twi_rr( 37 )

   # Start firmware download
   def cmd0101h( self, startBytes, head, size ):
      if self.password == 'msaPassword':
         self._b.twi_sbw( 122, b'\x00\x00\x10\x11' )
      self._b.twi_sbw( 126, b'\x00\x9f' )
      time.sleep( 0.3 )
      print( f"Image size is {size}" )
      cmd = bytearray(
         b'\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' )
      cmd[ 132 - 128 ] = startBytes + 8
      cmd[ 136 - 128 ] = ( size >> 24 ) & 0xff
      cmd[ 137 - 128 ] = ( size >> 16 ) & 0xff
      cmd[ 138 - 128 ] = ( size >> 8 ) & 0xff
      cmd[ 139 - 128 ] = ( size >> 0 ) & 0xff

      cmd = cmd + head
      cmd[ 133 - 128 ] = self._b.cdb_chkcode( cmd )
      self._b.twi_sbw( 136, cmd[ 8 : ] )
      self.cdbCmd( cmd[ : 8 ] )
      time.sleep( 0.2 )
      res = self.cdb1StatusWait()
      if self.password == 'msaPassword':
         self._b.twi_sbw( 122, b'\x00\x00\x00\x00' )
      return res

   # Abort firmware download
   def cmd0102h( self ):
      if self.password == 'msaPassword':
         self._b.twi_sbw( 122, b'\x00\x00\x10\x11' )
      self._b.twi_sbw( 126, b'\x00\x9f' )
      time.sleep( 0.1 )
      cmd = bytearray( b'\x01\x02\x00\x00\x00\x00\x00\x00' )
      cmd[ 133 - 128 ] = self._b.cdb_chkcode( cmd )
      self.cdbCmd( cmd )
      time.sleep( 0.5 )
      res = self.cdb1StatusWait()
      if self.password == 'msaPassword':
         self._b.twi_sbw( 122, b'\x00\x00\x00\x00' )
      return res

   # Download firmware via LPL
   # Extra delays are added to work with OSFP-LS ( II-VI ) modules
   def cmd0103h( self, addr, data ):
      # Payload is bytes 140-255, but LPLLen includes headers in bytes 136-139
      lpl_len = len( data ) + 4
      self._b.twi_sbw( 126, b'\x00\x9f' )
      time.sleep( 0.03 )  # presumably sleeping to give time for page switch...
      # Generate 4-byte representation of addr and write starting from byte 136
      addrBlock = bytearray( b'\x00\x00\x00\x00' )
      addrBlock[ 136 - 136 ] = ( addr >> 24 ) & 0xff
      addrBlock[ 137 - 136 ] = ( addr >> 16 ) & 0xff
      addrBlock[ 138 - 136 ] = ( addr >> 8 ) & 0xff
      addrBlock[ 139 - 136 ] = ( addr >> 0 ) & 0xff
      # Pad payload with 0s to overwrite any previous values (lpl_len should prevent
      # module from treating those fields as valid payload, but this is an
      # additional cautionary measure)
      padding = '\0' if sys.version_info.major == 2 else b'\0'
      paddedPayload = bytearray( data.ljust( self.maxDownloadBlockSize, padding ) )
      # Without this sleep, sometimes CDB status reports 0x45 (i.e. previous CDB
      # cmd was interrupted) despite the previous cmd reporting 0x01. 0.1 seems
      # high, but 0.05 still occasionally failed.
      # Write the rest of the CDB message
      cmd = bytearray( b'\x01\x03\x00\x00\x00\x00\x00\x00' )
      cmd[ 132 - 128 ] = lpl_len & 0xff
      cmd[ 133 - 128 ] = self._b.cdb_chkcode( cmd + addrBlock + paddedPayload )
      # According to II-VI, writing entire packet in one go is best practice
      packet = cmd + addrBlock + paddedPayload
      if self.cmdProcessingInOneTransaction:
         self._b.twi_msbw( 128, packet )
      else:
         self._b.twi_msbw( 130, packet[ 2 : ] )
         self._b.twi_msbw( 128, packet[ 0 : 2 ] )
      while self._b.twi_rr( 37 ) != 0x1:
         if self._b.twi_rr( 37 ) == 0x45:
            p9f = self._b.twi_srr( 128, 128 )
            print( "Sent packet:" )
            self._dumpLplPacket( packet )
            print()
            print( "Page 9Fh after 0x45 status:" )
            self._dumpLplPacket( p9f )
            print()
            # According to II-VI, POLS can retry same packet if 0x45 is seen
            print( "Retry packet..." )
            # Use the conservative approach by splitting into two transactions
            self._b.twi_msbw( 130, packet[ 2 : ] )
            self._b.twi_msbw( 128, packet[ 0 : 2 ] )
      return self._b.twi_rr( 37 )

   # Download firmware via EPL
   # Extra delays are added to work with multiple 400GBASE-ZR modules
   def cmd0104h( self, addr, data ):
      epl_len = len( data )
      pages = epl_len // 128
      if epl_len % 128:
         pages += 1
      # Restore bank 0
      self._b.twi_bw( 126, 0 )
      for i in range( pages ):
         # Some modules implement auto-paging and increment
         # the page automatically. If it is not so,
         # program new page manually
         next_page = 0xa0 + i
         if not self.autoPagingSupport or self._b.twi_rr( 127 ) != next_page:
            self._b.twi_bw( 127, next_page )
         self.dataBlockWrite( 128, data[ 128 * i : 128 * ( i + 1 ) ] )
         # Optional sleep after page write
         if self.epl_page_delay:
            time.sleep( self.epl_page_delay )
      self._b.twi_bw( 127, 0x9f )
      # Optional sleep before sending Cdb command
      if self.epl_command_delay:
         time.sleep( self.epl_command_delay )
      cmd = bytearray( b'\x01\x04\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00' )
      cmd[ 130 - 128 ] = ( epl_len >> 8 ) & 0xff
      cmd[ 131 - 128 ] = ( epl_len >> 0 ) & 0xff
      cmd[ 136 - 128 ] = ( addr >> 24 ) & 0xff
      cmd[ 137 - 128 ] = ( addr >> 16 ) & 0xff
      cmd[ 138 - 128 ] = ( addr >> 8 ) & 0xff
      cmd[ 139 - 128 ] = ( addr >> 0 ) & 0xff
      cmd[ 133 - 128 ] = self._b.cdb_chkcode( cmd )
      self.cdbCmd( cmd )
      res = self.cdb1StatusWait()
      return res

   # Complete firmware download
   def cmd0107h( self ):
      if self.password == 'msaPassword':
         self._b.twi_sbw( 122, b'\x00\x00\x10\x11' )
      self._b.twi_sbw( 126, b'\x00\x9f' )
      time.sleep( 0.1 )
      cmd = bytearray( b'\x01\x07\x00\x00\x00\xf7\x00\x00' )
      self.cdbCmd( cmd )
      time.sleep( 0.5 )
      res = self.cdb1StatusWait( timeout=90 )
      if self.password == 'msaPassword':
         self._b.twi_sbw( 122, b'\x00\x00\x00\x00' )
      time.sleep( 0.1 )
      return res

   # Copy firmware image
   def cmd0108h( self, direction ):
      if self.password == 'msaPassword':
         self._b.twi_sbw( 122, b'\x00\x00\x10\x11' )
      self._b.twi_sbw( 126, b'\x00\x9f' )
      time.sleep( 0.1 )
      cmd = bytearray( b'\x01\x08\x00\x00\x01\x00\x00\x00\x00' )
      cmd[ -1 ] = direction
      cmd[ 133 - 128 ] = self._b.cdb_chkcode( cmd )
      self.cdbCmd( cmd )
      res = self.cdb1StatusWait()
      if self.password == 'msaPassword':
         self._b.twi_sbw( 122, b'\x00\x00\x00\x00' )
      return res

   # Run firmware image
   def cmd0109h( self, mode=0x01 ):
      if self.password == 'msaPassword':
         self._b.twi_sbw( 122, b'\x00\x00\x10\x11' )
      self._b.twi_sbw( 126, b'\x00\x9f' )
      cmd = bytearray( b'\x01\x09\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00' )
      cmd[ 137 - 128 ] = mode
      cmd[ 138 - 128 ] = 4 # Delay to Reset 1 sec
      cmd[ 133 - 128 ] = self._b.cdb_chkcode( cmd )
      self.cdbCmd( cmd )
      res = self.cdb1StatusWait()
      if self.password == 'msaPassword':
         self._b.twi_sbw( 122, b'\x00\x00\x00\x00' )
      return res

   # Commit firmware image
   def cmd010Ah( self ):
      if self.password == 'msaPassword':
         self._b.twi_sbw( 122, b'\x00\x00\x10\x11' )
      self._b.twi_sbw( 126, b'\x00\x9f' )
      cmd = bytearray( b'\x01\x0A\x00\x00\x00\xf4\x00\x00' )
      print( cmd )
      self.cdbCmd( cmd )
      res = self.cdb1StatusWait()
      if self.password == 'msaPassword':
         self._b.twi_sbw( 122, b'\x00\x00\x00\x00' )
      return res

   # QUERY command
   def cmd0000h( self, ms=0 ):
      self._b.twi_sbw( 126, b'\x00\x9f' )
      cmd = bytearray( b'\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00' )
      cmd[ -1 ] = ms & 0xff
      cmd[ -2 ] = ms >> 8
      cmd[ 133 - 128 ] = self._b.cdb_chkcode( cmd )
      self.cdbCmd( cmd )
      self.cdb1StatusWait()
      rlplen = self._b.twi_rr( 134 )
      rlp_chkcode = self._b.twi_rr( 135 )
      rlp = self._b.twi_srr( 136, rlplen )
      return rlplen, rlp_chkcode, rlp

   # Module Features Implemented
   def cmd0040h( self ):
      self._b.twi_sbw( 126, b'\x00\x9f' )
      cmd = bytearray( b'\x00\x40\x00\x00\x00\xBF\x00\x00' )
      self.cdbCmd( cmd )
      self.cdb1StatusWait()
      rlplen = self._b.twi_rr( 134 )
      rlp_chkcode = self._b.twi_rr( 135 )
      msglen = self._b.cdb_rlp( cmd, rlplen, rlp_chkcode )
      rlp = self._b.twi_srr( 136, msglen )
      return rlplen, rlp_chkcode, rlp

   # Firmware Upgrade Features Implemented
   def cmd0041h( self ):
      self._b.twi_sbw( 126, b'\x00\x9f' )
      cmd = bytearray( b'\x00\x41\x00\x00\x00\xBE\x00\x00' )
      self.cdbCmd( cmd )
      self.cdb1StatusWait()
      rlplen = self._b.twi_rr( 134 )
      rlp_chkcode = self._b.twi_rr( 135 )
      msglen = self._b.cdb_rlp( cmd, rlplen, rlp_chkcode )
      rlp = self._b.twi_srr( 136, msglen )
      return rlplen, rlp_chkcode, rlp

   # Performance Monitoring Implemented
   def cmd0042h( self ):
      self._b.twi_sbw( 126, b'\x00\x9f' )
      cmd = bytearray( b'\x00\x42\x00\x00\x00\xBD\x00\x00' )
      self.cdbCmd( cmd )
      self.cdb1StatusWait()
      rlplen = self._b.twi_rr( 134 )
      rlp_chkcode = self._b.twi_rr( 135 )
      msglen = self._b.cdb_rlp( cmd, rlplen, rlp_chkcode )
      rlp = self._b.twi_srr( 136, msglen )
      return rlplen, rlp_chkcode, rlp

   # BERT and Diagnostics Implemented
   def cmd0043h( self ):
      self._b.twi_sbw( 126, b'\x00\x9f' )
      cmd = bytearray( b'\x00\x43\x00\x00\x00\xBC\x00\x00' )
      self.cdbCmd( cmd )
      self.cdb1StatusWait()
      rlplen = self._b.twi_rr( 134 )
      rlp_chkcode = self._b.twi_rr( 135 )
      msglen = self._b.cdb_rlp( cmd, rlplen, rlp_chkcode )
      rlp = self._b.twi_srr( 136, msglen )
      return rlplen, rlp_chkcode, rlp
