# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

"""Discover ZTN controllers with MDNS.
"""

from __future__ import absolute_import

import dns, dns.name, dns.message, dns.rdataclass # pylint: disable=multiple-imports
import logging
import select
import socket
import socketserver
import struct
import time

from ztn.controllers import ZtnmController
from ztn.net import Interface
import ztn.settings

UNSUPPORTED_VERSION = "1"
DEFAULT_MDNS_CLIENT_PORT = 0

class _UDPHandler(socketserver.BaseRequestHandler):

    SRV_SUFFIX = ('_v%d' % ztn.settings.ZTNM_VERSION, '_sub', '_zerotouch', '_tcp', 'local', '',)

    def __init__(self, request, client_address, server,
                 ctx, log=None):
        self.log = log or logging.getLogger(self.__class__.__name__)
        self.ctx = ctx
        socketserver.BaseRequestHandler.__init__(self, request, client_address, server)

    def handle(self):
        data = self.request[0]
        self.log.info("client address %s", self.client_address)
        self.log.info("received %d bytes", len(data))

        msg = dns.message.from_wire(data)
        if msg.id != self.ctx.q.id:
            self.log.debug("skipping message id %s", msg.id)
            return
        if not msg.question:
            self.log.debug("skipping empty question")
            return
        if msg.question[0] != self.ctx.q.question[0]:
            self.log.debug("mismatched question")
            return
        if self.client_address == self.ctx.server.socket.getsockname():
            self.log.debug("skipping our own query %s", msg)
            return

        tgt = None
        port = None
        address = None
        txtvers = None
        zone = None
        ssl = False
        timestamp = None

        self.log.debug("query %s %s", self.client_address, msg)
        for ans in msg.answer:

            if (ans.rdtype == dns.rdatatype.SRV
                and ans.rdclass == dns.rdataclass.IN
                and ans.name == self.ctx.q.question[0].name):
                for item in ans.items:
                    port = item.port
                    tgt = item.target
                continue

            # IPv4 server assignment
            if (ans.rdtype == dns.rdatatype.A
                and ans.rdclass == dns.rdataclass.IN
                and tgt is not None
                and tgt == ans.name):
                for item in ans.items:
                    address = item.address
                continue

            # IPv6 server assignment
            if (ans.rdtype == dns.rdatatype.AAAA
                and ans.rdclass == dns.rdataclass.IN
                and tgt is not None
                and tgt == ans.name):
                for item in ans.items:
                    address = item.address
                continue

            if (ans.rdtype == dns.rdatatype.TXT
                and ans.rdclass == dns.rdataclass.IN):
                for item in ans.items:
                    for s in item.strings:
                        if s.startswith(b'txtvers='):
                            txtvers = s[8:]
                        elif s.startswith(b'zone='):
                            zone = s[5:]
                        elif s == b'ssl':
                            ssl = True
                        elif s.startswith(b'ssl='):
                            val = s[4:]
                            ssl = (val == '1')
                        elif s.startswith(b'timestamp='):
                            try:
                                timestamp = int(s[10:])
                            except ValueError as what:
                                self.log.warning("invalid timestamp: %s", str(what))
                        else:
                            self.log.warning("unhandled TXT field %s", s)
                continue

            self.log.warning("unhandled answer: %s", ans)

        txtvers = txtvers or UNSUPPORTED_VERSION

        # v2 and v3 mDNS responses are compatible
        if (txtvers in ("2", "3",)
            and str(ztn.settings.ZTNM_VERSION) not in ("2", "3",)):
            return

        if address is not None and port is not None:
            self.ctx.handle_server(address, server_port=port, zone=zone, ssl=ssl,
                                   timestamp=timestamp)

class _UDPHandlerFactory:

    handler_klass = _UDPHandler

    def __init__(self, ctx, log=None):
        self.log = log or logging.getLogger(self.__class__.__name__)
        self.ctx = ctx

    def create_handler(self, request, client_address, server):
        return self.handler_klass(request, client_address, server,
                                  self.ctx, log=self.log)

class _Server4:

    def __init__(self, interface, bind_port=None, log=None):
        self.log = log or logging.getLogger(self.__class__.__name__)
        self.interface = interface
        self.bind_port = bind_port  # port used by the ZTN/mDNS client

        self.server = None
        self.handler_factory = None

        self.q = None
        self.servers = []

    def open(self):

        self.handler_factory = _UDPHandlerFactory(self, log=self.log)

        self.server = socketserver.UDPServer((self.interface.v4, self.bind_port,),
                                             self.handler_factory.create_handler)

        # see http://stackoverflow.com/questions/603852/multicast-in-python
        mreq = struct.pack("4sl", socket.inet_aton('224.0.0.251'), socket.INADDR_ANY)
        self.server.socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
        self.log.info("querying on %s", self.server.socket.getsockname())

    @property
    def zone(self):
        return dns.name.Name((self.interface.mac,) + _UDPHandler.SRV_SUFFIX)

    def query(self):
        self.q = dns.message.make_query(self.zone, dns.rdatatype.SRV, dns.rdataclass.IN)
        buf = self.q.to_wire()
        txt = self.q.to_text()
        self.log.debug("sending query:")
        self.log.debug(txt)
        self.server.socket.sendto(buf, ('224.0.0.251', 5353))

    def handle_server(self, address, server_port, zone=None, ssl=False, timestamp=None):
        if zone is not None:
            self.log.info("found server at [%s%%%s]:%d (ssl=%s)",
                          address, zone, server_port, ssl)
            self.servers.append(ZtnmController(address + '%' + zone, server_port, ssl, timestamp))
        else:
            self.log.info("found server at %s:%d (ssl=%s)",
                          address, server_port, ssl)
            self.servers.append(ZtnmController(address, server_port, ssl, timestamp))

        # do not close, keep gathering servers
        ##self.server.server_close()
        ##self.server = None

    def serve_until(self, future):
        fno = self.server.fileno()
        while self.server is not None:
            now = time.time()
            if now > future:
                break
            wait = min(future-time.time(), 0.5)
            rfd, _, _ = select.select([fno], [], [], wait)
            if fno in rfd:
                self.log.debug("received a packet")
                self.server.handle_request()

class _UDPV6Server(socketserver.UDPServer):
    """Simple V6 UDP server.

    See
    http://www.thecodingforums.com/threads/python-socketserver-with-ipv6.681964/
    """
    address_family = socket.AF_INET6

class _Server6:

    def __init__(self, interface, bind_port=None, log=None):
        self.log = log or logging.getLogger(self.__class__.__name__)
        self.interface = interface
        self.bind_port = bind_port  # port used by the ZTN/mDNS client
        self.server = None
        self.handler_factory = None

        self.q = None
        self.address = None
        self.ssl = False

        self.servers = []

    def open(self):

        self.handler_factory = _UDPHandlerFactory(self, log=self.log)

        self.log.info("starting mDNS responder on %s", self.interface.v6)
        self.server = _UDPV6Server((self.interface.v6, self.bind_port, 0, self.interface.ifindex),
                                  self.handler_factory.create_handler)

        # see http://code.activestate.com/recipes/442490-ipv6-multicast/
        ifn = struct.pack("I", self.interface.ifindex)
        mreq = socket.inet_pton(socket.AF_INET6, "ff02::fb") + ifn
        self.server.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, mreq)

        self.log.info("querying on %s", self.server.socket.getsockname())

    @property
    def zone(self):
        mac_bytes = self.interface.mac.split(':')
        mac_bytes = [int(x, 16) for x in mac_bytes]
        mac_bytes[0] ^= 0x02
        mac_args = mac_bytes + [self.interface.name,]
        lbl = ("fe80::%02x%02x:%02xff:fe%02x:%02x%02x%%%s"
               % tuple(mac_args))
        return dns.name.Name((lbl,) + _UDPHandler.SRV_SUFFIX)

    def query(self):
        self.q = dns.message.make_query(self.zone, dns.rdatatype.SRV, dns.rdataclass.IN)
        buf = self.q.to_wire()
        txt = self.q.to_text()
        self.log.debug("sending query:")
        self.log.debug(txt)
        self.server.socket.sendto(buf, ('ff02::fb', 5353))

    def handle_server(self, address, server_port, zone=None, ssl=False, timestamp=None):
        if zone is not None:
            self.log.info("found server at [%s%%%s]:%d (ssl=%s)",
                          address, zone, server_port, ssl,)
            self.servers.append(ZtnmController(address + '%' + zone, server_port, ssl, timestamp))
        else:
            self.log.info("found server at %s:%d (ssl=%s)",
                          address, server_port, ssl)
            self.servers.append(ZtnmController(address, server_port, ssl, timestamp))

        # do not close, keep gathering servers
        ##self.server.server_close()
        ##self.server = None

    def serve_until(self, future):
        fno = self.server.fileno()
        while self.server is not None:
            now = time.time()
            if now > future:
                break
            wait = min(future-time.time(), 0.5)
            rfd, _, _ = select.select([fno], [], [], wait)
            if fno in rfd:
                self.log.debug("received a packet")
                self.server.handle_request()



class MdnsDiscovery:

    def __init__(self, interfaceName, response_wait=5, log=None, client_port=None):
        self.log = log or logging.getLogger(self.__class__.__name__)

        self.interface = Interface(interfaceName)
        self.response_wait = response_wait
        self.controllers = []
        if client_port is None:
            client_port = DEFAULT_MDNS_CLIENT_PORT
        self.client_port = client_port

    def run(self):

        now = time.time()
        future = now + self.response_wait
        while True:

            now = time.time()
            if now > future:
                break

            code = self.discoverIpv6LinkLocal()
            if code != 0:
                return code

            code = self.discoverIpv4()
            if code != 0:
                return code

            future2 = now + 0.25
            now = time.time()
            if now < future2:
                time.sleep(future2-now)

        if self.controllers:
            self.log.info("found %d ztn controllers via mdns in %ds:\n    %s",
                          len(self.controllers), self.response_wait,
                          "\n    ".join([str(ep) for ep in self.controllers]))
        else:
            self.log.error("no ztn controllers via mdns in %ds", self.response_wait)

        return 0

    def shutdown(self):
        pass

    def updateControllers(self, *controllers):
        p = {}
        for c in self.controllers:
            p[(c.addr, c.port,)] = c
        for c in controllers:
            p[(c.addr, c.port,)] = c
        self.controllers = list(p.values())

    def discoverIpv4(self):
        if self.interface.v4 is None:
            self.log.warning("no IPv4 address, skipping IPv4 discovery")
            return 0

        svr = _Server4(self.interface, bind_port=self.client_port, log=self.log)
        svr.open()
        svr.query()
        svr.serve_until(time.time() + 1.0)
        if svr.server is not None:
            svr.server.server_close()

        self.updateControllers(*svr.servers)

        return 0

    def discoverIpv6LinkLocal(self):

        if self.interface.v6 is None:
            self.log.warning("no IPv6 link-local address, skipping IPv6 discovery")
            return 0

        svr = _Server6(self.interface, bind_port=self.client_port, log=self.log)
        try:
            svr.open()
        except socket.error as what:
            self.log.warning("cannot start discovery: %s", str(what))
            return 0
        svr.query()
        svr.serve_until(time.time() + 1.0)
        if svr.server is not None:
            svr.server.server_close()

        self.updateControllers(*svr.servers)

        return 0


if __name__ == '__main__':

    logging.basicConfig()
    logger = logging.getLogger('discovery.mdns')
    logger.setLevel(logging.INFO)

    import argparse
    ap = argparse.ArgumentParser('discovery.mdns')
    ap.add_argument("interface", help='The interface on which to run discovery.')
    ops = ap.parse_args()

    MdnsDiscovery(ops.interface, log=logger).run()
