#!/usr/bin/env python3
# Copyright (c) 2017 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.
#
# Script to enable high power mode on Arista loopback transceivers.
# This script has two parts.
# 1. determine ports to receive configuration. This task is independent of Xcvr Type.
# 2. Apply power control. This task is dependent on Xcvr Type.

import Tac, EntityManager, Cell, PyClient
import argparse
import re
import sys
import weakref
import SmbusUtil


# ---------------------------------------------------------------------------------#
# Private Module Globals and Helper Functions
# ---------------------------------------------------------------------------------#
_xcvrType = Tac.Type( "Xcvr::XcvrType" )
_mediaType = Tac.Type( "Xcvr::MediaType" )
_mediaTypeMap = {
   _xcvrType.qsfpPlus   : _mediaType.xcvr100GBaseMfg,
   _xcvrType.cfp2       : _mediaType.xcvr2x100GBaseMfg
}

# Helper function to return true if system is fixed,
# false otherwise
def isFixed():
   return Cell.cellType() == "fixed"


# ---------------------------------------------------------------------------------#
# Function to parse user arguments and build a list of ports which need power
# control based on the user's input.
# Returns a tuple of ( xcvrType, user args, list of ports )
# ---------------------------------------------------------------------------------#
def setup():
   usage = """
   loopbackPowerControl [OPTIONS] XCVRTYPE {xcvr-specific...}

      This program provides control of the power consumption of
      Arista's loopback modules. The utility of this script is for
      Saftey Testing.

      Use -h for help
   """

   cfp2DcoUsage = """
   loopbackPowerControl [OPTIONS] cfp2-dco [CFP2-DCO OPTIONS]

      Provides high power toggle for Cfp2-Dco electrical loopback modules.
      Default behavior (no arguments) is to read the module(s) state.
      Use -h for help.
   """

   qsfpUsage = """
   loopbackPowerControl [OPTIONS] qsfp [QSFP OPTIONS]

      Provides high power toggle for Qsfp loopback modules. High power is power
      class 5, 6, and 7.
      Default behavior (no arguments) will read module(s) state.
      Use -h for help.
   """

   parser = argparse.ArgumentParser( prog="loopbackPowerControl", usage=usage )

   # Non xcvr specific arguments
   parser.add_argument( "-a", "--allPorts", action="store_true",
                           default=False,
                           help="Controls power of all applicable modules \
                                 in the system" )
   parser.add_argument( "-c", "--lineCard", action="store", type=int,
                           default=None,
                           help="Controls power of modules on the specified \
                                 linecard, e.g. 3" )
   parser.add_argument( "-n", "--portNumber", action="store",
                           default=None,
                           help="Controls power for module(s) specified by \
                                 port, e.g. 3/1 or 3/2,5/19" )

   # add a subparser per xcvr type
   xcvrTypeParser = parser.add_subparsers( dest="xcvrType", help="Xcvr type" )
   cfp2DcoParser = xcvrTypeParser.add_parser( "cfp2-dco", usage=cfp2DcoUsage )
   qsfpParser = xcvrTypeParser.add_parser( "qsfp", usage=qsfpUsage )

   # qsfp specific args
   qsfpParser.add_argument( "-p", "--powerClass", action="store", type=int,
                            default=None,
                            help="Specifies module(s) power class \
                                  (e.g. power class 7 for 5W)" )
   qsfpParser.add_argument( "-l", "--lowPower", action="store_true",
                           default=False,
                           help="Asserts low power on specified modules." )
   qsfpParser.add_argument( "--smbusAddrSpacing", action="store",
                            default='0x100',
                            help="Specify accelerator address offsets. \
                                  Only applies to select platforms." )

   # cfp2 specific args
   cfp2DcoParser.add_argument( "-p", "--high", action="store", type=int,
                               default=None,
                               help="Use --high=1 to assert highest Cfp2-dco power \
                                     draw. Use --high=0 to assert default power." )

   args = parser.parse_args()

   # Check user input
   xcvrType = None
   if args.xcvrType == "qsfp":
      xcvrType = _xcvrType.qsfpPlus

      if args.powerClass and args.lowPower:
         print( "Conflicting arguments: --powerClass and --lowPower." )
         sys.exit( 1 )
      if( ( args.powerClass ) and
          ( args.powerClass < 5 or args.powerClass > 7 ) ):
         print( "Please specify a power class between 5 and 7 inclusive." )
         sys.exit( 1 )

   elif args.xcvrType == "cfp2-dco":
      xcvrType = _xcvrType.cfp2

   else:
      # leaving this here for future loopback support
      print( "Xcvr Type " + args.xcvrType + " not supported." )
      sys.exit( 1 )

   if args.allPorts and args.lineCard:
      print( "Conflicting arguments: --allPorts and --lineCard." )
      sys.exit( 1 )
   if args.allPorts and args.portNumber:
      print( "Conflicting arguments: --allPorts and --portNumber." )
      sys.exit( 1 )
   if args.lineCard and args.portNumber:
      print( "Conflicting arguments: --lineCard and --portNumber." )
      sys.exit( 1 )

   argPorts = None
   linecard = None
   if( args.allPorts ): # pylint: disable=superfluous-parens
      pass
   elif( args.lineCard ): # pylint: disable=superfluous-parens
      if isFixed():
         parser.error( "Linecard was specified, but this system is fixed." )
         sys.exit( 1 )
      linecard = "Linecard" + str( args.lineCard )
   elif( args.portNumber ): # pylint: disable=superfluous-parens
      argPorts = []
      portStrs = args.portNumber.split( "," )
      for p in portStrs:
         r = re.match( r'((\d+)/)?(\d+)', p )
         if isFixed():
            # pylint: disable-next=consider-using-f-string
            argPorts.append( "Ethernet%s" % ( r.group( 3 ) ) )
         else:
            # pylint: disable-next=consider-using-f-string
            argPorts.append( "Ethernet%s/%s" % ( r.group( 2 ), r.group( 3 ) ) )
   else:
      print( "Please specify portNumber, lineCard, or allPorts before specifying "
             "Xcvr Type." )
      sys.exit( 1 )


   # Amass a list of xcvrStatus directories
   pycl = PyClient.PyClient( "ar", "Sysdb" )
   xcvrStati = []
   if isFixed():
      xcvrStatusAll = pycl.agentRoot().entity[ 'hardware/archer/xcvr/status/all' ]
      xcvrStati.append( xcvrStatusAll )
   else:
      sliceDir = pycl.agentRoot().entity[ 'hardware/archer/xcvr/status/slice/' ]
      if args.allPorts:
         for lc in sliceDir.values():
            xcvrStati.append( lc )
      elif( args.portNumber ): # pylint: disable=superfluous-parens
         # Grab any Linecard dirs that coorespond to a user-specified port
         tmpLCs = []
         for port in argPorts:
            m = re.match( r'Ethernet(\d+)/\d+', port )
            lc = "Linecard" + m.group( 1 )
            if lc not in tmpLCs:
               xcvrStati.append( sliceDir[ "Linecard" + m.group( 1 ) ] )
               tmpLCs.append( lc )
      else:
         xcvrStati.append( sliceDir[ linecard ] )


   # Create set of valid ports by iterating over list of xcvrStatuses
   # A port is valid if:
   #     1. actual xcvrType matches the specified xcvrType
   #     2. xcvr is present
   #     3. xcvr's media type is the correct variant of Mfg media type
   validPorts = []
   for xcvrStatus in xcvrStati:
      for port in xcvrStatus.values():
         # if user specified some ports, filter out the non-specified ones
         if argPorts:
            if port.name not in argPorts:
               continue

         # validate module, add to list if valid, continue otherwise
         if( port.presence == "xcvrPresent" and
             port.xcvrType == xcvrType ):
            if port.mediaType == _mediaTypeMap[ xcvrType ]:
               validPorts.append( port.name )
            else:
               print( port.name + " is not recognized as a loopback. Skipping..." )

   return ( xcvrType, args, validPorts )


# ---------------------------------------------------------------------------------#
# Below are functions to control power per supported transceiver.
# Different transceivers have fundamentally different means of controlling power.
# Each supported trancseiver should have a power control function that takes in a
# list of ports.
# ---------------------------------------------------------------------------------#

# Function to control power for qsfp loopbacks.
def controlPowerQsfp( ports, args ):
   # constants related to qsfp devices and smbus
   deviceId_ = 0x50  # smbus device ID for Qsfps
   addr_ = 93  # address of power control function
   addrSize_ = 1
   delayMs_ = 'delay50ms'
   busTimeout_ = 'busTimeout1000ms'
   writeNoStopReadCurrent_ = False

   # Helper function to return human readable power class string
   # given a value read from register.
   def regValToPowerClass( val ):
      regValMap_ = {
         0b00000011 : "Power Class 1",
         0b01000001 : "Power Class 5",
         0b10000001 : "Power Class 6",
         0b00000001 : "Power Class 7"
      }
      if val in regValMap_:
         return regValMap_[ val ]
      else:
         return "Unknown"

   # Helper function to return register value corresponding to
   # passed in power class number.
   def powerClassToRegVal( pc ):
      return {
         1 : 0b00000011,
         5 : 0b01000001,
         6 : 0b10000001,
         7 : 0b00000001
      }[ pc ]

   # Grab the smbus topology
   em = EntityManager.Sysdb( sysname='ar' )
   from CliPlugin import CounterCli # pylint: disable=import-outside-toplevel
   CounterCli.Plugin( weakref.proxy( em ) )
   smbusMap = CounterCli.createSmbusDeviceMap()
   smbusMap = {  j: i  for i, j in smbusMap.items()
                if 'ethernet' in j.lower() }

   # build a map of smbus path -> (accelId,busId) by iterating
   # over all potential smbus addresses
   devs = {}  # smbus path -> (accelId,busId)
   for device, path in smbusMap.items():
      if not device in ports:
         continue

      # extract accelId and busId from Sysdb
      # group 1 is linecard, 2 is accelId, 3 is busId
      m = re.match( r'(linecard\d+)?/?(\d+)/(\d+)/\dx\d+', path.lower() )
      if m:
         devs[ device ] = ( m.group( 1 ), m.group( 2 ), m.group( 3 ) )
      else:
         print( "No accelerator info for " + device )
         sys.exit( 1 )

   fmtStr = "{0:<6} {1:<13} {2:<15} {3:<14}"
   banner = fmtStr.format( "Op", "Result", "Port", "Power Class" )
   bannerLine = fmtStr.format( "-" * 6, "-" * 13, "-" * 15, "-" * 14 )
   print( banner + "\n" + bannerLine )

   # Begin smbus write(s) or read(s)
   read = ( args.powerClass is None ) and ( args.lowPower is False )
   factory = SmbusUtil.Factory()
   for device, ids in devs.items():
      ( linecardNum, accelId, busId ) = ids
      accelId = int( accelId )
      busId = int( busId )
      if linecardNum:
         linecardNum = int( linecardNum.strip( "linecard" ) )

      helper = factory.device( accelId, busId, deviceId_, addrSize_,
                               readDelayMs=delayMs_, writeDelayMs=delayMs_,
                               busTimeout=busTimeout_,
                               writeNoStopReadCurrent=writeNoStopReadCurrent_,
                               smbusAddrSpacing=args.smbusAddrSpacing,
                               smbusAgentId=linecardNum,
                               pciAddress=None )

      if read:
         try:
            result = helper.read8( addr_ )
         except SmbusUtil.SmbusFailure:
            print( fmtStr.format( "read", "FAILED", device, "N/A" ) )
            continue

         print( fmtStr.format( "read", "SUCCESS", device,
                               regValToPowerClass( result ) ) )
      else:
         pc = args.powerClass
         if args.lowPower:
            pc = 1

         data = powerClassToRegVal( pc )

         try:
            helper.write8( addr_, data )
         except SmbusUtil.SmbusFailure:
            print( fmtStr.format( "write", "FAILED", device, "N/A" ) )
            continue

         print( fmtStr.format( "write", "SUCCESS", device,
                               regValToPowerClass( data ) ) )


# Function to control power for Cfp2-Dco loopback modules
def controlPowerCfp2Dco( ports, args ):
   # Note: register 0x9401 has wattage range of [0,8.7], while 9402 and 9403
   # have range of [0,7.6]
   registers_ = [ 0x19401, 0x19402, 0x19403 ]
   defaultPower_ = 0x00
   maxPower_ = 0xFF

   # Helper function to convert actual cfp2-dco register value
   # to programmed wattage (See Multilane Cfp2-Dco loopback spec)
   def valueToWatt( reg, value ):
      if reg == 0x19401:
         return ( value // 255 ) * 8.7
      else:
         return ( value // 255 ) * 7.6

   # Helper function to read port given cfp2Ctrl entity. Returns total
   # power read from 3 module registers
   def doRead( cfp2Ctrl, port ):
      power = 0
      for reg in registers_:
         evalStr = ".aham.hamImpl.data[ " + str( reg ) + " ]"
         value = cfp2Ctrl[ port ].config.ahamDesc.eval( evalStr )
         power += valueToWatt( reg, value )
      return power

   # Helper function to write port given cfp2Ctrl entity. Returns
   # module power to be configured by this write.
   def doWrite( cfp2Ctrl, port, data ):
      power = 0
      for reg in registers_:
         evalStr = ".aham.hamImpl.data.__setitem__( " \
               + str( reg ) + ", " + str( data ) + " )"
         cfp2Ctrl[ port ].config.ahamDesc.eval( evalStr )
         power += valueToWatt( reg, data )
      return power

   data = defaultPower_
   if args.high:
      data = maxPower_

   # Print output header
   fmtStr = "{0:<6} {1:<15} {2:<14}"
   print( fmtStr.format( "Op", "Port", "Power (Watt)" ) )
   print( fmtStr.format( "-" * 6, "-" * 15, "-" * 14 ) )
   pwrFmtStr = "{0:.2f}"

   # perform read(s)/write(s) based on system type
   if isFixed():
      pycl = PyClient.PyClient( "ar", "XcvrAgent" )
      cfp2Ctrl = pycl.agentRoot().entity[ 'xcvrCtrlAgent' \
                        ].xcvrControllerAgentSm.cfp2CtrlSm

      for port in ports:
         if args.high is None:
            power = doRead( cfp2Ctrl, port )
            print( fmtStr.format( "Read",
                                  port,
                                  pwrFmtStr.format( power )
                                     if power else "Default" ) )
         else:
            power = doWrite( cfp2Ctrl, port, data )
            print( fmtStr.format( "Write",
                                  port,
                                  pwrFmtStr.format( power )
                                     if power else "Default" ) )

   else:  # Modular
      cfp2CtrlSms = {}
      for port in ports:
         cfp2Ctrl = None
         m = re.match( r'Ethernet(\d+)/\d+', port )
         lcNum = m.group( 1 )
         if lcNum in cfp2CtrlSms:
            cfp2Ctrl = cfp2CtrlSms[ lcNum ]
         else:
            xcvrAgent = "XcvrAgent-" + "Linecard" + lcNum
            pycl = PyClient.PyClient( "ar", xcvrAgent )
            cfp2Ctrl = pycl.agentRoot().entity[ 'xcvrCtrlAgent' \
                                 ].xcvrControllerAgentSm.cfp2CtrlSm
            cfp2CtrlSms[ lcNum ] = cfp2Ctrl

         if args.high is None:
            power = doRead( cfp2Ctrl, port )
            print( fmtStr.format( "Read",
                                  port,
                                  pwrFmtStr.format( power )
                                     if power else "Default" ) )
         else:
            power = doWrite( cfp2Ctrl, port, data )
            print( fmtStr.format( "Write",
                                  port,
                                  pwrFmtStr.format( power )
                                     if power else "Default" ) )


# ---------------------------------------------------------------------------------#
# Main
# ---------------------------------------------------------------------------------#
if __name__ == "__main__":
   # get xcvr, user arguments, and list of ports to configure
   ( xcvr, userArgs, prts ) = setup()

   # Check xcvrType to execute the correct power control function
   if xcvr == _xcvrType.qsfpPlus:
      controlPowerQsfp( prts, userArgs )
   elif xcvr == _xcvrType.cfp2:
      controlPowerCfp2Dco( prts, userArgs )
   else:
      # leaving open for future loopbacks
      pass
