# Copyright (c) 2020 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

"""Network Utilities"""

from __future__ import absolute_import

import logging
import os
import random
import socket
import struct
import time

import netifaces
import pyroute2

ICMP_ECHO_REQUEST = 8
ICMPV6_ECHO_REQUEST = 128
SO_BINDTODEVICE = 25

def checksum_(source_string):
    assert isinstance(source_string, bytes)
    # I'm not too confident that this is right but testing seems to
    # suggest that it gives the same answers as in_cksum_ in ping.c.
    sum_ = 0
    count_to = (len(source_string) // 2) * 2
    count = 0
    while count < count_to:
        this_val = source_string[count + 1]*256 + source_string[count]
        sum_ = sum_ + this_val
        sum_ = sum_ & 0xffffffff # Necessary?
        count = count + 2
    if count_to < len(source_string):
        sum_ = sum_ + source_string[len(source_string) - 1]
        sum_ = sum_ & 0xffffffff # Necessary?
    sum_ = (sum_ >> 16) + (sum_ & 0xffff)
    sum_ = sum_ + (sum_ >> 16)
    answer = ~sum_
    answer = answer & 0xffff
    # Swap bytes. Bugger me if I know why.
    answer = answer >> 8 | (answer << 8 & 0xff00)
    return answer


def create_packet(pkt_id, typ=ICMP_ECHO_REQUEST):
    """Create a new echo request packet based on the given "id"."""
    # Header is type (8), code (8), checksum_ (16), id (16), sequence (16)
    header = struct.pack('BBHHh', typ, 0, 0, pkt_id, 1)
    data = 192 * b'Q'
    # Calculate the checksum_ on the data and the dummy header.
    my_checksum_ = checksum_(header + data)
    # Now that we have the right checksum_, we put that in. It's just easier
    # to make up a new header than to stuff it into the dummy.
    header = struct.pack('BBHHh', typ, 0,
                         socket.htons(my_checksum_), pkt_id, 1)
    return header + data

def ping4(source_address=None, source_intf=None):
    sock = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_ICMP)

    # See http://stackoverflow.com/questions/12607516/python-udp-broadcast-not-sending
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)

    if source_address is not None:
        sock.bind((source_address, 0,))
    elif source_intf is not None:
        # See http://stackoverflow.com/questions/7221577/help-with-python-setsockopt
        sock.setsockopt(socket.SOL_SOCKET, SO_BINDTODEVICE,
                        (source_intf + '\0').encode())

    packet_id = int((os.getpid() * random.random()) % 65535)
    packet = create_packet(packet_id)
    while packet:
        # The icmp protocol does not use a port, but the function
        # below expects it, so we just give it a dummy port.
        sent = sock.sendto(packet, ("255.255.255.255", socket.IPPROTO_ICMP))
        packet = packet[sent:]


# See https://github.com/gvnn3/PCS/blob/master/scripts/ping6.py
def ping6(source_address=None, source_intf=None):
    sock = socket.socket(socket.AF_INET6, socket.SOCK_RAW, socket.IPPROTO_ICMPV6)

    if source_address is not None:
        sock.bind((source_address, 0,))
    elif source_intf is not None:
        sock.setsockopt(socket.SOL_SOCKET, SO_BINDTODEVICE,
                        (source_intf + '\0').encode())

    packet_id = int((os.getpid() * random.random()) % 65535)
    packet = create_packet(packet_id, typ=ICMPV6_ECHO_REQUEST)
    while packet:
        # The icmp protocol does not use a port, but the function
        # below expects it, so we just give it a dummy port.
        sent = sock.sendto(packet, ("ff02::1", socket.IPPROTO_ICMPV6))
        packet = packet[sent:]


# pylint: disable=bad-whitespace

# see /usr/include/linux/neighbour.h
NUD_INCOMPLETE  = 0x01
NUD_REACHABLE   = 0x02
NUD_STALE       = 0x04
NUD_DELAY       = 0x08
NUD_PROBE       = 0x10
NUD_FAILED      = 0x20

# Dummy states
NUD_NOARP       = 0x40
NUD_PERMANENT   = 0x80
NUD_NONE        = 0x00

def get_neighbours4():
    nl = []
    iproute = pyroute2.IPRoute()
    for e in iproute.get_neighbours(family=socket.AF_INET):
        attrs = dict(e['attrs'])
        dst = attrs.get('NDA_DST', None)
        if dst is not None and e['state'] in (NUD_REACHABLE,):
            nl.append(dst)
    iproute.close()
    return nl

def get_neighbours6():
    nl = []
    iproute = pyroute2.IPRoute()
    for e in iproute.get_neighbours(family=socket.AF_INET6):
        attrs = dict(e['attrs'])
        dst = attrs.get('NDA_DST', None)
        idx = e.get('ifindex', None)
        if (dst is not None
            and idx is not None
            and e['state'] in (NUD_REACHABLE,)):
            nl.append(dst + '%' + str(idx))
    iproute.close()
    return nl



def get_neighbours(interface_name, wait):

    ping6(source_intf=interface_name)
    ping4(source_intf=interface_name)

    if wait > 0:
        time.sleep(wait)

    return dict(v4=get_neighbours4(), v6=get_neighbours6())


class Interface:

    def __init__(self, name):
        self.name = name

        info = netifaces.ifaddresses(name)

        empty_family = [{}]
        self.mac = info.get(netifaces.AF_LINK, empty_family)[0].get('addr', None)
        self.v4 = info.get(netifaces.AF_INET, empty_family)[0].get('addr', None)

        self.v6 = None
        for i in info.get(netifaces.AF_INET6, empty_family):
            # Link-local addresses only.
            addr = i.get('addr', None)
            if addr and addr.startswith("fe80"):
                self.v6 = addr
                break

        iproute = pyroute2.IPRoute()
        self.ifindex = iproute.link_lookup(ifname=name)[0]
        iproute.close()


if __name__ == '__main__':

    import argparse
    import pprint
    ap = argparse.ArgumentParser()
    ap.add_argument("interface")
    ap.add_argument("wait", type=int)
    ops = ap.parse_args()
    pprint.pprint(get_neighbours(ops.interface, ops.wait))
