#!/usr/bin/env python3
# Copyright (c) 2014 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import argparse
from itertools import chain
import sys

import PLSmbusUtil
import Smbus_pb2

backendLookup = {
   "simulated" : Smbus_pb2.SIMULATED,
   "scd" : Smbus_pb2.SCD,
   "kernel" : Smbus_pb2.KERNEL_DEV,
   "ioport" : Smbus_pb2.IOPORT,
   "celestica" : Smbus_pb2.CELESTICA,
   "plm" : Smbus_pb2.PLM,
   "fbdomfpga" : Smbus_pb2.FBDOMFPGA,
}

MAX_ADDR_SIZE = 4

def dumpBytes( data, step=16 ):
   """Dump data in the same fashion as i2cdump

      'data' is a bytes string
   """
   if not data:
      return

   # Print heading
   print( '   ',
          ' '.join( f'{v: 2x}' for v in range( step ) ),
          '  ',
          ''.join( f'{v:x}' for v in range( step ) ) )

   # Unprintable ASCII characters are replaced with '?', however 0x00 and 0xff should
   # instead use '.' to match i2cdump behavior.
   unprintableCharacters = bytes( chain( range( 0x20 ), range( 0x7f, 0xff + 1 ) ) )
   replacementCharacters = b'.' + b'?' * ( len( unprintableCharacters ) - 2 ) + b'.'
   replaceTable = bytes.maketrans( unprintableCharacters, replacementCharacters )

   begin = 0
   while sub := data[ begin : begin + step ]:
      hexOutput = sub.hex( ' ' )
      asciiOutput = sub.translate( replaceTable ).decode()
      print( f'{begin:02x}: {hexOutput}    {asciiOutput}' )

      begin += step

def main():

   def intType( x ):
      ret = int( x, 0 )
      assert ret <= 0xff
      return ret

   def hexStringType( inputStr ):
      # Converts sequence of hexadecimal digits (with leading 0x prefix) into bytes
      # string. Otherwise (no 0x prefix) the input is assumed to be a base-10 int.
      if inputStr.lower().startswith( "0x" ):
         inputStr = inputStr[ 2 : ]
         if len( inputStr ) % 2 == 1:
            # If input length is odd, assume the user forgot to include the leading 0
            inputStr = f'0{inputStr}'
         return bytes.fromhex( inputStr )
      else:
         return bytes( [ intType( inputStr ) ] )

   commandList = [ "read8", "read16", "reads",
                   "write8", "write16", "writes",
                   "recvByte", "sendByte",
                   "processCall", "blockProcessCall",
                   "dump" ]

   delays = {
      "0": Smbus_pb2.ZERO,
      "1ms": Smbus_pb2.ONE_MS,
      "10ms": Smbus_pb2.TEN_MS,
      "50ms": Smbus_pb2.FIFTY_MS,
   }

   usage = ( "plsmbus [-h] COMMAND [--pci PCI] [--scd ACCELID] bus deviceId "
             "[register] [data]\n"
             "where COMMAND = {read8|read16|reads|write8|write16|"
             "writes|recvByte|sendByte|processCall|blockProcessCall|dump}" )

   parser = argparse.ArgumentParser( description="Perform Smbus transactions",
                                     usage=usage )
   parser.add_argument( "command", choices=commandList )
   parser.add_argument( "--backend", default="", help="Backend to use" )
   parser.add_argument( "--pec", action="store_true", help="Send PEC" )
   parser.add_argument( "--pci", default=0, help="A BDF PCI address "
                        "[<domain>:]bus:device.function" )
   parser.add_argument( "--scd", dest="accelId", type=intType,
                        default=None, help="interpreted as hex" )
   parser.add_argument( "bus", type=intType, help="Bus Id, interpeted as hex" )
   parser.add_argument( "deviceId", type=intType,
                        help="Device Address, interpeted as hex" )
   parser.add_argument( "register", type=hexStringType, nargs='?', default=bytes(),
                        help="Command/Register, interpeted as hex (up to 4 bytes)" )
   parser.add_argument( "data", type=hexStringType, help="interpeted as hex",
                        nargs='?' )
   parser.add_argument( "--writeNoStopReadCurrent", action="store_true",
                        default=False,
                        help="enable write no stop read current" )
   parser.add_argument( "--delay", choices=delays, default="0" )


   args = parser.parse_args()

   pciAddr = 0
   if args.pci:
      pciAddr = PLSmbusUtil.encodePCIAddress( args.pci )

   backend = backendLookup.get( args.backend )

   if args.command != "recvByte":
      # The sendByte command uses 'register' to specify the byte to send so recvByte
      # is the only command where 'register' is not used.
      assert args.register

   assert len( args.register ) <= MAX_ADDR_SIZE

   sock = None
   try:
      sock = PLSmbusUtil.connect()
   except PermissionError as permErr:
      socketPath = PLSmbusUtil.socketPath()
      print( permErr )
      print( f"This script requires write privileges to {socketPath}.\n"
             "Please change permissions or try again using 'sudo'" )
   except OSError as osErr:
      print( osErr )
      print( "Could not connect to PlutoSmbus" )
   finally:
      if not sock:
         if backend == Smbus_pb2.IOPORT:
            print( "Falling back to raw access" )
         else:
            sys.exit( 1 )

   if args.command == "read8":
      result = PLSmbusUtil.read( sock, pciAddr, args.accelId,
                                 args.bus, args.deviceId, args.register, count=1,
                                 backend=backend, delay=delays[ args.delay ] )
      print( f'0x{result.hex()}' )

   elif args.command == "read16":
      result = PLSmbusUtil.read( sock, pciAddr, args.accelId,
                                 args.bus, args.deviceId, args.register, count=2,
                                 backend=backend, delay=delays[ args.delay ] )
      # SMBus read word returns data in little endian
      byteSwapped = result[ :: -1 ]
      print( f'0x{byteSwapped.hex()}' )

   elif args.command == "recvByte":
      result = PLSmbusUtil.read( sock, pciAddr, args.accelId, args.bus,
                                 args.deviceId, args.register, backend=backend,
                                 delay=delays[ args.delay ] )
      print( f'0x{result.hex()}' )

   elif args.command == "reads":
      if not args.data:
         result = PLSmbusUtil.read( sock, pciAddr, args.accelId, args.bus,
                                    args.deviceId, args.register,
                                    readType=Smbus_pb2.BLOCK_READ,
                                    backend=backend, delay=delays[ args.delay ] )
      elif len( args.data ) == 1:
         result = PLSmbusUtil.read( sock, pciAddr, args.accelId, args.bus,
                                    args.deviceId, args.register,
                                    readCurrent=args.writeNoStopReadCurrent,
                                    count=args.data[ 0 ], backend=backend,
                                    delay=delays[ args.delay ] )
      else:
         assert False

      print( result.hex( ' ' ) )

   elif args.command == "write8":
      assert len( args.data ) == 1
      PLSmbusUtil.write( sock, pciAddr, args.accelId, args.bus,
                         args.deviceId, args.register, args.data,
                         backend=backend, pec=args.pec, delay=delays[ args.delay ] )

   elif args.command == "write16":
      assert len( args.data ) == 2
      # SMBus write word expects data to be little endian
      byteSwapped = args.data[ :: -1 ]
      PLSmbusUtil.write( sock, pciAddr, args.accelId, args.bus,
                         args.deviceId, args.register, byteSwapped,
                         backend=backend, pec=args.pec, delay=delays[ args.delay ] )

   elif args.command == "sendByte":
      assert not args.data
      PLSmbusUtil.write( sock, pciAddr, args.accelId, args.bus,
                         args.deviceId, args.register, b"",
                         backend=backend, pec=args.pec, delay=delays[ args.delay ] )

   elif args.command == "writes":
      assert args.data
      PLSmbusUtil.write( sock, pciAddr, args.accelId, args.bus,
                         args.deviceId, args.register, args.data,
                         writeType=Smbus_pb2.BLOCK_WRITE,
                         backend=backend, pec=args.pec, delay=delays[ args.delay ] )

   elif args.command == "processCall":
      assert len( args.data ) == 2
      # SMBus processCall expects data to be little endian
      dataByteSwapped = args.data[ :: -1 ]
      result = PLSmbusUtil.processCall( sock, pciAddr, args.accelId, args.bus,
                                        args.deviceId, args.register,
                                        dataByteSwapped, backend=backend,
                                        delay=delays[ args.delay ] )
      # SMBus processCall returns data in little endian
      resultByteSwapped = result[ :: -1 ]
      print( f'0x{resultByteSwapped.hex()}' )

   elif args.command == "blockProcessCall":
      assert args.data
      result = PLSmbusUtil.processCall( sock, pciAddr, args.accelId, args.bus,
                                        args.deviceId, args.register, args.data,
                                        callType=Smbus_pb2.BLOCK_PROCESS_CALL,
                                        backend=backend, delay=delays[ args.delay ] )
      print( result.hex( ' ' ) )

   elif args.command == "dump":
      begin = int.from_bytes( args.register, "big" )
      end = ( args.data or 0xff ) + 1
      res = bytes()
      for i in range( begin, end ):
         try:
            byte = PLSmbusUtil.read( sock, pciAddr, args.accelId,
                                     args.bus, args.deviceId, i, count=1,
                                     backend=backend, delay=delays[ args.delay ] )
            res += byte
         except Exception: # pylint: disable=broad-except
            res += b'\xff'
      dumpBytes( res )

   return 0

if __name__ == "__main__":
   sys.exit( main() )
