# Copyright (c) 2023 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import print_function
import os
import shlex
import subprocess
import unittest
import shutil

this_file_dir = os.path.dirname(os.path.realpath(__file__))
bessctl = os.path.join(this_file_dir, 'bessctl')

def exec_cmd(cmd):
    '''Wrapper around subprocess.check_output to execute a cmd.
    '''
    args = shlex.split(cmd)
    try:
        return subprocess.check_output(args, stderr=subprocess.STDOUT)
    except subprocess.CalledProcessError as e:
        raise e

class TestPlugins(unittest.TestCase):
    """
    A list of test cases to verify if the bessctl plugin loader works
    as expected.
    """

    show_cmd_func_template =\
'''
def show_{s}(cli):
    cli.fout.write('{s}')
'''
    plugin_template =\
'''
import commands
@commands.cmd('show {s}', '{s}')
''' + show_cmd_func_template
    plugin_dir_name = '/bessctlPlugins'
    custom_plugin_dir = '/tmp' + plugin_dir_name
    default_plugin_dir = this_file_dir + plugin_dir_name
   
    @classmethod
    def setUpClass(cls):
        '''Create the resources needed for the testlets to run.
        '''
        os.mkdir(cls.default_plugin_dir)
        os.mkdir(cls.custom_plugin_dir)
        exec_cmd('%s daemon start -m 0 -buffers 2 -skip_root_check' % bessctl +
            ' -i /tmp/bessd.pid')

    @classmethod
    def tearDownClass(cls):
        '''Cleanup the resources once the testlets have run.
        '''
        exec_cmd('%s daemon stop' % bessctl)
        shutil.rmtree(cls.default_plugin_dir)
        shutil.rmtree(cls.custom_plugin_dir)

    def testPluginLoadDefaultDir(self):
        '''Verify that loading a bessctl plugin from the default plugin dir
        works as expected and the show cmd implemented in the plugin produces
        expected output.
        '''
        if os.environ.get('BESSCTL_PLUGINS_DIR'):
            del os.environ['BESSCTL_PLUGINS_DIR']
        plugin_name = 'sample'
        with open('%s/%s.py' % (TestPlugins.default_plugin_dir, plugin_name),
                  'wt') as f:
            f.write(TestPlugins.plugin_template.format(s=plugin_name))
        output = exec_cmd('%s show %s' % (bessctl, plugin_name))
        self.assertEqual(plugin_name, output.decode())
        os.remove(f.name)

    def testPluginLoadCustomDir(self):
        '''Verify that loading a bessctl plugin from a custom plugin dir
        works as expected and the show cmd implemented in the plugin produces
        expected output.
        '''
        plugin_name = 'dummy'
        os.environ['BESSCTL_PLUGINS_DIR'] = TestPlugins.custom_plugin_dir
        with open('%s/%s.py' % (TestPlugins.custom_plugin_dir, plugin_name),
                  'wt') as f:
            f.write(TestPlugins.plugin_template.format(s=plugin_name))
        output = exec_cmd('%s show %s' % (bessctl, plugin_name))
        self.assertEqual(plugin_name, output.decode())
        os.remove(f.name)

    def testPluginLoadInvalidDir(self):
        '''Verify that a bessctl plugin is not loaded if present within a plugin dir
        which is different from the one defined within the environment variable
        BESSCTL_PLUGINS_DIR and on executing the show cmd an exception is raised.
        '''
        plugin_name = 'dummy'
        os.environ['BESSCTL_PLUGINS_DIR'] = "invalid"
        with open('%s/%s.py' % (TestPlugins.custom_plugin_dir, plugin_name),
                  'wt') as f:
            f.write(TestPlugins.plugin_template.format(s=plugin_name))
        expectedError = 'Unknown command "show %s"' % plugin_name
        try:
            output = exec_cmd('%s show %s' % (bessctl, plugin_name))
        except Exception as e:
            self.assertIn(expectedError, str(e.output))
        os.remove(f.name)

    def testPluginLoadBadPluginCode(self):
        '''Verify that a bessctl plugin with bad code is not loaded 
        and on executing the show cmd an exception is raised.
        '''
        plugin_name = 'badPlugin'
        with open('%s/%s.py' % (TestPlugins.default_plugin_dir, plugin_name),
                  'wt') as f:
            f.write(TestPlugins.plugin_template.format(s=plugin_name + '!@#$%^&*()'))
        expectedError = 'Unknown command "show %s"' % plugin_name
        try:
            output = exec_cmd('%s show %s' % (bessctl, plugin_name))
        except Exception as e:
            self.assertIn(expectedError, str(e.output))
        os.remove(f.name)

    def testPluginLoadInvalidPlugin(self):
        '''Verify that a bessctl plugin without the @commands.cmd decorator does not
        extend bessctl and on executing the show cmd an exception is raised.
        '''
        plugin_name = 'invalidPlugin'
        os.environ['BESSCTL_PLUGINS_DIR'] = TestPlugins.custom_plugin_dir
        with open('%s/%s.py' % (TestPlugins.custom_plugin_dir, plugin_name),
                  'wt') as f:
            f.write(TestPlugins.show_cmd_func_template.format(s=plugin_name))
        expectedError = 'Unknown command "show %s"' % plugin_name
        try:
            output = exec_cmd('%s show %s' % (bessctl, plugin_name))
        except Exception as e:
            self.assertIn(expectedError, str(e.output))
        os.remove(f.name)

if __name__ == '__main__':
    unittest.main()
