#!/usr/bin/env python3

# Copyright (c) 2014-2016, The Regents of the University of California.
# Copyright (c) 2016-2017, Nefeli Networks, Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the names of the copyright holders nor the names of their
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

from __future__ import print_function
from __future__ import absolute_import
import sys
import os
import os.path
import io
import tempfile
import time
import cli
import commands

# Suppress scapy IPv6 warning (must be done before importing scapy module)
import logging
logging.getLogger("scapy.runtime").setLevel(logging.ERROR)

try:
    this_dir = os.path.dirname(os.path.realpath(__file__))
    sys.path.insert(1, os.path.join(this_dir, '..'))
    from pybess.bess import *
except ImportError:
    print('Cannot import the API module (pybess)', file=sys.stderr)
    raise


class BESSCLI(cli.CLI):

    def __init__(self, bess, cmd_db, **kwargs):
        self.bess = bess
        self.cmd_db = cmd_db
        self.this_dir = this_dir

        super(BESSCLI, self).__init__(self.cmd_db.cmdlist, **kwargs)

    def get_var_attrs(self, var_token, partial_word):
        return self.cmd_db.get_var_attrs(self, var_token, partial_word)

    def split_var(self, var_type, line):
        try:
            return self.cmd_db.split_var(self, var_type, line)
        except self.InternalError:
            return super(BESSCLI, self).split_var(var_type, line)

    def bind_var(self, var_type, line):
        try:
            return self.cmd_db.bind_var(self, var_type, line)
        except self.InternalError:
            return super(BESSCLI, self).bind_var(var_type, line)

    def print_banner(self):
        self.fout.write('Type "help" for more information.\n')

    def get_default_args(self):
        return [self]

    def _handle_broken_connection(self):
        peer = self.bess.peer
        if peer.startswith('localhost') or peer.startswith('127.'):
            self._print_crashlog()
        self.bess.disconnect()

    def call_func(self, func, args):
        try:
            super(BESSCLI, self).call_func(func, args)

        except self.bess.APIError as e:
            self.err(e)
            raise self.HandledError()

        except self.bess.RPCError as e:
            self.err('RPC failed to  - {}'.format(
                self.bess.peer, str(e)))

            self._handle_broken_connection()
            raise self.HandledError()

        except self.bess.Error as e:
            self.err(e.errmsg)

            if e.code in errno.errorcode:
                err_code = errno.errorcode[e.code]
            else:
                err_code = '<unknown>'

            self.ferr.write('  BESS daemon response - errno=%d (%s: %s)\n' %
                            (e.code, err_code, os.strerror(e.code)))
            for k, v in sorted(e.info.items()):
                self.ferr.write('%12s: %s\n' % (str(k), str(v)))

            raise self.HandledError()

    def _print_crashlog(self):
        try:
            log_path = tempfile.gettempdir() + '/bessd_crash.log'
            log = open(log_path).read()
            ctime = time.ctime(os.path.getmtime(log_path))
            self.ferr.write('From {} ({}):\n{}'.format(log_path, ctime, log))
        except Exception as e:
            self.ferr.write('Abruptly disconnected from bessd, but crashlog is '
                            'not available:\n%s\n' % (str(e)))

    def loop(self):
        super(BESSCLI, self).loop()
        self.bess.disconnect()

    def get_prompt(self):
        if self.bess.is_connected():
            return '{} $ '.format(self.bess.peer)

        if self.bess.is_connection_broken():
            self._handle_broken_connection()

        return '<disconnected> $ '


def run_cli(instream=sys.stdin, grpc_url=None):
    s = BESS()
    cli = BESSCLI(s, commands, fin=instream)

    try:
        s.connect(grpc_url=grpc_url)
    except BESS.RPCError as e:
        if cli.interactive:
            cli.ferr.write(str(e) + '\n')
            cli.ferr.write('Perhaps bessd daemon is not running locally? '
                           'Try "daemon start".\n')

    cli.loop()

    # end of loop due to error? let them know which command failed.
    if cli.stop_loop:
        if cli.last_cmd:
            cli.ferr.write('  Command failed: %s\n' % cli.last_cmd)
        sys.exit(1)


def main():
    if len(sys.argv) == 1:
        run_cli()
    else:
        cmds = []
        line_buf = []

        grpc_url = None
        if "--grpc_url" in sys.argv:
           idx = sys.argv.index( "--grpc_url" )
           grpc_url = sys.argv[ idx + 1 ]
           del sys.argv[ idx:idx + 2 ]

        for arg in sys.argv[1:]:
            if arg == '--':
                cmds.append(' '.join(line_buf))
                line_buf = []
            else:
                line_buf.append(arg)

        cmds.append(u' '.join(line_buf))
        run_cli(io.StringIO('\n'.join(cmds)), grpc_url=grpc_url)

if __name__ == '__main__':
    main()
