# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

from __future__ import absolute_import, print_function

import logging

from ztn.controllers.discovery.mdns import MdnsDiscovery
from ztn.controllers.discovery.peer import PeerDiscovery
from ztn.manifest import ManifestRequest
import ztn.settings

class ZtnmTransactionManager:

    def __init__(self, log, mac, interface, platform_name, mdns_client_port=None):
        self.log = log or ztn.settings.ZTNM_DEFAULT_LOGGER

        #
        # All methods of finding a controller and requesting
        # a manifest require these inputs.
        #
        self.mac = mac or ztn.settings.ZTNM_DEFAULT_MAC
        self.interface = interface or ztn.settings.ZTNM_MANAGEMENT_INTERFACE
        self.platform_name = platform_name or ztn.settings.ZTNM_PLATFORM_NAME
        self.static_addresses = []
        self.mdns_client_port = mdns_client_port

        if self.interface and self.mac is None:
            # Read the mac from sysfs
            with open(f"/sys/class/net/{self.interface}/address", encoding='utf-8') as f:
                self.mac = f.read().strip()

        # Add static addresses from local configuration if available.
        self.add_static_addresses(ztn.settings.ZTNM_STATIC_ADDRESSES)

    def add_static_addresses(self, address_or_list):
        if isinstance(address_or_list, list):
            self.static_addresses += address_or_list
        else:
            self.static_addresses.append(address_or_list)

    def transact_controllers(self, controllers, timeout):
        mr = ManifestRequest(self.log, self.interface)
        mac_or_interface = self.mac or self.interface
        for controller in controllers:
            mf = mr.get(controller, mac_or_interface,
                        self.platform_name,
                        timeout=timeout)
            if mf:
                return (mf, controller)
        return (None, None)

    def discover_controllers_mdns(self):
        #
        # The MDNS discovery process yields a list of possible
        #
        obj = MdnsDiscovery(self.interface, client_port=self.mdns_client_port, log=self.log)
        obj.run()
        return obj.controllers


    def discover_controllers_peer(self):
        obj = PeerDiscovery(self.interface, log=self.log)
        obj.run()
        return obj.controllers


    def transact(self, attempts,
                 static=True, mdns=True, peer=True,
                 manifest_request_timeout=10):
        #
        # The ZTN transaction process can be broken down into two steps:
        #
        # * Collect possible addresses which *may* return a manifest
        # * Ask those addresses for a manifest.
        # * Return when a manifest is received from one of those addresses.
        #
        #
        # Addresses are collected as follows:
        #
        # - Static Addresses
        #   These come from local user configuration, or from DHCP,
        #   or from known possible or previous addresses, etc.
        #
        #   Static addresses should be added by the client prior to
        #   requesting a transaction (collected from wherever they
        #   are in the local environment) using .add_static_address()
        #
        # - Addresses discovered via MDNS Queries
        #   - In link-local environments the ZTN controller is discovered
        #     through an MDNS request.
        #     - The MDNS response contains the address.
        #     - The address is used to request a manifest.
        #
        # - Addresses discovered via neighbor discovery.
        #   - In cases where the MDNS request/response fails then all
        #     neigbhors are discovered and just asked for it.
        #     only the controller for this device should respond.
        #
        #
        if len(self.static_addresses) == 0:
            static = False

        if not (static or mdns or peer):
            raise ValueError("At least one of static, mdns, or peer must be True.")


        (result, controller, method) = (None, None, None)

        for attempt in range(0, attempts):

            if static:
                # Try all static addresses
                (result, controller) = self.transact_controllers(self.static_addresses,
                                                                 manifest_request_timeout)

                if result:
                    method = 'static'
                    break

            if mdns:
                # Gather MDNS responses
                (result, controller) = self.transact_controllers(self.discover_controllers_mdns(),
                                                                 manifest_request_timeout)

                if result:
                    method = 'mdns'
                    break


            if peer:
                # Gather peer responses
                (result, controller) = self.transact_controllers(self.discover_controllers_peer(),
                                                                 manifest_request_timeout)

                if result:
                    method = 'peer'
                    break

        if result:
            self.log.info("succeeded on attempt %d method=%s %s" % (attempt, method, controller))
        else:
            self.log.error("transaction failed after %d attempts." % attempts)

        return (result, controller, method)


if __name__ == '__main__':
    logging.basicConfig()
    logger = logging.getLogger('transaction')
    logger.setLevel(logging.INFO)

    txnmgr = ZtnmTransactionManager(logger, '52:54:00:a2:4e:94', 'ma1',
                                    'x86-64-bigswitch-bs3240-r0')
    txnmgr.add_static_addresses('10.2.1.48')
    print(txnmgr.transact(3))
