#!/usr/bin/env python3
# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import os
import threading

import BasicCliModes
from CliCommon import AlreadyHandledError
from CliDynamicSymbol import CliDynamicPlugin
import CliGlobal
from CliPlugin import ForwardingDestinationCommon
import ConfigMount
import LazyMount
import Tac
from TypeFuture import TacLazyType

from Ark import synchronized

ForwardingDestinationModel = CliDynamicPlugin( 'ForwardingDestinationModel' )
EgressInterface = ForwardingDestinationModel.EgressInterface
PromptTree = CliDynamicPlugin( 'PromptTree' )
FieldException = PromptTree.FieldException
ForwardingDestinationHelper = CliDynamicPlugin( 'ForwardingDestinationHelper' )
ForwardingDestinationPromptTree = CliDynamicPlugin(
   'ForwardingDestinationPromptTree' )
ForwardingDestinationModel = CliDynamicPlugin( 'ForwardingDestinationModel' )
PacketDestination = ForwardingDestinationModel.PacketDestination

PacketTracerState = TacLazyType( 'PacketTracer::PacketTracerState' )
RequestId = TacLazyType( 'PacketTracer::RequestId' )
State = TacLazyType( 'PacketTracer::State' )

class RequestState:
   currentId = None
   count = 0
   # Grabbed when using currentId
   idLock = threading.Lock()
   # Grabbed when using count
   countLock = threading.Lock()

requestState = RequestState()

gv = CliGlobal.CliGlobal(
   activeClientDir=None,
   clientConfig=None,
   clientStatus=None,
   packetTracerHwStatus=None,
   packetTracerSwStatus=None )

# -----------------------------------------------------------------------------------
# Helper functions
# -----------------------------------------------------------------------------------
def shouldWaitForSleep():
   # As an optimization in tests we use virtual time in waitFor instead of sleep
   return 'PACKETTRACER_TEST' not in os.environ

@synchronized( requestState.countLock )
def cleanupActiveRequest():
   requestState.count -= 1
   assert requestState.count >= 0

   # If there are no active requests perform a cleanup
   if requestState.count == 0:
      gv.activeClientDir.deleteEntity(
         ForwardingDestinationCommon.CLI_CLIENT_NAME )

@synchronized( requestState.countLock )
def markActiveRequest():
   # Mark this session as making a request, populate the CLI client in the
   # activeClient if it isn't already.
   requestState.count += 1
   gv.activeClientDir.newEntity( 'Tac::Dir',
                                 ForwardingDestinationCommon.CLI_CLIENT_NAME )

@synchronized( requestState.idLock )
def getNewRequestId():
   if requestState.currentId is None:
      # If ConfigAgent crashes between the CLI sending a request to the PD agent and
      # the PD agent sending a response the request/response will be left dangling in
      # Sysdb. To avoid this, we clear all requests each time we start up. This helps
      # us avoid re-using a previously used id.
      gv.clientConfig.request.clear()
      requestState.currentId = 0
   requestState.currentId += 1
   return RequestId( clientName=ForwardingDestinationCommon.CLI_CLIENT_NAME,
                     id=requestState.currentId )

def generateRequest():
   requestId = getNewRequestId()
   return gv.clientConfig.newRequest( requestId )

def sendRequest( request ):
   requestId = request.requestId
   request.valid = True
   timeout = ( gv.packetTracerHwStatus.timeout +
               ForwardingDestinationCommon.PROCESSING_DELAY )

   try:
      Tac.waitFor( lambda: ( requestId in gv.clientStatus.result and
                             gv.clientStatus.result[ requestId ].valid ),
                   sleep=shouldWaitForSleep(), timeout=timeout, warnAfter=None,
                   maxDelay=ForwardingDestinationCommon.MAX_WAITFOR_DELAY )
   except Tac.Timeout:
      result = None
   else:
      # Fetch result
      result = gv.clientStatus.result[ requestId ]

   return result

def handleResult( args, mode, requestId, result ):
   model = None
   state = None
   if result:
      state = result.state

   if state == State.timeout:
      model = PacketDestination( status='timeout' )
   elif state == State.success:
      model = None
      if 'detail' in args:
         ( details, sectionOrder ) = showForwardingDestinationDetailHandler(
                                          mode, args, result.requestId )
         model = PacketDestination( status='success',
                                    details=details, _sectionOrder=sectionOrder )
      else:
         model = PacketDestination( status='success' )
      for resultEntry in result.egressIntf:
         ei = EgressInterface()
         ei.logical = resultEntry.logicalIntf
         ei.physical = resultEntry.physicalIntf
         model.egressInterfaces.append( ei )
   else:
      error = 'An error occurred with the request. Clearing packet configuration.'
      if state == State.invalidRequest:
         error = 'The request you made is invalid. Clearing packet configuration.'
      mode.addError( error )
      mode.session.sessionDataIs( 'PacketTracer.Packet', None )

   # Clear out the request
   del gv.clientConfig.request[ requestId ]
   return model

def showForwardingDestinationDetailHandler( mode, args, requestId ):
   details = {}
   sectionOrder = []
   numUnguardedHooks = 0
   for func, guard in \
          ForwardingDestinationCommon.showForwardingDestinationHook.extensions():
      if guard( mode, None ) is None:
         ( details, sectionOrder ) = func( mode, args, requestId )
         numUnguardedHooks += 1
   assert numUnguardedHooks <= 1, "Found too many possible results"
   return ( details, sectionOrder )

def verifyPacketTracerState( mode ):
   errorMsg = 'An error occurred with the request.'
   # Wait for PD agent to start up - i.e. be in any state but 'unknown'
   try:
      Tac.waitFor( lambda: ( gv.packetTracerSwStatus.state !=
                             PacketTracerState.unknown ),
                   sleep=shouldWaitForSleep(),
                   timeout=ForwardingDestinationCommon.AGENT_START_DELAY,
                   warnAfter=None,
                   maxDelay=ForwardingDestinationCommon.MAX_WAITFOR_DELAY )
   except Tac.Timeout as e:
      mode.addError( errorMsg )
      raise AlreadyHandledError from e

   # Check if PD agent is currently processing requests
   if gv.packetTracerSwStatus.state != PacketTracerState.running:
      if gv.packetTracerSwStatus.error:
         for error in gv.packetTracerSwStatus.error:
            mode.addError( error )
      else:
         mode.addError( errorMsg )
      raise AlreadyHandledError

def processRequest( mode, args ):
   # Verify that packet tracer is in state where it can process requests
   verifyPacketTracerState( mode )

   # Due to the use of iteration rules most arguments will be in list form. It's
   # not expected that any of them be an actual list longer than one element due
   # to using maxMatches=1. Go through each arg and change lists to be the value
   # itself.
   for key in args:
      if isinstance( args[ key ], ( list, tuple ) ):
         args[ key ] = args[ key ][ 0 ]

   treeDict, newConfiguration = \
         ForwardingDestinationHelper.fetchConfiguredPacket( mode )

   # Update packet fields with provided args. Also generate a user readable list
   # of all updated fields.
   updatedAttributes = []
   for key in args:
      if key in ForwardingDestinationHelper.ArgToLabel:
         updatedAttributes.append( ForwardingDestinationHelper.ArgToLabel[ key ] )
         treeDict[ key ] = args[ key ]

   # Determine packet information from fields that are present in the treeDict
   packetInfo = ForwardingDestinationHelper.generatePacketTypes( treeDict )
   treeDict[ '<packetType>' ] = packetInfo.packetType
   treeDict[ '<innerPacketType>' ] = packetInfo.innerPacketType
   treeDict[ '<l4Type>' ] = packetInfo.l4Type
   treeDict[ '<ipVersion>' ] = packetInfo.ipVersion
   treeDict[ '<greType>' ] = packetInfo.greType

   # Check if we should prompt for a raw packet
   rawPacketPrompt = gv.packetTracerHwStatus.rawPacketSupported
   if not isinstance( mode, BasicCliModes.EnableMode ):
      rawPacketPrompt = False

   # Check if inner packet types are supported
   innerPacketPrompt = ( gv.packetTracerHwStatus.extendedTunnelTypesSupported and
                         gv.packetTracerSwStatus.extendedTunnelTypesSupported )

   ForwardingDestinationHelper.checkForIncompatibleFields( mode, treeDict )

   if mode.session.isInteractive() and mode.session.shouldPrint():
      # Prompt for all other required fields
      pt = ForwardingDestinationPromptTree.PacketPromptTree( rawPacketPrompt,
                                                             innerPacketPrompt )
      originalTreeDict = treeDict.copy()
      reprompt = 'edit' in args
      try:
         pt.prompt( mode, treeDict, reprompt )
      except FieldException as e:
         mode.addError( str( e ) )
         return None

      for k in treeDict:
         if k in ForwardingDestinationCommon.PacketInfoFields:
            continue
         if ( k not in originalTreeDict or
              originalTreeDict[ k ] != treeDict[ k ] ):
            updatedAttributes.append(
                  ForwardingDestinationHelper.ArgToLabel[ k ] )
   else:
      if ForwardingDestinationHelper.checkForMissingFields(
            mode,
            ForwardingDestinationHelper.generateRequiredFields( packetInfo ),
            treeDict ):
         return None

   if treeDict[ '<packetType>' ] == 'raw':
      # We don't store any information when it is a raw packet
      ForwardingDestinationHelper.clearConfiguredPacket( mode )
   elif mode.session.shouldPrint():
      # The user can opt to run the command without any parameters, which will
      # just re-use the existing config. In this case no attributes are updated.
      if updatedAttributes:
         if newConfiguration:
            print( 'Saved packet configuration' )
         else:
            print( 'Updated packet configuration with new field(s): ' +
                   ', '.join( updatedAttributes ) )

   # Update ethertype and L4 type based on packet type
   ForwardingDestinationHelper.updateEthertypeAndProtocol( treeDict )

   # Validate fields
   ForwardingDestinationHelper.validateFields( mode, treeDict,
                                               gv.packetTracerHwStatus,
                                               gv.packetTracerSwStatus )

   # Generate request in Sysdb and update it with provided fields, then print it out
   request = generateRequest()
   request = ForwardingDestinationHelper.updateRequest( request, treeDict )
   if mode.session.shouldPrint():
      ForwardingDestinationHelper.printRequest( request )

   # Send request
   result = sendRequest( request )
   return handleResult( args, mode, request.requestId, result )

def commandHandler( mode, args ):
   # We only want to add the CLI as an active client of PacketTracer when it is
   # actively processing requests - so mark us as an active client and once
   # processRequest terminates clean that up.
   markActiveRequest()
   try:
      return processRequest( mode, args )
   finally:
      cleanupActiveRequest()

def clearPacketHandler( mode, args ):
   ForwardingDestinationHelper.clearConfiguredPacket( mode )
   print( 'Cleared packet configuration' )

def Plugin( entityManager ):
   gv.packetTracerHwStatus = LazyMount.mount(
      entityManager,
      'packettracer/hwstatus',
      'PacketTracer::HwStatus', 'r' )
   gv.packetTracerSwStatus = LazyMount.mount(
      entityManager,
      'packettracer/swstatus',
      'PacketTracer::SwStatus', 'r' )
   gv.clientStatus = LazyMount.mount(
      entityManager,
      'packettracer/status/{}'.format(
         ForwardingDestinationCommon.CLI_CLIENT_NAME ),
      'PacketTracer::ClientStatus', 'r' )
   gv.clientConfig = ConfigMount.mount(
      entityManager,
      'packettracer/config/{}'.format(
         ForwardingDestinationCommon.CLI_CLIENT_NAME ),
      'PacketTracer::ClientConfig', 'w' )
   gv.activeClientDir = LazyMount.mount(
      entityManager,
      'packettracer/activeClient',
      'Tac::Dir', 'w' )
