# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from typing import NamedTuple
import os
import re
import time
import traceback
import json
import Tac
import Tracing

from Fdl import (
      MacAddrAdd,
)

from SfFruHelper import (
      DEVICE,
      FIRST_MAC,
      MAC,
      PCI,
      DRIVER,
      PORT,
      ROLE,
      TYPE,
      VENDOR_ID,
      DEVICE_ID,
      VF2PF,
)

th = Tracing.Handle( "Fdl.CaravanCommon" )
t0 = th.trace0
t9 = th.trace9

DPS_INTF = "et100"
DPS_DRIVER = "BESS"
DPS_PCI = "0000:ff:ff.f"
DPS_PORT = 100
DPS_VENDOR_ID = 0xffff
DPS_DEVICE_ID = 0xffff

class PciDevIds( NamedTuple ):
   vendor: int
   device: int

   def __repr__( self ) -> str:
      return f"vendorId: {self.vendor:#x}, deviceId: {self.device:#x}"

# ---------------------------------
# Independence
# ---------------------------------
INTEL_PCI_VENDOR_ID = 0x8086
INDEPENDENCE_NAC_SW_VF_PCI_DEVICE_ID = 0x1889
INDEPENDENCE_NAC_SW_PCI_DEVICE_ID = 0x188f
XL710_SW_PCI_DEVICE_ID = 0x1572
XL710_SW_VF_PCI_DEVICE_ID = 0x154c

XL710_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                           XL710_SW_PCI_DEVICE_ID )
XL710_VF_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                              XL710_SW_VF_PCI_DEVICE_ID )
IND_NAC_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                            INDEPENDENCE_NAC_SW_PCI_DEVICE_ID )
IND_NAC_VF_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                                INDEPENDENCE_NAC_SW_VF_PCI_DEVICE_ID )

# ---------------------------------
# Councilbluffs
# ---------------------------------
COUNCILBLUFFS_NAC_SW_PCI_DEVICE_ID = 0x1895
COUNCILBLUFFS_NAC_SW_VF_PCI_DEVICE_ID = 0x1889
CBL_NAC_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                             COUNCILBLUFFS_NAC_SW_PCI_DEVICE_ID )
CBL_NAC_VF_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                               COUNCILBLUFFS_NAC_SW_VF_PCI_DEVICE_ID )

# ---------------------------------
# Willamette
# ---------------------------------
WILLAMETTE_X553_RJ45_SW_PCI_DEVICE_ID = 0x15e4
Q8_X553_RJ45_SW_PCI_DEVICE_ID = 0x15e5
WILLAMETTE_X553_SFP_SW_PCI_DEVICE_ID = 0x15c4
WILLAMETTE_I226V_SW_PCI_DEVICE_ID = 0x125c
WILLAMETTE_I226IT_SW_PCI_DEVICE_ID = 0x125d
WILLAMETTE_I350_SW_PCI_DEVICE_ID = 0x1521
WILLAMETTE_X553_SFP_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                                     WILLAMETTE_X553_SFP_SW_PCI_DEVICE_ID )
WILLAMETTE_X553_RJ45_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                                     WILLAMETTE_X553_RJ45_SW_PCI_DEVICE_ID )
Q8_X553_RJ45_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                                  Q8_X553_RJ45_SW_PCI_DEVICE_ID )
WILLAMETTE_I226V_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                                     WILLAMETTE_I226V_SW_PCI_DEVICE_ID )
WILLAMETTE_I226IT_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                                     WILLAMETTE_I226IT_SW_PCI_DEVICE_ID )
WILLAMETTE_I350_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                                     WILLAMETTE_I350_SW_PCI_DEVICE_ID )

# ---------------------------------
# Shorthorn Device Constants
# ---------------------------------
SHORTHORN_SW_PCI_DEVICE_ID = 0x1572
SHORTHORN_SW_VF_PCI_DEVICE_ID = 0x154c
SHORTHORN_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                                       SHORTHORN_SW_PCI_DEVICE_ID )
SHORTHORN_VF_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                                          SHORTHORN_SW_VF_PCI_DEVICE_ID )

# ---------------------------------
# Longhorn Device Constants
# ---------------------------------
LONGHORN_SW_PCI_DEVICE_ID = 0x1592
LONGHORN_SW_VF_PCI_DEVICE_ID = 0x1889
LONGHORN_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                                      LONGHORN_SW_PCI_DEVICE_ID )
LONGHORN_VF_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                                         LONGHORN_SW_VF_PCI_DEVICE_ID )

# ---------------------------------
# Appaloosa Device Constants
# ---------------------------------
APPALOOSA_SW_PCI_DEVICE_ID = 0x15ff
APPALOOSA_SW_VF_PCI_DEVICE_ID = 0x154c
APPALOOSA_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                                       APPALOOSA_SW_PCI_DEVICE_ID )
APPALOOSA_VF_PCI_IDS = PciDevIds( INTEL_PCI_VENDOR_ID,
                                          APPALOOSA_SW_VF_PCI_DEVICE_ID )

SYS_CLASS_NET = "/sys/class/net"
SFA_FRU_DEVICE_CACHE = "/var/run/sfaFruPluginDevices.json"
SYS_CLASS_PCIBUS = "/sys/class/pci_bus"

supportedPciDeviceIds = {
   CBL_NAC_VF_PCI_IDS: "NAC",
   IND_NAC_VF_PCI_IDS: "NAC",
   XL710_VF_PCI_IDS: "XL710",
   SHORTHORN_VF_PCI_IDS: "XL710",
   # BUG830692 - VF device ID is same as IND NAC device
   # ID. Need to update the getDeviceType function
   # to handle this case
   # LONGHORN_VF_PCI_IDS : "E810",
   APPALOOSA_VF_PCI_IDS: "XL710",
   WILLAMETTE_X553_SFP_PCI_IDS: "x553",
   WILLAMETTE_X553_RJ45_PCI_IDS: "x553",
   Q8_X553_RJ45_PCI_IDS: "x553",
   WILLAMETTE_I226V_PCI_IDS: "i226",
   WILLAMETTE_I226IT_PCI_IDS: "i226",
   WILLAMETTE_I350_PCI_IDS: "i350",
}

def inBuildEnv():
   return ( os.environ.get( "P4USER" ) or
         os.environ.get( "ABUILD" ) or
         os.environ.get( "A4_CHROOT" ) )

def initDeviceCache():
   t9( "Initializing empty device cache" )
   deviceCache = {}
   deviceCache[ DEVICE ] = []
   deviceCache[ FIRST_MAC ] = ""
   deviceCache[ MAC ] = {}
   deviceCache[ PCI ] = {}
   deviceCache[ DRIVER ] = {}
   deviceCache[ PORT ] = {}
   deviceCache[ ROLE ] = {}
   deviceCache[ TYPE ] = {}
   deviceCache[ VENDOR_ID ] = {}
   deviceCache[ DEVICE_ID ] = {}
   deviceCache[ VF2PF ] = {}
   return deviceCache

def loadDeviceCache( devCachePath=SFA_FRU_DEVICE_CACHE, mustBePresent=False ):
   deviceCache = None
   if not inBuildEnv():
      if os.path.exists( devCachePath ):
         t0( "Loading interfaces from existing device cache file" )
         with open( devCachePath, 'r' ) as devCacheFile:
            deviceCache = json.load( devCacheFile )
      elif not mustBePresent:
         deviceCache = initDeviceCache()
      else:
         assert 0, "Device cache file not found, must be present"

   return deviceCache

def saveDeviceCache( deviceCache, devCachePath=SFA_FRU_DEVICE_CACHE ):
   with open( devCachePath, "w" ) as f:
      json.dump( deviceCache, f, indent=3 )

def getDeviceType( pciDevIds ):
   if pciDevIds in supportedPciDeviceIds:
      return supportedPciDeviceIds[ pciDevIds ]
   assert 0, f"Usupported device type: {pciDevIds}"

def getDriver( deviceName ):
   """Retrieves the driver name for the given device from /sys/class/net
   This method is from SfFruHelper with slight modifications"""
   driver = None
   try:
      output = Tac.run( [ "ethtool", "-i", deviceName ], stdout=Tac.CAPTURE )
      driver = re.search( "driver: (.*)", output ).group( 1 )
   except: # pylint: disable=bare-except
      t0( f"Unexpected exception: {traceback.format_exc()}" )
   return driver

def getMac( deviceName ):
   """Retrieves the mac address for the given devicename from /sys/class/net
   This method is from SfFruHelper with slight modifications"""
   with open( os.path.join( SYS_CLASS_NET, deviceName, "address" ) ) as f:
      return f.read().strip()

def getPci( deviceName ):
   """Retrieves the pci address for the given devicename from /sys/class/net
   This method is from SfFruHelper with slight modifications"""
   pci = None
   try:
      output = Tac.run( [ "ethtool", "-i", deviceName ], stdout=Tac.CAPTURE )
      pci = re.search( ".*bus-info: (.*)", output ).group( 1 )
   except: # pylint: disable=bare-except
      t0( f"Unexpected exception: {traceback.format_exc()}" )
   return pci

def findPfForVf( physfn, vfPciAddr ):
   path = physfn + "/net"
   pfs = Tac.run( [ "ls", path ],
                  stdout=Tac.CAPTURE ).split( '\n' )
   # remove empty '' physfn from split
   pfs = [ fn for fn in pfs if len( fn ) ]

   # remove eth prefix from kernel interface names
   pfsIdx = [ int( idx[ 3 : ] ) for idx in pfs ]
   pfsIdx.sort()

   # case for XL710s, where there is a single vf per pf
   if len( pfs ) == 1:
      return pfs[ 0 ]
   # case for NAC, where there are multiple vf per pf
   else:
      rgx = r'[0-9a-f]{4}:[0-9a-f]{2}:[0-9a-f]{2}\.([0-9a-f]{1})'
      res = re.search( rgx, vfPciAddr )
      if not res or not res.group( 1 ):
         assert 0, f"Could not find fnNo for vfPciAddr: {vfPciAddr}, "\
            f"physfn: {physfn}"
      fnNo = int( res.group( 1 ) )
      # index into the pfsIdx list using the function number to
      # guarantee the original interface exists
      intfName = "eth" + str( pfsIdx[ fnNo ] )
      return intfName

def addVf2Pf( deviceCache ):
   t9( "adding Vf to Pf Mapping in deviceCache" )

   physfn = Tac.run( [ "find", "/sys/devices", "-name", "physfn" ],
                     stdout=Tac.CAPTURE ).split( '\n' )
   physfn = [ fn for fn in physfn if len( fn ) ]

   for fn in physfn:
      rgx = r'\/sys\/devices\/.*\/([0-9a-f:.]{12})\/physfn'
      res = re.search( rgx, fn )
      if not res or not res.group( 1 ):
         continue
      vfPciAddr = res.group( 1 )
      for devName in deviceCache[ DEVICE ]:
         if vfPciAddr == deviceCache[ PCI ][ devName ]:
            deviceCache[ VF2PF ][ devName ] = findPfForVf( fn, vfPciAddr )

def addPortToSfePhy( sfePhyDir, invPort, kernelIntfName ):
   phy = sfePhyDir.newPhy( kernelIntfName )
   phy.port = invPort
   t0( f"{invPort} invPort added for kernel interface {kernelIntfName}" )
   return phy

def addPortToEthPortDir( ethPortDir, deviceCache, deviceName, desc ):
   port = ethPortDir.newPort( deviceCache[ PORT ][ deviceName ] )
   port.description = desc
   port.label = deviceCache[ PORT ][ deviceName ]
   port.role = deviceCache[ ROLE ][ deviceName ]
   port.macAddr = deviceCache[ MAC ][ deviceName ]
   return port

def renameInterface( oldDevName, newDevName ):
   """In some cases (like for the Sfe Agent) we need to rename
   interfaces. This will change the interface name from oldDevName to
   newDevName
   This method is from SfFruHelper"""
   if inBuildEnv():
      t0( "Interface rename aborted since we seem to be on a build server" )
      return oldDevName
   # Rename interface to temporary name
   tmpDevName = oldDevName + "tmp"
   Tac.run( [ "ip", "link", "set", "dev", oldDevName, "down" ] )
   Tac.run( [ "ip", "link", "set", "dev", oldDevName,
              "name", tmpDevName ] )
   # Now rename interface to their final name and bring them
   # back up
   Tac.run( [ "ip", "link", "set", "dev", tmpDevName,
              "name", newDevName ] )
   try:
      Tac.run( [ "ip", "link", "set", "dev", newDevName, "up" ] )
   except Tac.SystemCommandError:
      assert 0, f"Unable to bring up link for{newDevName}"
   t9( "renameInterface from", oldDevName, "to", newDevName )
   return newDevName

def addDpsToDeviceCache( deviceCache, macAddr ):
   t9( "Adding Dps to deviceCache" )
   devName = DPS_INTF
   deviceCache[ PCI ][ devName ] = DPS_PCI
   deviceCache[ DEVICE ].append( devName )
   deviceCache[ MAC ][ devName ] = macAddr
   deviceCache[ DRIVER ][ devName ] = DPS_DRIVER
   deviceCache[ ROLE ][ devName ] = "Switched"
   deviceCache[ TYPE ][ devName ] = "VNI"
   deviceCache[ PORT ][ devName ] = DPS_PORT
   deviceCache[ VENDOR_ID ][ devName ] = DPS_VENDOR_ID
   deviceCache[ DEVICE_ID ][ devName ] = DPS_DEVICE_ID
   deviceCache[ VF2PF ][ devName ] = devName

def renameDevFromPci( intfName, pci ):
   """
   Search SYS_CLASS_NET for the given pci address. If the device name
   that is found with the matching pci doesn't match up with the
   given intfName, then change the device to the intfName name
   """
   for device in os.listdir( SYS_CLASS_NET ):
      # Do not want to look at devices lo or ma1
      if device.startswith( "et" ):
         sysPci = getPci( device )
         if pci == sysPci and intfName != device:
            renameInterface( device, intfName )
            break

def checkDevCachePci( intfName, pci ):
   """
   Search SYS_CLASS_NET for the given pci address. If the device name
   that is found with the matching pci doesn't match up with the
   given intfName, then change the intfName to the device name
   """
   for device in os.listdir( SYS_CLASS_NET ):
      # Do not want to look at devices lo or ma1
      if device.startswith( "et" ):
         sysPci = getPci( device )
         if pci == sysPci and intfName != device:
            renameInterface( intfName, device )
            break

def getKernelIntfNamesForPhy( pciDev, pciBus, numIntfs ):
   deviceNames = []
   for device in os.listdir( SYS_CLASS_NET ):
      vendorPath = f"{SYS_CLASS_NET}/{device}/device/vendor"
      devicePath = f"{SYS_CLASS_NET}/{device}/device/device"
      # Need to check vendor/device path exists as USB ethernet adapters may
      # include this information
      if re.match( r'eth\d', device ) and \
         os.path.exists( vendorPath ) and os.path.exists( devicePath ):
         devPci = getPci( device )
         if not devPci.startswith( f"0000:{pciBus:02x}" ):
            continue
         with open( vendorPath, 'r' ) as f:
            vendor_id = int( f.readline(), 16 )
         with open( devicePath, 'r' ) as f:
            device_id = int( f.readline(), 16 )
         if ( vendor_id, device_id ) == pciDev:
            deviceNames.append( device )
            if len( deviceNames ) == numIntfs:
               break
   assert len( deviceNames ) == numIntfs
   return deviceNames

def enableVfs( pciDev, vfPciDev, pciBus, numDevEnableVfs, numVfs ):
   """ Enables vfs on a pci device, and returns the list of pci
   devices associated with the vfPciDev passed in """
   if inBuildEnv():
      return
   t9( f"enabling vfs for {pciDev}" )

   sw_devices = getKernelIntfNamesForPhy( pciDev, pciBus, numDevEnableVfs )

   for dev in sw_devices:
      cmd = f"echo {numVfs} > /sys/class/net/{dev}/device/sriov_numvfs"
      Tac.run( [ "bash", "-c", cmd ], asRoot=True )
      t0( f"{numVfs} vf(s) enabled for interface {dev}" )

def getVfDevices( pciBus, vfPciDevId ):
   cmd = f"lspci -s {pciBus:x}: -d {vfPciDevId.vendor:x}:{vfPciDevId.device:x}"\
         " -D | cut -d \" \" -f 1 | sort"
   vf_devices = Tac.run( [ "bash", "-c", cmd ], asRoot=True, stdout=Tac.CAPTURE )
   return vf_devices.splitlines()

def addPortAssociations( vfDevs ):
   # In order for the enumeration to work, these HAVE to be sorted
   # Ideally this is already done, but as a safeguard, add sorting
   # here
   vfDevs.sort()
   for index, device in enumerate( vfDevs ):
      cmd = f"echo {index} > /sys/bus/pci/devices/{device}/port_association"
      Tac.run( [ "bash", "-c", cmd ], asRoot=True )

def waitForVfIntfs( vfPciAddrs, maxWait=30 ):
   # Sometimes, the interfaces will take take to be created.
   # We will search through SYS_CLASS_NET for the interfaces, and if
   # not all of them are found, then the FDL will wait 1 second before trying again
   waitTimeLeft = maxWait
   while True:
      count = 0
      waitTimeLeft -= 1
      for device in os.listdir( SYS_CLASS_NET ):
         if device.startswith( "eth" ):
            pciAddr = getPci( device )
            if pciAddr in vfPciAddrs:
               count += 1
      if count == len( vfPciAddrs ):
         break

      if waitTimeLeft <= 0:
         assert 0, f"{len(vfPciAddrs)} VF Interfaces were not found"\
               " after {maxWait} seconds"
      time.sleep( 1 )

def addPciDevToDeviceCache( deviceCache,
      pciAddrToName, kernelIntfPrefix, macAddrBase ):
   # Retrieve vendor id and device id from sysfs path
   pcibusReg = r'^([\w]+:[\w]+)'
   for key in pciAddrToName.keys():
      device = re.match( pcibusReg, key ).group()
      pciDevPath = f"{SYS_CLASS_PCIBUS}/{device}/device/"
      if os.path.exists( pciDevPath ):
         for pciAddr in os.listdir( pciDevPath ):
            if pciAddr not in key:
               continue
            vendorPath = f"{pciDevPath}{pciAddr}/vendor"
            devicePath = f"{pciDevPath}{pciAddr}/device"
            with open( vendorPath, 'r' ) as f:
               vendor_id = int( f.readline(), 16 )
            with open( devicePath, 'r' ) as f:
               device_id = int( f.readline(), 16 )
            if device not in deviceCache[ DEVICE ]:
               pciDev = PciDevIds( vendor_id, device_id )
               renameDevFromPci( pciAddrToName[ pciAddr ], pciAddr )
               addToDeviceCache( deviceCache,
                                 pciAddr,
                                 kernelIntfPrefix,
                                 pciDev,
                                 pciAddrToName, macAddrBase )

def addVfIntfsToDeviceCache( deviceCache,
      pciAddrToName, kernelIntfPrefix, macAddrBase ):
   # Retrieve Mac addresses, pci addresses, and drivers for the ethports
   for device in os.listdir( SYS_CLASS_NET ):
      vendorPath = f"{SYS_CLASS_NET}/{device}/device/vendor"
      devicePath = f"{SYS_CLASS_NET}/{device}/device/device"
      # search for vf interfaces that have not been renamed yet
      if re.match( r'eth\d', device ) and \
         os.path.exists( vendorPath ) and os.path.exists( devicePath ):
         # We only want to add the Ethernet Adaptive Virtual Functions to
         # the deviceCache.
         pciAddr = getPci( device )
         if pciAddr not in pciAddrToName:
            continue
         with open( vendorPath, 'r' ) as f:
            vendor_id = int( f.readline(), 16 )
         with open( devicePath, 'r' ) as f:
            device_id = int( f.readline(), 16 )
         if device not in deviceCache[ DEVICE ]:
            pciDev = PciDevIds( vendor_id, device_id )
            renameInterface( device, pciAddrToName[ pciAddr ] )
            addToDeviceCache( deviceCache,
                           pciAddr,
                           kernelIntfPrefix,
                           pciDev,
                           pciAddrToName, macAddrBase )

def addToDeviceCache( deviceCache, pciAddr,
      kernelIntfPrefix, pciDev,
      pciAddrToNameMapping, macAddrBase ):
   # remove 'et' or 'et1_' prefix to get port
   newDevName = pciAddrToNameMapping[ pciAddr ]
   port = int( newDevName[ len( kernelIntfPrefix ) : ] )
   deviceCache[ PCI ][ newDevName ] = pciAddr
   deviceCache[ DEVICE ].append( newDevName )
   deviceCache[ MAC ][ newDevName ] = MacAddrAdd( macAddrBase, port + 1 )
   deviceCache[ DRIVER ][ newDevName ] = getDriver( newDevName )
   deviceCache[ ROLE ][ newDevName ] = "Switched"
   deviceCache[ VENDOR_ID ][ newDevName ] = pciDev.vendor
   deviceCache[ DEVICE_ID ][ newDevName ] = pciDev.device
   deviceCache[ TYPE ][ newDevName ] = getDeviceType( pciDev )
   deviceCache[ PORT ][ newDevName ] = port
   t0( "Added device", newDevName, "to deviceCache" )
