#!/usr/bin/env python3
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import sys
import textwrap
import time
from MicrosemiLib import MicrosemiBase
import argparse
import Pci
import Tac

class Microsemi:
   speeds = ( "2.5G", "5G", "8G", "16G" )
   majorState = ( "DETECT", "POLLING", "CFG", "L0", "RECOVERY",
                  "DISABLE", "LOOPBACK", "HOT_RESET", "TX_LOS", "L1" )

   def __init__( self, base ):
      self.base = base

   def unbind( self, args ):
      assert 0 <= args.part <= 47
      assert 0 <= args.dsp <= 31
      assert 0 <= args.option <= 4
      return self.base.doGas( self.base.MRPC_P2P_UNBIND |
                              ( args.part << 8 ) |
                              ( args.dsp << 16 ) |
                              ( args.option << 24 ),
                              self.base.MRPC_PORTPARTP2P )

   def bind( self, args ):
      assert 0 <= args.port <= 47
      assert 0 <= args.part <= 47
      assert 0 <= args.dsp <= 31
      return self.base.bind( args.port, args.dsp, args.part )

   def ltssm( self, args ):
      numPorts = 128
      for port in self.base.allPorts():
         if self.base.doGas( numPorts << 24 | int( port ) << 8 |
                             self.base.MRPC_LTSSM_LogDump,
                         self.base.MRPC_DIAG_PORT_LTSSM_LOG ) == 0:
            log0 = 0
            states = dict()

            for offset in range( self.base.GAS_OUTPUT_DATA,
                                 self.base.GAS_OUTPUT_DATA + numPorts * 4, 4 ):
               log0 = self.base.read32( offset )
               states[ self.majorState[ self.base.getBits( log0, 10, 7 ) ] ] = 1
            print( "Port %d" % int( port ), "Speed ",
                   self.speeds[ self.base.getBits( log0, 14, 13 ) ],
                   "Major State", list( states ) )

   def bifurcation( self, args ):
      for stack in ( 1, 3, 4 ):
         assert self.base.doGas( stack << 8, self.base.MRPC_STACKBIF ) == 0
         value = self.base.readGroup4( 0x404 )
         for index in range( 0, 7 ):
            if value[ index ] != 0:
               print( "Port %d.%d: %x" % ( stack, index, value[ index ] ) )

   def firmwareDownload( self, args ):
      fileName = args.fileName or self.base.configImage
      self.base.firmwareDownload( fileName )
      print( "Firmware download completed" )

   def firmwareUpgrade( self, args ):
      fileName = args.fileName or self.base.configImage
      self.base.firmwareDownload( fileName )
      print( "Firmware download completed" )
      self.base.firmwareToggle( False, True )
      self.reset( args )
      Tac.waitFor(
            self.base.hardwareIsPresent,
            timeout=15.0,
            description="{}({!r}).hardwareIsPresent()".format(
               self.base.__class__.__name__, str( self.base.mepAddr ) ) )

   def firmwareInfo( self, args ):
      print( "Vendor Table Revision               %08x" % self.base.read32(
         self.base.GAS_VendorTableRevision ) )

      print( "Active Firmware Address             %08x" % self.base.read32(
         self.base.GAS_ActiveFirmwareAddress ) )
      print( "Active Firmware Build Version       %08s" % self.base.getVersion(
         self.base.GAS_ActiveFirmwareVersion ) )
      print( "Active Configuration Address        %08x" % self.base.read32(
         self.base.GAS_ActiveConfigAddress ) )
      print( "Active Configuration Version        %08s" % self.base.getVersion(
         self.base.GAS_ActiveConfigVersion ) )
      print( "Active Configuration Build          %08s" % self.base.read32(
         self.base.GAS_ActiveConfigBuild ) )

      print( "Inactive Firmware Address           %08x" % self.base.read32(
         self.base.GAS_InactiveFirmwareAddress ) )
      print( "Inactive Firmware Version           %08s" % self.base.getVersion(
         self.base.GAS_InactiveFirmwareVersion ) )
      print( "Inactive Configuration Address      %08x" % self.base.read32(
         self.base.GAS_InactiveConfigAddress ) )
      print( "Inactive Configuration Version      %08s" % self.base.getVersion(
         self.base.GAS_InactiveConfigVersion ) )
      print( "Inactive Configuration Build        %08s" % self.base.read32(
         self.base.GAS_InactiveConfigBuild ) )

      print( "Config Revision                     %08x" % self.base.read32(
         self.base.GAS_VendorTableRevision ) )

   def toggle( self, args ):
      if args.both:
         self.base.firmwareToggle( True, True )
      else:
         self.base.firmwareToggle( toggleFirmware=args.firmware,
                                   toggleConfig=args.config )
      if not args.skipReset:
         self.reset( args )
         time.sleep( 2 ) # needs some time to get ready for re-enumeration
         self.base.rescan()

   def updateConfig( self, args ):
      self.base.updateConfig( args.fileName, True )

   def updateFirmware( self, args ):
      self.base.updateFirmware( True )

   def reset( self, args ):
      self.base.doGas( 0, self.base.MRPC_RESET )

   def rescan( self, args ):
      self.base.rescan()

   def imageInfo( self, args ):
      ( _, imgType, loadAddr,
        version, vendor, revision ) = self.base.getImageInfo( args.fileName )
      if imgType == 4:
         fileType = "configuration"
      else:
         fileType = "firmware"
      print( "Type: %d (%s)" % ( imgType, fileType ) )
      print( "Load addr: %08X" % loadAddr )
      print( "Version: %x" % version )
      print( "Vendor: %x" % vendor )
      print( "Revision: %x " % revision )

   def upgradeNeeded( self, args ):
      ( _, _, _,
        _, _, revision ) = self.base.getImageInfo( args.fileName )
      # Change the below comparison to ">" right now keeping + for testing`
      if revision != self.base.read32( self.base.GAS_VendorTableRevision ):
         print( "True" )
      else:
         print( "False" )

   def temp( self, args ):
      self.base.doGas( 1, self.base.MRPC_DIETEMP )
      self.base.doGas( 2, self.base.MRPC_DIETEMP )
      print( self.base.read32( self.base.GAS_OUTPUT_DATA ) // 100 )

   def lnkstat( self, args ):
      linksStates = self.base.linksStates()

      for linkStat in sorted( linksStates.values(), 
                              key=lambda st: st.phyPortId ):

         if linkStat.cfgLnkWidth == 0:
            continue

         linkSpeed = {
               "PcieGenUndefined" : "0G",
               "PcieGen1" : "2.5G",
               "PcieGen2" : "5G",
               "PcieGen3" : "8G",
               }[ linkStat.linkRate ]
         linkStatDict = {
            attr : getattr( linkStat, attr )
            for attr in linkStat.attributes
         }

         if linkStat.partId == 0xFF:
            partitionStr = "unbound"
         else:
            partitionStr = f"{linkStat.partId:02}.{linkStat.logPortId:02}"

         print( ( "[{phyPortId:02}] part:{partitionStr: <7} "
                "w:cfg[x{cfgLnkWidth:02}]-neg[x{negLnkWidth:02}] "
                "         stk:{stackId}.{stackPortId} {usp} "
                "dl_active:{linkIsUp!s: <5} Rate: {linkSpeed: <4} "
                "LTSSM: {majorState}" ).format(
                   **linkStatDict,
                   partitionStr=partitionStr,
                   usp="usp" if linkStat.portIsUpstream else "dsp",
                   linkSpeed=linkSpeed,
                   majorState=linkStat.ltssmState.major.upper()
                ) )

   def ntInfo( self, args ):
      self.base.managementEndpoint_ = self.base.mepAddr
      self.base.microsemiFunctionBar = None

      print( "Management endpoint is ", self.base.managementEndpoint_ )

      ( partNumber, partId, _, _ ) = self.base.readGroup8(
            self.base.MRPC_NTB_BASE )
      ntMap0 = self.base.read32( self.base.MRPC_NTB_BASE + 4 )
      ntMap1 = self.base.read32( self.base.MRPC_NTB_BASE + 8 )
      ( requesterID, _ ) = self.base.readGroup16( self.base.MRPC_NTB_BASE + 12 )

      print( "partId=%d, numPartitions=%d, NTMap@0x%08x%08x, requesterID=0x%x" % (
            partId, partNumber, ntMap1, ntMap0, requesterID ) )


      for nt in range( 0, 2 ):
         part = self.base.NTPartitionGetInfo( nt )
         print( "NT%d: locked=%d, NTStat=%x, Opc=%x, Control=%x, BarOffset=%x" % (
            nt, part[ 'lockedId' ], part[ 'ntStat' ], part[ 'ntOpc' ],
            part[ 'ntControl' ], part[ 'barOffset' ] ) )
         print( "NT%d: Err: %d, ErrIndex %d, ntRequesterErr=%x, ntTableErr=%x" % (
            nt, part[ 'ntError' ], part[ 'ntErrorIndex' ],
            part[ 'ntRequesterError' ], part[ 'ntTableError' ] ) )
         print( "NT%d: enabled=%s, requester0=%d, proxy0=%d, enabled=%s,"
                "requester1 = % d, proxy1 = % d" % (
            partId,
            part[ 'Requester0' ][ 'enabled' ],
            part[ 'Requester0' ][ 'RequesterID' ],
            part[ 'Requester0' ][ 'NTProxy' ],
            part[ 'Requester1' ][ 'enabled' ],
            part[ 'Requester1' ][ 'RequesterID' ],
            part[ 'Requester1' ][ 'NTProxy' ] ) )

         # Read BAR Setup

         for barNo in range( 0, 6 ):
            bar = self.base.NTBBarGetInfo( nt, barNo )
            if 'valid' in bar:
               sys.stdout.write( "BAR%d  " % barNo )
               sys.stdout.write( "%s" % bar[ 'mode' ] )
               if not bar[ 'prefetch' ]:
                  sys.stdout.write( "Non-" )
               sys.stdout.write( "prefetchable " )
               sys.stdout.write( "%s " % bar[ 'mappingType' ] )
               if "NT-Direct" in bar[ 'mappingType' ]:
                  print( "BaseAddr=%016x, Size=%08xB, Pos=0x%016x, destPart=%d" %
                     ( bar[ 'BaseAddr' ], bar[ 'NtTranslationSize' ],
                     bar[ 'NtTranslationPosition' ],
                     bar[ 'NtDestinationPartition' ] ) )
               else:
                  print( "" )
            else:
               if 'dw' in bar and bar[ 'dw' ] != [ 0, 0, 0, 0 ]:
                  print( "Unconfigured BAR%d: " % barNo, bar )

   def ntEnable( self, args ):
      self.base.ntEnable()

   def ntReset( self, args ):
      self.base.ntReset()


def main():
   def integer( value ):
      return int( value, 0 )
   parser = argparse.ArgumentParser( description=textwrap.dedent(
      '''
      Utility to access the microsemi chips for programming,
      getting status, dumping info etc.
      '''
      ), formatter_class=argparse.ArgumentDefaultsHelpFormatter )
   group = parser.add_mutually_exclusive_group( required=True )
   group.add_argument( "--mep",
                       metavar="PCIADDR",
                       help="Management endpoint PCI address",
                       type=Pci.Address )
   group.add_argument( "--sup",
                       help="Select supervisor PCIe switch",
                       action='store_true' )
   parser.add_argument( "--mmio",
                        help="Directly access device registers",
                        action='store_true' )
   subparsers = parser.add_subparsers( metavar="COMMAND", required=True )

   ltssmCmd = subparsers.add_parser( 'ltssm',
                                     help='Dump ltssm for all ports' )
   ltssmCmd.set_defaults( action=Microsemi.ltssm )

   bifurcationCmd = subparsers.add_parser('bifurcation',
                       help='dump bifurcation')
   bifurcationCmd.set_defaults( action=Microsemi.bifurcation )

   firmwareInfoCmd = subparsers.add_parser('firmwareInfo',
                     help='Provide details about the firmware')
   firmwareInfoCmd.set_defaults( action=Microsemi.firmwareInfo )

   firmwareDownloadCmd = subparsers.add_parser('firmwareDownload',
                     help='Download the firmware onto the flash')
   firmwareDownloadCmd.add_argument( 'fileName',
                     help='filePath',
                     default=None )
   firmwareDownloadCmd.set_defaults( action=Microsemi.firmwareDownload )

   firmwareUpgradeCmd = subparsers.add_parser('firmwareUpgrade',
                     help='Download the firmware onto the flash and toggle')
   firmwareUpgradeCmd.add_argument( 'fileName',
                     help='filePath',
                     default=None )
   firmwareUpgradeCmd.set_defaults( action=Microsemi.firmwareUpgrade )

   bindCmd = subparsers.add_parser( 'bind', help='bind port' )
   bindCmd.add_argument( "--part", type=integer, default=0 )
   bindCmd.add_argument( "dsp", type=integer,
                         help="Downstream port number" )
   bindCmd.add_argument( "port", type=integer,
                         help="Physical port number" )
   bindCmd.set_defaults( action=Microsemi.bind )

   unbindCmd = subparsers.add_parser( 'unbind', help='unbind port' )
   unbindCmd.add_argument( "--part", type=integer, default=0 )
   unbindCmd.add_argument( "--option", type=integer, default=0x2 )
   unbindCmd.add_argument( "dsp", type=integer,
                           help="Downstream port number" )
   unbindCmd.set_defaults( action=Microsemi.unbind )

   imageInfoCmd = subparsers.add_parser('imageInfo', 
         help='Dump all the info. available for the firmware image')
   imageInfoCmd.add_argument( 'fileName', help='filePath' )
   imageInfoCmd.set_defaults( action=Microsemi.imageInfo )

   upgradeNeededCmd = subparsers.add_parser('upgradeNeeded', 
         help='Check if the firmware needs to be upgraded ')
   upgradeNeededCmd.add_argument( 'fileName', help='filePath' )
   upgradeNeededCmd.set_defaults( action=Microsemi.upgradeNeeded )

   lnkstatCmd = subparsers.add_parser('lnkstat',
                                       help='Dump the link status')
   lnkstatCmd.set_defaults( action=Microsemi.lnkstat )

   ntEnableCmd = subparsers.add_parser('ntEnable',
                                        help='Enable the nt port')
   ntEnableCmd.set_defaults( action=Microsemi.ntEnable )

   ntInfoCmd = subparsers.add_parser('ntInfo',
                                      help='Dump the info. on the nt port')
   ntInfoCmd.set_defaults( action=Microsemi.ntInfo )

   ntResetCmd = subparsers.add_parser('ntReset', help='Reset the nt port')
   ntResetCmd.set_defaults( action=Microsemi.ntReset )

   rescanCmd = subparsers.add_parser('rescan', help='rescan the ports')
   rescanCmd.set_defaults( action=Microsemi.rescan )

   resetCmd = subparsers.add_parser('reset', help='reset the ports')
   resetCmd.set_defaults( action=Microsemi.reset )

   tempCmd = subparsers.add_parser('temp', help='Read the temperature')
   tempCmd.set_defaults( action=Microsemi.temp )

   toggleCmd = subparsers.add_parser( 'toggle',
                                      help='toggle the active image' )
   toggleCmd.add_argument( "--config", action="store_true" )
   toggleCmd.add_argument( "--firmware", action="store_true" )
   toggleCmd.add_argument( "--both", action="store_true" )
   toggleCmd.add_argument( "--skipReset", action="store_true" )
   toggleCmd.set_defaults( action=Microsemi.toggle )

   updateConfigCmd = subparsers.add_parser('updateConfig',
                                help='update the chip configuration')
   updateConfigCmd.add_argument( 'fileName',
                                 help='filePath', default=None )
   updateConfigCmd.set_defaults( action=Microsemi.updateConfig )

   updateFirmwareCmd = subparsers.add_parser('updateFirmware',
                                  help='update the firmware on the chip')
   updateFirmwareCmd.set_defaults( action=Microsemi.updateFirmware )

   args = parser.parse_args()
   base = \
      MicrosemiBase.forSupervisor( mmio=args.mmio ) \
      if args.sup else MicrosemiBase( args.mep, mmio=args.mmio )
   M = Microsemi( base )
   args.action( M, args )

if __name__ == '__main__':
   main()
