# Copyright (c) 2019 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import Tac
import Tracing
from GenericReactor import GenericReactor
from TypeFuture import TacLazyType

Dot1qHeader = TacLazyType( 'PacketTracer::Dot1qHeader' )
EthAddr = TacLazyType( 'Arnet::EthAddr' )
IntfId = TacLazyType( 'Arnet::IntfId' )
IpProtocolNumber = TacLazyType( 'Arnet::IpProtocolNumber' )
PacketTracerResult = TacLazyType( 'PacketTracer::Result' )
PacketTracerResultEntry = Tac.Type( 'PacketTracer::ResultEntry' )
PacketTracerState = TacLazyType( 'PacketTracer::PacketTracerState' )
ResultState = TacLazyType( 'PacketTracer::State' )
VlanIntfId = TacLazyType( 'Arnet::VlanIntfId' )

handle = Tracing.Handle( 'EtbaPacketTracer' )
t0 = handle.trace0
t2 = handle.trace2
t5 = handle.trace5
t8 = handle.trace8

# Trace levels used:
# 0: Plugin/SM initialization
# 2: PacketTracerSm
# 5: PacketTracerClientSm
# 8: Verbose messages

def trClient( *args, **kwargs ):
   '''
   Wrapper for ClientSm tracing which prefixes the client name and request Id.
   E.g.:
   [MplsUtils.1]: Processing request
   [MplsUtils.2]: Error: No MPLS label
   '''
   reqId = kwargs[ 'reqId' ]
   isVerbose = kwargs.get( 'verbose' )
   tr = t8 if isVerbose else t5
   reqIdStr = f'{reqId.clientName}.{reqId.id}'
   tr( '[' + reqIdStr + ']:', *args )

class PacketTracerSm:
   '''
   This SM monitors the PacketTracer config and status directories for new client
   names. Once the same name exists in both directories, we will create a new
   Client SM for that name's config/status.
   '''
   def __init__( self, bridge, config, status, packetTracerHwStatus,
                 packetTracerSwStatus ):
      t0( 'Initialize PacketTracerSm' )
      self.bridge = bridge
      self.config = config
      self.status = status
      self.clientSm = {}
      # Enable the PacketTracer HW status when to indicate that it's supported
      self.packetTracerHwStatus = packetTracerHwStatus
      self.packetTracerHwStatus.tracingSupported = True
      self.packetTracerSwStatus = packetTracerSwStatus
      self.packetTracerSwStatus.mplsPacketSupported = True
      self.packetTracerSwStatus.state = PacketTracerState.running

      self.handleClients()
      self.configReactor = GenericReactor( config, [ 'entityPtr' ],
                                           self.handleClient )
      self.statusReactor = GenericReactor( status, [ 'entityPtr' ],
                                           self.handleClient )

   def processClient( self, name ):
      '''
      Creates or deletes the client SM depending on whether the client is in both the
      config and status.
      '''
      if ( name in self.config and name in self.status ):
         t2( 'Creating clientSm:', name )
         clientConfig = self.config[ name ]
         clientStatus = self.status[ name ]
         self.clientSm[ name ] = PacketTracerClientSm( self.bridge,
                                                       clientConfig,
                                                       clientStatus,
                                                       self.packetTracerHwStatus )
      elif name in self.clientSm:
         t2( 'Deleting clientSm:', name )
         del self.clientSm[ name ]

   def handleClients( self ):
      for name in self.config:
         self.processClient( name )

   def handleClient( self, reactor, name ):
      "Callback for both the config and status reactors"
      self.processClient( name )

class PacketTracerClientSm( Tac.Notifiee ):
   '''
   This SM monitors the given PacketTracer client config requests and publishes the
   responses to the corresponding client status.
   '''
   notifierTypeName = 'PacketTracer::ClientConfig'

   def __init__( self, bridge, clientConfig, clientStatus, packetTracerHwStatus ):
      t0( 'Initializing PacketTracerClientSm for', clientConfig.name )
      Tac.Notifiee.__init__( self, clientConfig )
      self.bridge = bridge
      self.clientConfig = clientConfig
      self.clientStatus = clientStatus
      self.packetTracerHwStatus = packetTracerHwStatus
      self.helper = Tac.newInstance( 'PacketTracer::PacketConstructorHelper' )
      self.requestSm = {}
      self.cleanStaleResults()
      self.handleRequests()

   def cleanStaleResults( self ):
      "Cleans up the results for requests that have already been deleted"
      for requestId in self.clientStatus.result:
         if requestId not in self.clientConfig.request:
            del self.clientStatus.result[ requestId ]

   def handleRequests( self ):
      "Process requests that don't already have a corresponding result"
      for requestId in self.clientConfig.request:
         if requestId not in self.clientStatus.result:
            self.handleRequest( requestId )

   @Tac.handler( 'request' )
   def handleRequest( self, requestId ):
      if requestId not in self.clientConfig.request:
         trClient( 'Deleting result', reqId=requestId )
         del self.clientStatus.result[ requestId ]
         if requestId in self.requestSm:
            del self.requestSm[ requestId ]
         return

      request = self.clientConfig.request[ requestId ]
      if not request.valid:
         if requestId not in self.requestSm:
            self.requestSm[ requestId ] = PacketTracerRequestSm( request, self )
         return

      result = self.getResult( request )
      result.valid = True

   def recreatePacket( self, packetRequest, requestId ):
      hasDot1qHeader = packetRequest.l2Header.dot1qHeader != Dot1qHeader()
      hasInnerDot1qHeader = packetRequest.l2Header.innerDot1qHeader != Dot1qHeader()
      trClient( 'recreatePacket: Creating packet with the headers: L2 ' +
                ( '| Dot1Q ' if hasDot1qHeader else '' ) +
                ( '| Inner Dot1Q ' if hasInnerDot1qHeader else '' ) +
                ( '| L3 ' if packetRequest.hasL3 else '' ) +
                ( '| MPLS ' * packetRequest.numMplsLabels + '' ) +
                ( '| L3 ' if packetRequest.hasInnerL3Header else '' ) +
                ( '| L4 ' if packetRequest.hasL4 else '' ),
                reqId=requestId, verbose=True )

      maxPktSize = self.packetTracerHwStatus.maximumPacketSize
      packetStr = self.helper.constructPacket( requestId, packetRequest, maxPktSize )
      return packetStr

   def getResult( self, packetRequest ):
      '''
      Given a packet request, it will reconstruct the packet and give it to the
      EbraTestBridge processFrame code to get the list of egress intfs without
      letting it actually send the packet(s) out.
      '''
      requestId = packetRequest.requestId
      result = self.clientStatus.newResult( requestId )
      result.state = ResultState.error
      packetStr = self.recreatePacket( packetRequest, requestId )
      # Need to get the EbraTestPhyPort to give to processFrame
      ingressIntf = packetRequest.ingressIntf
      srcPort = self.bridge.port.get( ingressIntf )
      if not srcPort:
         trClient( 'Error: No srcPort found for:', ingressIntf, reqId=requestId )
         return result

      bridgeFrameInfoList = self.bridge.processFrame( packetStr,
                                                      packetRequest.l2Header.srcMac,
                                                      packetRequest.l2Header.dstMac,
                                                      srcPort,
                                                      tracePkt=True )

      # The EbraTestBridge object will first try routing the packet using all
      # registered routing handlers. After that, it will attempt then bridge all the
      # packets. As a result, it's possible we could have many egress intfs for a
      # given packet. We will consider a failure to be if there are no egress intfs
      if not bridgeFrameInfoList:
         trClient( 'Error: No bridged frames for given request', reqId=requestId )
         return result

      for bridgeFrameInfo in bridgeFrameInfoList:
         egressIntfs = bridgeFrameInfo.egressIntfs
         if not egressIntfs:
            continue

         trClient( 'Adding egress intfs:', egressIntfs, reqId=requestId,
                   verbose=True )
         for intf in egressIntfs:
            result.egressIntf[
               PacketTracerResultEntry.fromPhysicalIntf( intf ) ] = True
         result.state = ResultState.success

      trClient( 'Result state:', result.state, reqId=requestId )
      return result

class PacketTracerRequestSm( Tac.Notifiee ):
   notifierTypeName = "PacketTracer::Request"

   def __init__( self, packetTracerRequest, parentSm ):
      super().__init__( packetTracerRequest )
      self.parentSm = parentSm
      self.packetTracerRequest = packetTracerRequest

   @Tac.handler( "valid" )
   def handleValid( self ):
      if self.packetTracerRequest.valid and self.parentSm:
         self.parentSm.handleRequest( self.packetTracerRequest.requestId )

def packetTracerSmFactory( bridge ):
   em = bridge.em()
   clientConfig = em.entity( 'packettracer/config' )
   clientStatus = em.entity( 'packettracer/status' )
   packetTracerHwStatus = em.entity( 'packettracer/hwstatus' )
   packetTracerSwStatus = em.entity( 'packettracer/swstatus' )

   return PacketTracerSm( bridge=bridge,
                          config=clientConfig,
                          status=clientStatus,
                          packetTracerHwStatus=packetTracerHwStatus,
                          packetTracerSwStatus=packetTracerSwStatus )
