#!/usr/bin/env python3
# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import commands

def start_net_emulator_module( cli, netem_mod_name ):
   cli.bess.create_module( 'NetEmulator', netem_mod_name )
   # add traffic class and attach task to all workers
   workers = len( cli.bess.list_workers().workers_status )
   for wid in range( workers ):
      tc_name = f"{netem_mod_name}:{wid}"
      all_tcs = cli.bess.list_tcs( wid ).classes_status
      found = any( getattr( tc, 'class' ).name == tc_name for tc in all_tcs )
      if not found:
         cli.bess.add_tc( tc_name, policy='timer', wid=wid )

      args = {
         'traffic_class' : tc_name,
         'wid' : wid,
         'tid' : -1
         }
      ret = cli.bess.run_module_command( netem_mod_name,
         'setTaskForTrafficClass',
         'NetEmulatorSetTaskForTrafficClassArg',
         args )
      cli.bess.attach_task( netem_mod_name, tc_name, module_taskid=ret.tid )

def net_emulation_enable( cli, module_name ):
   netem_mod_name = 'NetEmulator_' + module_name

   try:
      prev_modules = cli.bess._get_previous_module_link( module_name )
      cli.bess.pause_all()
      start_net_emulator_module( cli, netem_mod_name )

      gate_id = 0
      for prev_mod_name, prev_mod_ogate, _ in prev_modules:
         cli.bess.disconnect_modules( prev_mod_name, prev_mod_ogate )
         cli.bess.connect_modules( prev_mod_name, netem_mod_name,
                                             prev_mod_ogate, gate_id )
         gate_id += 1

      for gate_id in range( len( prev_modules ) ):
         cli.bess.connect_modules( netem_mod_name, module_name,
                                             gate_id, gate_id )

   except Exception as e: # pylint: disable=broad-except
      print( str( e ) )
   finally:
      cli.bess.resume_all()

def net_emulation_disable( cli, module ):
   netem_mod_name = 'NetEmulator_' + module

   try:
      prev_modules = cli.bess._get_previous_module_link( netem_mod_name )
      cli.bess.pause_all()

      for gate_id in range( len( prev_modules ) ):
         cli.bess.disconnect_modules( netem_mod_name, gate_id )

      gate_id = 0
      for prev_mod_name, prev_mod_ogate, _ in prev_modules:
         cli.bess.disconnect_modules( prev_mod_name, prev_mod_ogate )
         cli.bess.connect_modules( prev_mod_name, module,
                                            prev_mod_ogate, gate_id )
         gate_id += 1

      cli.bess.destroy_module( netem_mod_name )

   except Exception as e: # pylint: disable=broad-except
      print( str( e ) )
   finally:
      cli.bess.resume_all()

def net_emulation_set_lossrate( cli, module_name, percent ):
   netem_mod_name = 'NetEmulator_' + module_name
   loss = int( percent / 100 * 1000000 )
   args = {
     'loss' : loss
   }
   cli.bess.run_module_command( netem_mod_name, 'setLoss',
                           'NetEmulatorSetLossArg', args )

def net_emulator_set_delay( cli, module_name, mean, sd ):
   netem_mod_name = 'NetEmulator_' + module_name
   mean = int( mean * 1000 * 1000 )
   sd = int( sd * 1000 * 1000 )
   args = {
     'mean' : mean,
     'stdev' : sd
   }
   cli.bess.run_module_command( netem_mod_name, 'setDelay',
                           'NetEmulatorSetDelayArg', args )

# Net Emulator CLI commands

@commands.cmd( 'netemulator MODULE enable',
               'Attach net emulator to all input gates of a module' )
def netemulator_enable( cli, module ):
   net_emulation_enable( cli, module )

@commands.cmd( 'netemulator MODULE disable',
               'Disable Netemulator module' )
def netemulator_disable( cli, module ):
   net_emulation_disable( cli, module )

@commands.cmd( 'netemulator MODULE delay DELAY SD',
               'Configure delay and standard deviation in net emulator' )
def netemulator_set_delay( cli, module, delay, sd ):
   net_emulator_set_delay( cli, module, delay, sd )

@commands.cmd( 'netemulator MODULE lossrate PERCENT',
               'Configure the lossrate in percentage' )
def netemulator_lossrate( cli, module, percent ):
   net_emulation_set_lossrate( cli, module, percent )
