#!/usr/bin/arista-python
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import argparse
import json
import Logging
import re
import subprocess
import sys
import time
from XcvrLib import TRANSCEIVER_POWER_CYCLE

version = "1.0"

dryRun = False
verbose = False

# These are used when --dryRun is selected. It allows testing the
# code path without writing anything to the system.
# Running with --dryRun and --verbose will allow seeing the commands
# that would be run to recover the system - the pcie addr will not be
# accurate as it comes from the simulated data below.
dryRunPciInfo = '''
{
    "pciIds": {
        "8c:00.0": {
            "name": "Slot3:pciFpga0"
        }
    },
    "switchs": {
        "Slot51cardSwitch": 0
    }
}
'''

dryRunModuleInfo = '''{
    "modules": {
        "5": {
            "modelName": "7800R3-48CQ2-LC"
        }
    },
    "redundancyProtocol": "rpr",
    "redundancyMode": "active"
}
'''

def runCommands( cmds ):
   '''
   Runs a list of CLI commands

   cmds : list[str] - list of command strings

   Returns the command output
   '''

   cmd = '\r\n'.join( cmds )
   if verbose:
      print( f"running commands '{cmd}'" )
   if dryRun:
      return None
   out = subprocess.check_output( 'FastCli -p15 -c "' + cmd + '"', shell=True )
   return out

def riserFromPort( xcvrPort ):
   '''
   Calculate the riser # from the port. Four ports per riser.

   xcvrPort : str - the port on the front of the card (1-48)

   Returns the riser number (0-11)
   '''
   p = int( xcvrPort )
   assert 0 < p <= 48
   # 4 ports per riser
   riser = ( p - 1 ) // 4
   return riser

def cardPciAddr( linecard ):
   '''
   Find the PCI address of the linecard SCD

   linecard : str - the linecard number such as "5"

   Returns the pci address
   '''
   cmds = [ "en", "show pci | json" ]
   out = runCommands( cmds )
   if dryRun:
      # Just make this work for the slot of the interface.
      out = dryRunPciInfo.replace( 'Slot3', f"Slot{linecard}" )
   jout = json.loads( out )
   addrDict = { b[ 'name' ].split( ':' )[ 0 ].lstrip( 'Slot' ):
                a for a, b in jout[ 'pciIds' ].items() if 'pciFpga' in b[ 'name' ] }

   assert linecard in addrDict, \
      f"slot {linecard} FPGA not found in PIC addresses"
   return addrDict[ linecard ]

def configRiserPowerCheck( pciAddr, enable, lc ):
   '''
   pciAddr : str - the Linecard SCD address
   enable : bool - enable the check (normal state).
   lc : str - the linecard number
   '''

   if enable:
      print( f">> Enable power good check for linecard{lc}" )
   else:
      print( f">> Disable power good check for linecard{lc}" )

   reg = 0x58f0
   val = int( enable )

   cmds = [ "enable", "bash sudo scd {} write {:#06x} {:#06x}".
            format( pciAddr, reg, val ) ]
   runCommands( cmds )

def riserPower( pciAddr, port, on ):
   '''
   Enable or disable power to a given riser.

   This method assumes all other risers are powered on.

   pciAddr : str - the PCI address of the linecard SCD
   port : str - the port on the card
   on : bool - True to enable power
        False to disable
   '''
   reg = "0x5960"
   riser = riserFromPort( port )
   val = 0xfff
   if not on:
      val = ( ~( 1 << riser ) & 0xfff )

   op = 'Enable' if on else 'Disable'
   print( ">> {} riser power for ports {}".
          format( op, list( range( riser *4 + 1, riser *4 + 5 ) ) ) )

   cmds = [ "enable", "bash sudo scd {} write {} {:#06x}".
            format( pciAddr, reg, val ) ]
   runCommands( cmds )

def simulateRemoved( interfaces, remove ):
   '''
   interfaces : list - list of /1 interfaces to be removed
   remove : bool - if True, the module is removed.
            if False, the module is inserted
   '''
   action = 'Remove' if remove else 'Insert'

   print( f">> {action} transceivers for {interfaces}" )
   no = '' if remove else 'no '

   for i in interfaces:
      cmds = [ 'enable', 'config', f"interface {i}",
               f"{no}transceiver diag simulate removed" ]
      runCommands( cmds )

def validateCard( linecard, pciAddr ):
   '''
   Validates that the linecard can support the recovery steps.

   1) 7800R3-48CQ2-LC or 7800R3-48CQM2-LC
   2) SCD version has riser power feature

   linecard : str - the linecard number as in '9'
   pciAddr : str - PCI address of the linecard SCD
   '''

   cmds = [ "en", "show module | json" ]
   out = runCommands( cmds )

   if dryRun:
      # Just make this work for the given linecard
      out = dryRunModuleInfo.replace( '"5"', f'"{linecard}"' )
   jout = json.loads( out )
   if not ( 'modules' in jout and linecard in jout[ 'modules' ] and
            jout[ 'modules' ][ linecard ][ 'modelName' ] in
            [ '7800R3-48CQ2-LC', '7800R3-48CQM2-LC' ] ):
      print( ">> Linecard {} is not 7800R3-48CQ2-LC or 7800R3-48CQM2-LC".
             format( linecard ) )
      return False

   # Check the SCD version
   cmds = [ "enable", f"bash sudo scd {pciAddr} read 0x100" ]
   out = runCommands( cmds )
   if dryRun:
      out = b'0x100 0x100121 == 0000 0000 0001 0000 0000 0001 0010 0001'

   m = re.search( b'0x100 (0x[a-f0-9]+)[a-f0-9]{2}([a-f0-9]{2}) ==', out )
   if m:
      major = int( m.groups()[ 0 ], 16 )
      minor = int( m.groups()[ 1 ], 16 )
      if major > 0x10 or ( major == 0x10 and minor >= 0x20 ):
         print( ">> Card {} validated with SCD version {:#x}_{:#x}".
               format( linecard, major, minor ) )
         return True
      else:
         print( ">> SCD version {:#x}_{:#x} does not support this function".
                format( major, minor ) )
   else:
      print( f"Unable to validate SCD version for linecard{linecard}" )
   return False

def main():
   global dryRun

   parser = argparse.ArgumentParser()
   parser.add_argument( "interface", action="store", type=str,
                        help="Interface to recover (e.g. Ethernet4/2/1, Fabric1/1)" )

   parser.add_argument( '--dryRun', action='store_true',
                        help="Do not actually run commands, just display them " )

   parser.add_argument( '--show-impact', action='store_true',
                        help="Show which ports would be impacted by power cycle"
                        " for the target interface and exit" )

   parser.add_argument( '--verbose', action='store_true',
                        help="Print the commands run" )

   parser.add_argument( '--version', action='store_true',
                          help="Show script version and exit" )

   args = parser.parse_args()

   if args.version:
      print( f">> Version: {version}" )
      sys.exit()
   dryRun = args.dryRun
   global verbose
   verbose = args.verbose

   m = re.match( r'(Ethernet|Fabric)(\d+)/(\d+)/1', args.interface )
   if not m:
      assert False, (
         f"Invalid interface { args.interface } - "
         f"expect in the form '<Ethernet/Fabric><linecard>/<port>/1'" )

   intfPrefix = m.group( 1 )
   linecard = m.group( 2 )
   port = m.group( 3 )

   # impacted interfaces
   riser = riserFromPort( port )
   impactedIntfs = []
   for i in range( riser *4 + 1, riser *4 + 5 ):
      impactedIntfs.append( f"{ intfPrefix }{ linecard }/{ i }/1" )

   if args.show_impact:
      print( f">> Impacted interfaces: {impactedIntfs}" )
      sys.exit()

   pciAddr = cardPciAddr( linecard )
   if not validateCard( linecard, pciAddr ):
      sys.exit()


   # Steps
   # 1) disable riser power good check
   # 2) simulate removal of the xcvr with CLI
   # 3) Disable riser power
   # 4) simulate insertion via CLI
   # 5) Enable riser power
   # 6) simulate removal of xcvr
   # 7) simulate insertion of Xcvr
   # 8) enable riser power good check

   print( f">> Beginning recovery of {args.interface}" )
   print( f">> Impacted interfaces: {impactedIntfs}" )

   configRiserPowerCheck( pciAddr, False, linecard )
   simulateRemoved( impactedIntfs, True )
   time.sleep( 10 )
   riserPower( pciAddr, port, False )
   time.sleep( 1 )
   simulateRemoved( impactedIntfs, False )
   time.sleep( 10 )
   riserPower( pciAddr, port, True )
   time.sleep( 1 )
   simulateRemoved( impactedIntfs, True )
   time.sleep( 10 )
   simulateRemoved( impactedIntfs, False )
   configRiserPowerCheck( pciAddr, True, linecard )
   print( ">> Recovery complete!" )

   stripBreakouts = lambda i: re.match(
      r'((Ethernet|Fabric)(?:\d+)/(?:\d+))', i ).group()
   Logging.log( TRANSCEIVER_POWER_CYCLE,
                ", ".join( stripBreakouts( intf ) for intf in impactedIntfs ) )


if __name__ == "__main__":
   main()
