#!/usr/bin/env python3
# Copyright (c) 2008-2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

import Tac, Tracing, Arnet.Device, os, re
import Cell
import SharedMem
import Smash
from Arnet.NsLib import DEFAULT_NS
from EbraTestBridgePort import EbraTestPort
from EbraTestBridge import IEEE_LACP_ADDR, PKTEVENT_ACTION_NONE, IEEE_LLDP_ADDR
from EbraTestBridge import IEEE_8021X_ADDR
from PtpPacketLib import ptpPacketType
from Arnet import PktParserTestLib # pylint: disable=ungrouped-imports

perlinkMode = Tac.Type( "Bfd::PerlinkMode" )

handle = Tracing.Handle( 'EbraTestBridge' )
t2 = handle.trace2    # link aggregation specific events
t4 = handle.trace4    # misc important events
t5 = handle.trace5    # Packet tx/rx, error level
t6 = handle.trace6    # Packet tx/rx, alert level
t7 = handle.trace7    # Packet tx/rx, normal level
t8 = handle.trace8    # Packet tx/rx, debug level
t8b = handle.trace3   # Packet tx/rx, super-debug level

# Global for EbraTestLagPort to get at the Lag::Input::IntfDeviceNameDir
lagInputDevnameDir = None

# Globals required to 'short-circuit' initialMember to programmedMember
hwLagConfig = None
hwLagStatus = None
lagProgDefaultSm = None

# Global to get which interfaces are on perlink mode
bfdConfigIntf = None

# Global to get the bridge mac
bridgingConfig = None

# Globals required for EthIntffVrfSm managing Lag intfs
ethLagIntfStatusDir = None
ethLagIntfStatusLocalDir = None
mngdIntfStatusDir = None
mngdIntfStatusLocalDir = None
mngdIntfStatusSelector = None 
ethIntfVrfSmHelper = None

# Globals required for LagKernelDeviceNameAndOperStateReactor
allIntfStatusDir = None
allIntfStatusLocalDir = None
kniStatus = None
lagKernelDeviceOperStateHelper = None

# Global required for etbaDutLagHwConfigReactor
hwLagInputConfigEbraTest = None

BFD_PORT_CTRL = 3784
BFD_PORT_ECHO = 3785
BFD_PORT_LAG = 6784

def isBfdPacket( data ):
   headers = PktParserTestLib.parsePktStr( data )[ 1 ]
   udpHdr = PktParserTestLib.findHeader( headers, "UdpHdr" )
   if udpHdr:
      t6( "UDP Packet, dest port =", udpHdr.dstPort )
      # these are the only ports that need per-link support
      if udpHdr.dstPort in [ BFD_PORT_LAG, BFD_PORT_CTRL, BFD_PORT_ECHO, ]:
         return True
   return False

class EbraTestLagMemberReactor( Tac.Notifiee ):
   """Manages the LAG port's member based on LagIntfStatus.member"""

   notifierTypeName = 'Interface::EthLagIntfStatus'

   def __init__( self, ethIntfStatus, lagPort ):
      self.lagPort_ = lagPort
      self.bridge_ = lagPort.bridge_
      Tac.Notifiee.__init__( self, ethIntfStatus )

   @Tac.handler( 'member' )
   def handleMember( self, key ):
      # When mlag is configured then the port channel can remain active
      # even if there are no local ports, however on actual hardware
      # we remove the port channel from hardware and clearing all
      # the mac entries
      localPortsInLag = False
      n = self.notifier_
      t8( 'EbraTestLagMemberReactor handleMember', key, n.intfId, list( n.member ) )
      # Resynchronize: compare members with self.lagPort_.ports list
      name2port = {}
      toDel = {}
      # pylint: disable-next=consider-using-enumerate
      for p in range( len( self.lagPort_.ports ) ):
         intfName = self.lagPort_.ports[ p ].name()
         name2port[ intfName ] = p
         if intfName not in n.member:
            toDel[ p ] = intfName
      # add new members
      for i in n.member:
         if not i.startswith( "Peer" ):
            localPortsInLag = True
         if i not in name2port:
            if i not in self.bridge_.port:
               # Probably a peer interface. Ignore.
               continue
            p = self.bridge_.port[ i ]
            self.lagPort_.ports.append( p )
            p.setProcessFrame( self.lagPort_.processFrame )
            t2( "handleMember adding", i, "to", n.intfId )
      newPortList = []
      # pylint: disable-next=consider-using-enumerate
      for portNum in range( len( self.lagPort_.ports ) ):
         if portNum in toDel:
            self.lagPort_.ports[ portNum ].setProcessFrame( None )
            t2( "handleMember removing", toDel[ portNum ], "from", n.intfId )
         else:
            newPortList.append( self.lagPort_.ports[ portNum ] )
      self.lagPort_.ports = newPortList
      self.bridge_.macTableFlushPort( key )

      if not localPortsInLag:
         self.bridge_.macTableFlushPort( self.notifier_.intfId )

   @Tac.handler( 'linkStatus' )
   def handleLinkStatus( self ):
      if self.notifier_.linkStatus == 'linkDown':
         self.bridge_.macTableFlushPort( self.notifier_.intfId )
      
class EbraTestLagDirReactor:
   """We wait until a lag intf appears in the interface/{config,status}/eth/intf
   collections before adding it to the test bridge.  We don't just watch the
   interface/{config,status}/eth/lag collections, because there is a time delay
   between showing up there and showing up in the 'intf' collections.  The test
   bridge looks in the 'intf' collections, so that's what we watch for.

   This class is a base class for the reactors to each of the config and status
   collections since the code for each is nearly identical.
   """

   def __init__( self, bridge ):
      self.bridge_ = bridge
      self.ethIntfConfigDir_ = bridge.ethIntfConfigDir_
      self.ethIntfStatusDir_ = bridge.ethIntfStatusDir_

   def handleIntf( self, key ):
      t2( "EbraTestLagDirReactor for", key )
      if key is None:
         # In the unlikely event of deletion of multiple interfaces in a
         # multi-attribute notification, we will leave orphan interfaces
         # in the bridge port list, until a given interface gets re-added.
         for i in self.ethIntfConfigDir_.intfConfig:
            self.handleIntf( i )
         return

      def isSubIntf( intfName ):
         return '.' in intfName

      if not key.startswith( 'Port-Channel' ) or isSubIntf( key ):
         return

      t2( "EbraTestLagDirReactor checking", key )
      intfConfig = self.ethIntfConfigDir_.intfConfig.get( key )
      intfStatus = self.ethIntfStatusDir_.intfStatus.get( key )
      if intfConfig and intfStatus:
         if not key in self.bridge_.port:
            t2( "EbraTestLagDirReactor adding", key, "to ports" )
            self.bridge_.addPort(
               EbraTestLagPort( self.bridge_, None, None, intfConfig, intfStatus ) )
         else:
            # intfStatus object changed.  This could be a result of
            # an Mlag forming on the secondary dut.
            port = self.bridge_.port[ key ]
            t2( 'lag port', key, 'contains ports', port.ports )
            oldIntfStatus = port.intfStatus_
            if intfStatus != oldIntfStatus:
               t2( 'EbraTestLagDirReactor, change in intfStatus, modify port' )
               port.updateStatusConfig( intfStatus, intfConfig )
               
      elif key in self.bridge_.port:
         # Make sure that the underlying ports are removed from the Lag.
         # The EbraTestLagMemberReactor may not have gotten around to it
         # before the Lag was deleted.
         t2( "EbraTestLagDirReactor removing", key, "from ports" )
         for port in self.bridge_.port[ key ].ports:
            port.setProcessFrame( None )
         self.bridge_.delPort( key )

class EbraTestLagStatusReactor( Tac.Notifiee, EbraTestLagDirReactor ):

   notifierTypeName = 'Interface::EthIntfStatusDir'
   
   def __init__( self, bridge ):
      EbraTestLagDirReactor.__init__( self, bridge )
      Tac.Notifiee.__init__( self, bridge.ethIntfStatusDir_ )

   @Tac.handler( 'intfStatus' )
   def handleStatus( self, key ):
      self.handleIntf( key )

class EbraTestLagConfigReactor( Tac.Notifiee, EbraTestLagDirReactor ):

   notifierTypeName = 'Interface::EthIntfConfigDir'

   def __init__( self, bridge ):
      EbraTestLagDirReactor.__init__( self, bridge )
      Tac.Notifiee.__init__( self, bridge.ethIntfConfigDir_ )

   @Tac.handler( 'intfConfig' )
   def handleConfig( self, key ):
      self.handleIntf( key )

class EtbaDutLagHwConfigReactor( Tac.Notifiee ):

   notifierTypeName = 'Interface::AllEthPhyIntfStatusDir'

   def __init__ ( self, bridge ):
      self.ethPhyIntfStatusDir_ = bridge.allEthPhyIntfStatusDir_
      self.hwLagInputConfig_ = hwLagInputConfigEbraTest
      Tac.Notifiee.__init__( self, bridge.allEthPhyIntfStatusDir_ )
      self.etbaDutLagHwConfig( None )

   @Tac.handler( 'intfStatus' )
   def etbaDutLagHwConfig( self, intfName ):
      add = True
      if intfName and intfName not in self.ethPhyIntfStatusDir_.intfStatus:
         add = False
      t4( "etbaDutLagHwConfig handle", intfName, "add" if add else "delete" )

      # To decrease limits for testing purposes, environment variable
      # LAG_MAX_LAGS can be set to the desired limit
      PortChannelNum = Tac.Type( "Lag::PortChannelNum" )
      maxLags = int( os.environ.get( "LAG_MAX_LAGS", PortChannelNum.max ) )
      maxMembersPerLag = 16
      lagGroupList = [ ( 'switch1', maxLags, maxMembersPerLag,
                         list( self.ethPhyIntfStatusDir_.intfStatus ) ) ]
      ethLagGroup = self.hwLagInputConfig_.lagGroup
      ethPhyIntfLagGroup = self.hwLagInputConfig_.phyIntf
      for g in lagGroupList:
         lg = ethLagGroup.get( g[ 0 ], None )
         if not lg:
            lg = ethLagGroup.newMember( g[ 0 ] )
         lg.maxLagIntfs = g[ 1 ]
         lg.maxPortsPerLagIntf = g[ 2 ]
         for p in g[ 3 ]:
            lp = ethPhyIntfLagGroup.get( p, None )
            if not lp:
               lp = ethPhyIntfLagGroup.newMember( p )
            lp.lagGroup = lg
      if not add and intfName:
         del ethPhyIntfLagGroup[ intfName ]

class EbraTestLagPort( EbraTestPort ):
   """Simulates a LAG port."""

   def __init__( self, bridge, tapDevice, trapDevice, intfConfig, intfStatus ):
      """Initialize the simulated LAG.  Create reactors to LAG state
      to maintain collection of ports in the LAG."""

      assert tapDevice is None
      assert trapDevice is None
      assert intfConfig
      assert intfStatus

      self.lagPortReactor_ = None
      self.ports = []

      EbraTestPort.__init__( self, bridge, intfConfig, intfStatus )

      intfName = intfStatus.intfId
      lagNum = re.search( r'\D*(\d+)', intfName ).group( 1 )
      if bridge.inNamespace():
         devName = 'po%s' % lagNum
      else:
         devName = bridge.name() + '-po' + lagNum
      bridgeMac = bridgingConfig.bridgeMacAddr
      self.trapDevice_ = Arnet.Device.Tap( devName, hw=bridgeMac )
      self.trapFile_ = Tac.File( self.trapDevice_.fileno(),
                                 self._trapDeviceReadHandler,
                                 self.trapDevice_.name,
                                 readBufferSize=16000 )

      intfDeviceName = Tac.Value( 'Lag::Input::IntfDeviceName',
                                  intfId=intfName,
                                  genId=intfStatus.genId,
                                  deviceName=self.trapDevice_.name )
      lagInputDevnameDir.intfDeviceName.addMember( intfDeviceName )
      t2( "EbraTestLagPort", intfName, "initialized. Device is:",
         self.trapDevice_.name )

   def close( self ):
      del lagInputDevnameDir.intfDeviceName[ self.intfStatus_.intfId ]
      self.trapFile_.close()
      self.trapDevice_.close()
      t2( "EbraTestLagPort", self.name(), "closed" )

   def updateStatusConfig( self, intfStatus, intfConfig ):
      EbraTestPort.updateStatusConfig( self, intfStatus, intfConfig )
      self.lagPortReactor_ =  EbraTestLagMemberReactor( intfStatus, self )
      self.lagPortReactor_.handleMember( None ) 

   def bfdPerLinkConfigured( self ):
      if ( self.name() in bfdConfigIntf.perLink and
            bfdConfigIntf.perLink[ self.name() ] != perlinkMode.none ):
         return True
      return False

   def processFrame( self, data, srcMacAddr, dstMacAddr, intf, tracePkt ):
      """An incoming interface has handed us a packet that
      the bridge needs to think came in on the aggregate."""
      if dstMacAddr == IEEE_LACP_ADDR:
         t7( 'Passing LACP control frame from', srcMacAddr, 'to', intf.name() )
         egressIntf = intf
      elif dstMacAddr == IEEE_LLDP_ADDR:
         t7( 'Passing LLDP frame from', srcMacAddr, 'to', intf.name() )
         egressIntf = intf
      elif dstMacAddr == IEEE_8021X_ADDR:
         t7( 'Passing Dot1x/Macsec frame from', srcMacAddr, 'to', intf.name() )
         egressIntf = intf
      elif ptpPacketType( data ) != None: # pylint: disable=singleton-comparison
         t7( 'Passing Ptp frame from', srcMacAddr, 'to', intf.name())
         egressIntf = intf
      elif isBfdPacket( data ) and self.bfdPerLinkConfigured():
         t7( 'Passing Bfd frame from', srcMacAddr, 'to', intf.name() )
         egressIntf = intf
      else:
         t7( 'Remapping incoming frame with source:', srcMacAddr,
             'destination:', dstMacAddr, 'from physical port', intf.name(),
             'to aggregate port', self.name() )
         egressIntf = self
      self.bridge_.processFrame( data, srcMacAddr, dstMacAddr, egressIntf, tracePkt )

   def sendFrame( self, data, srcMacAddr, dstMacAddr, srcPortName, 
                  priority, vlanId, vlanAction ):
      """Map the packet to the right set of underlying ports and
      hash to pick a single port."""
      if self.intfConfig_.enabled:
         if not self.ports:
            t8( "Refusing to transmit on aggregate port %s with no members" %
                        self.name() )
            return
         # To do: better hash algorithm (Does this matter?)
         lagHash = hash( str( srcMacAddr ) + str( dstMacAddr ) )
         t8( 'lag hashes ( SMAC %s, DMAC %s ) to %d' %
             ( srcMacAddr, dstMacAddr, lagHash ) )
         dst = self.ports[ lagHash % len( self.ports ) ]
         t7( 'Remapping outgoing frame with source: %s, destination: %s from '
             'aggregate port %s to physical port %s' % ( srcMacAddr, dstMacAddr,
               self.name(), dst.name() ) )
         dst.sendFrame( data, srcMacAddr, dstMacAddr, srcPortName,
                        priority, vlanId, vlanAction )
      else:
         t8( "Refusing to transmit on disabled port %s" % self.name() )

   def _trapDeviceReadHandler( self, data ):
      t8( "trapping frame output on %s" % self.name(), Tracing.HexDump( data ) )
      self.sendFrame( data, None, None, None, None, None, PKTEVENT_ACTION_NONE )

   def trapFrame( self, data ):
      t8( "trapping frame input on %s" % self.name(), Tracing.HexDump( data ) )
      os.write( self.trapDevice_.fileno(), data )
      

def bridgeInit( bridge ):
   t2( "Lag plugin bridgeInit" )
   bridge.lagPluginLagStatusReactor_ = EbraTestLagStatusReactor( bridge )
   bridge.lagPluginLagConfigReactor_ = EbraTestLagConfigReactor( bridge )
   bridge.etbaDutLagHwConfigReactor_ = EtbaDutLagHwConfigReactor( bridge )
   bridge.hwCapabilities_ = bridge.em().entity( 'bridging/hwcapabilities' )
   bridge.hwCapabilities_.extendedLagIdSupported = True
   bridge.hwCapabilities_.mixedSpeedLagSupported = True
   bridge.hwCapabilities_.lagHwInterlockSupported = True

def agentInit( em ):
   t2( "Lag plugin agentInit" )
   global lagInputDevnameDir
   global hwLagInputConfigEbraTest
   mg = em.mountGroup()
   lagInputDevnameDir = mg.mount( 'lag/input/devname/ebratest',
                                  'Lag::Input::IntfDeviceNameDir', 'wc' )
   mg.mount( 'interface/status/eth/phy', 'Tac::Dir', 'ri' )
   hwLagInputConfigEbraTest = mg.mount( 'hardware/lag/input/config/ebratest',
                                        'Hardware::Lag::Input::Config', 'wc' )

   global bfdConfigIntf 
   bfdConfigIntf = mg.mount( 'bfd/config/intf', 'Bfd::ConfigIntf', 'r' )

   global hwLagConfig
   global hwLagStatus
   hwLagConfig = mg.mount( 'hardware/lag/config',
                           'Hardware::Lag::Config', 'rO' )
   hwLagStatus = mg.mount( 'hardware/lag/status',
                           'Hardware::Lag::Status', 'w' )

   global bridgingConfig
   bridgingConfig = mg.mount( 'bridging/config', 'Bridging::Config', 'r' )

   em.mount( 'bridging/hwcapabilities', 'Bridging::HwCapabilities', 'w' )
   
   global ethLagIntfStatusDir
   ethLagIntfStatusDir = em.createLocalEntity( 'interface/status/eth/lag',
                                               'Interface::EthLagIntfStatusDir' )
   global ethLagIntfStatusLocalDir
   ethLagIntfStatusLocalDir = mg.mount(
      Cell.path( 'interface/status/eth/lag/local' ),
      'Interface::EthLagIntfStatusLocalDir', 'r' )
   global allIntfStatusDir
   global allIntfStatusLocalDir
   global kniStatus
   allIntfStatusDir = mg.mount( 'interface/status/all',
                     'Interface::AllIntfStatusDir', 'r' )
   allIntfStatusLocalDir = mg.mount(
                     Cell.path( 'interface/status/local' ),
                     'Interface::AllIntfStatusLocalDir', 'r' )
   shmemEm = SharedMem.entityManager( sysdbEm=em )
   kniStatus = shmemEm.doMount( "kni/ns/%s/status" % DEFAULT_NS,
                                "KernelNetInfo::Status",
                                Smash.mountInfo( 'keyshadow' ) )


   def onMountComplete():
      global lagProgDefaultSm
      global mngdIntfStatusDir
      global mngdIntfStatusLocalDir
      global mngdIntfStatusSelector
      global ethIntfVrfSmHelper
      global lagKernelDeviceOperStateHelper

      lagInputDevnameDir.priority = 10
      lagInputDevnameDir.containsLocalDevices = True
      # state-machine that copies initialMembers to programmedMembers
      lagProgDefaultSm = Tac.newInstance( "Hardware::Lag::LagProgDefaultSm",
                                          hwLagConfig, hwLagStatus )

      # Starting LagKernelDeviceNameAndOperStateReactor
      lagKernelDeviceOperStateHelper = Tac.newInstance(
                           "Lag::LagKernelDeviceOperStateHelper",
                           ethLagIntfStatusDir,
                           allIntfStatusDir,
                           allIntfStatusLocalDir,
                           kniStatus )

      # Starting EthIntf::EthIntfVrfSm ( which manages the lag interfaces )
      # in Etba
      mngdIntfStatusDir = \
                  Tac.newInstance( 'Interface::ManagedIntfStatusDir', 'statusDir' )
      mngdIntfStatusLocalDir = \
         Tac.newInstance( 'Interface::ManagedIntfStatusLocalDir', 'statusLocalDir' )
      mngdIntfStatusSelector = \
         Tac.newInstance( "Lag::ManagedIntfStatusSelector",
                          ethLagIntfStatusDir, ethLagIntfStatusLocalDir,
                          mngdIntfStatusDir, mngdIntfStatusLocalDir )
      ethIntfVrfSmHelper = Tac.newInstance( "EthIntf::EthIntfVrfSmHelper",
                                             em.cEntityManager() )
      ethIntfVrfSmHelper.doStartSm( mngdIntfStatusDir, mngdIntfStatusLocalDir )

   mg.close( onMountComplete )

def Plugin( ctx ):
   t2( "Lag plugin registering" )
   ctx.registerInterfaceHandler( 'Interface::EthLagIntfStatus',
                                 EbraTestLagPort )
   ctx.registerBridgeInitHandler( bridgeInit )
   ctx.registerAgentInitHandler( agentInit )
