#!/usr/bin/env python3
# Copyright (c) 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

"""Python bindings for Aresolve, the asynchronous host resolver library.
"""

from enum import Enum, auto
from socket import AF_INET, AF_INET6
import collections
import socket
import weakref
import Arnet
import Tac
import Tracing

traceHandle = Tracing.Handle( 'AresolvePy' )
tReq = traceHandle.trace5
tResp = traceHandle.trace6
tState = traceHandle.trace7
tReact = traceHandle.trace8

# Portability: Linux specific
# EAI errors used in getaddrinfo[_a] calls; from netdb.h on Linux
EAI_SUCCESS = 0
EAI_BADFLAGS = socket.EAI_BADFLAGS
EAI_NONAME = socket.EAI_NONAME
EAI_AGAIN = socket.EAI_AGAIN  # occurs when for e.g., no nameserver configured
EAI_FAIL = socket.EAI_FAIL
EAI_NODATA = socket.EAI_NODATA
EAI_FAMILY = socket.EAI_FAMILY
EAI_SOCKTYPE = socket.EAI_SOCKTYPE
EAI_SERVICE = socket.EAI_SERVICE
EAI_ADDRFAMILY = socket.EAI_ADDRFAMILY
EAI_MEMORY = socket.EAI_MEMORY
EAI_SYSTEM = socket.EAI_SYSTEM
EAI_OVERFLOW = socket.EAI_OVERFLOW
# Not defined in socket module
EAI_INPROGRESS = -100
EAI_CANCELED = -101
EAI_NOTCANCELED = -102
EIA_ALLDONE = -103
EAI_INTR = -104
EAI_IDN_ENCODE = -105

# The possible errors noted by Aresolve.tac
EAI_ERROR_TABLE = {
   EAI_INPROGRESS: 'Processing request in progress',
   socket.EAI_NONAME: 'Name or service not known',
   socket.EAI_AGAIN: 'Temporary failure in name resolution',
   socket.EAI_NODATA: 'Non-recoverable failure in name resolution',
   EAI_CANCELED: 'Request canceled',
   socket.EAI_ADDRFAMILY: 'Address family for NAME not supported',
   socket.EAI_MEMORY: 'Memory allocation failure',
}

EAI_ERRNO_TO_SYMBOL = {
   EAI_SUCCESS: 'EAI_SUCCESS',
   socket.EAI_BADFLAGS: 'EAI_BADFLAGS',
   socket.EAI_NONAME: 'EAI_NONAME',
   socket.EAI_AGAIN: 'EAI_AGAIN',
   socket.EAI_FAIL: 'EAI_FAIL',
   socket.EAI_NODATA: 'EAI_NODATA',
   socket.EAI_FAMILY: 'EAI_FAMILY',
   socket.EAI_SOCKTYPE: 'EAI_SOCKTYPE',
   socket.EAI_SERVICE: 'EAI_SERVICE',
   socket.EAI_ADDRFAMILY: 'EAI_ADDRFAMILY',
   socket.EAI_MEMORY: 'EAI_MEMORY',
   socket.EAI_SYSTEM: 'EAI_SYSTEM',
   socket.EAI_OVERFLOW: 'EAI_OVERFLOW',
   EAI_INPROGRESS: 'EAI_INPROGRESS',
   EAI_CANCELED: 'EAI_CANCELED',
   EAI_NOTCANCELED: 'EAI_NOTCANCELED',
   EIA_ALLDONE: 'EIA_ALLDONE',
   EAI_INTR: 'EAI_INTR',
   EAI_IDN_ENCODE: 'EAI_IDN_ENCODE',
}


# DnsRecord - object passed to all Querier callbacks; attributes of this object are:
#
# name: str, the name queried.
# valid: bool, was the query successful? If False, consult lastError
# lastError: int (one of EAI_ERROR_TABLE.keys(), the most recent error for name
# lastRefresh: float, Tac::Seconds the last time this record updated
# ipAddress: list of str, zero or more IPv4 addresses for name
# ip6Address: list of str, zero or more IPv6 addresses for name
class DnsRecord:
   def __init__( self,
                 name, valid, lastError, lastRefresh,
                 ipAddress, ip6Address ):
      self.name = name
      self.valid = valid
      self.lastError = lastError
      self.lastRefresh = lastRefresh
      self.ipAddress = ipAddress
      self.ip6Address = ip6Address

   def __repr__( self ):
      # pylint: disable-next=consider-using-f-string
      return ( "DnsRecord(name=%s, valid=%s, lastError=%s, lastRefresh=%s, "
               "ipAddress=%s, ip6Address=%s)" % (
               self.name, self.valid, self.lastError, self.lastRefresh,
               self.ipAddress, self.ip6Address ) )

   # Not using __eq__ because we don't check lastRefresh
   def isEqual( self, other ):
      """Compares two DNS records, excluding the lastRefresh attribute.
      Name is also not compared because intention is to just check results."""

      assert isinstance( other, DnsRecord )
      return ( self.valid == other.valid and
               self.lastError == other.lastError and
               sorted( self.ipAddress ) == sorted( other.ipAddress ) and
               sorted( self.ip6Address ) == sorted( other.ip6Address ) )

resolutionRecordTypename = 'Aresolve::ResolutionRecord'

class ResolutionRecordReactor( Tac.Notifiee ):
   """A TACC notifiee for DNS records (results) from Aresolve."""

   notifierTypeName = resolutionRecordTypename

   def __init__( self, aresolveRecordDir, querier, callback ):
      Tac.Notifiee.__init__( self, aresolveRecordDir )
      self.aresolveRecordDir = aresolveRecordDir
      self.callback_ = callback
      self.querier_ = querier

   @Tac.handler( 'lastError' )
   @Tac.handler( 'valid' )
   @Tac.handler( 'lastRefresh' )
   def handler( self ):
      """When attributes change, fire the callback if we have a usable result,

      Aresolve inserts entries into the output directory as it makes
      requests. Initially requests have the valid attribute set False.
      If either valid changes to True or lastError changes from its
      initial value (EAI_SUCCESS), we have a valid DNS record or an error.
      """
      n = self.notifier_
      tReact( 'react name:', repr( n.name ) )
      # Convert the U32 lastError in Sysdb to the signed EAI_* defines in netdb.h
      lastError = getSigned( n.lastError )
      error = False
      if not n.valid:
         # The record not valid, meaning in error or incomplete.
         # If we're in error and not being asked to try again, we can notify.
         if lastError:
            self.querier_.counter[ lastError ] += 1
            if lastError != EAI_AGAIN:
               error = True
      else:
         # Completed and lastError is EAI_SUCCESS
         self.querier_.counter[ lastError ] += 1
      # Notify the registered callback of any completed result.
      if n.valid or error:
         self.callback_( n )

class CallbackStrategy( Enum ):
   """Determines the condition upon which the callback gets executed:

      Always: callback is always executed.
      OnRecordUpdate: executes callback when there's difference between
         new and previous DNS response, e.g. different set of IPs. Note
         that order of IPs is insignificant.
      OnUsedIpMissing: executes callback when the used IP is missing in
         the new DNS response. The used IPs have to be manually set."""
   Always = auto()
   OnRecordUpdate = auto()
   OnUsedIpMissing = auto()

class Querier:
   """Aresolve Querier: An asynchronous DNS querier.

   This object is a thin wrapper around the C++ Aresolve library and
   manages its state machine and request and response directories.

   Querier provides a callback interface (implemented via
   ResolutionRecordReactor) and a polling interface.

   The callback interface is used by specifying a default callback
   supplied as the 'callback' initializer argument or, if you prefer,
   callbacks may be supplied for each query (see the callback argument
   to the host() method). If no per-query callback is specified, the
   default callback passed at initialization will be used.

   When full response are received, your callback is run with the DNS
   result, a DnsRecord named tuple with the same attributes as an
   Aresolve::ResolutionRecord.

   If no default callback or per query callbacks are supplied, the
   only way to receive results is to poll, with the result* methods.

   After first resolution (either successfully, or in error) queries
   will be re-run every longTime seconds and your callback will be
   fired again (optionally fitered on result changes).

   Use this behaviour to dynamically update configuration files or
   agent state. For example, when a temporary DNS error that happens
   during startup corrects itself, your agent can respond to the new
   DnsRecord received by its callback to start working without retry
   code in any agent.
   """

   def __init__( self, callback=None, shortTime=None, longTime=None,
                 callbackStrategy=CallbackStrategy.Always, useDnsQuerySm=False ):
      """Creates an Aresolve (TACC DNS resolver) instance and a reactor.

      The shortTime and longTime values default to those defined in the
      Aresolve::AresolveSm implementation, generally 10 and 300 seconds,
      respectively.

      If no callback argument is provided, a callback must be supplied on
      each query (e.g., the host() method), or the polling interface must
      be used (e.g., the resultHost() method) to obtain query results.

      Args:
         callback: a callable, the default callback to be called when a DNS
            result arrives. It is given one argument, a DnsRecord being
            the DNS result record. If no callback is provided, polling
            mode is the only way to retrieve DNS result records.
         shortTime: an int, the number of seconds between ticks of the
            Aresolve short clock. This is the retry period on errors while
            Aresolve is still attempting to resolve.
         longTime: an int, the number of seconds between ticks of the
            Aresolve long clock. This is the query refresh period.
            When the DNS response arrives, the record will be sent to the
            query's callback depending on callbackStrategy.
         callbackStrategy: an enum, strategy determining whether the callback
            should be executed.
         useDnsQuerySm: a bool, if True uses DnsQuerySm instead of AresolveSm.
      """
      self.dnsOut_ = Tac.newInstance( 'Aresolve::ResolutionRecordDir',
                                      'AresolveResolutionRecord' )
      self.dnsIn_ = Tac.newInstance( 'Aresolve::ResolutionRequestDir',
                                     'AresolveResolutionRequest' )

      if useDnsQuerySm:
         self.ns = Tac.newInstance( "Arnet::NamespaceName", "default" )
         self.netConfigDir = Tac.newInstance( "System::NetConfig", "group" )
         self.dnsSm_ = Tac.newInstance( 'Aresolve::DnsQuerySm',
                                         self.dnsOut_, self.dnsIn_,
                                         self.ns, self.netConfigDir )
      else:
         self.dnsSm_ = Tac.newInstance( 'Aresolve::AresolveSm',
                                         self.dnsOut_, self.dnsIn_ )

      if shortTime:
         self.dnsSm_.shortTime = shortTime
      if longTime:
         self.dnsSm_.longTime = longTime

      self.callbackStrategy = callbackStrategy
      self.usedIPs = {}

      self.counter = collections.Counter()
      self.callback_ = callback
      self.dnsRecords = {}
      self.reactor_ = Tac.collectionChangeReactor(
         self.dnsOut_.record,
         ResolutionRecordReactor,
         reactorArgs=( weakref.proxy( self ), self._queryCallback ) )

   def host( self, name ):
      """Starts an A/AAAA DNS query for the host 'name' (str).

      Args:
        name: str, the DNS name to query.
      """
      tReq( 'request name:', repr( name ) )
      # Clear any cached results before re-resolving.
      self.dnsRecords.pop( name, None )

      self.dnsIn_.request[ name ] = self.dnsIn_.request.get( name,  0 ) + 1

   def resultHost( self, name ): # pylint: disable=inconsistent-return-statements
      """Returns the current response record for host 'name'.

      Returns None if there is no current query for name, else returns a DnsRecord.

      This interface can be used by code not running in a TAC activity loop to poll
      for DNS results.
      """
      resolutionRecord = self.dnsOut_.record.get( name )
      tResp( 'result name:', name, 'record:', resolutionRecord )
      if resolutionRecord is not None:
         return self._convert( resolutionRecord )

   def _convert( self, record ):
      """Converts a Aresolve::ResolutionRecord to a DnsRecord."""
      return DnsRecord( name=record.hostname,
                        lastRefresh=record.lastRefresh,
                        lastError=getSigned( record.lastError ),
                        valid=record.valid,
                        ipAddress=list( record.ipAddress ),
                        ip6Address=[ i.stringValue for i in record.ip6Address ] )

   def _queryCallback( self, record ):
      """Callback wrapper. Converts Aresolve::ResolutionRecord to DnsRecord."""
      tResp( 'received DNS record for name:', record.name )
      newDnsRecord = self._convert( record )
      previousDnsRecord = self.dnsRecords.get( newDnsRecord.name )

      if self.callbackStrategy == CallbackStrategy.OnRecordUpdate:
         isRecordUpdated = ( not bool( previousDnsRecord ) or
                             not previousDnsRecord.isEqual( newDnsRecord ) )
         executeCallback = isRecordUpdated
      elif self.callbackStrategy == CallbackStrategy.OnUsedIpMissing:
         recordIPs = newDnsRecord.ipAddress + newDnsRecord.ip6Address
         executeCallback = ( record.name not in self.usedIPs or
                             self.usedIPs[ record.name ] not in recordIPs )
         if executeCallback:
            # We execute callback only if our IP is missing in new DNS record,
            # or there's no used IP for the given hostname.
            # In such case, drop it from the list, this makes it clear to the
            # callback receiver that the DNS response has changed and it's time to
            # select a new IP. Otherwise, it would have to compare the records by
            # itself to find out to whether the original IP is still okay to use.
            self.usedIPs.pop( record.name, None )
      elif self.callbackStrategy == CallbackStrategy.Always:
         executeCallback = True
      else:
         assert False, 'Unhandled callback strategy'

      self.dnsRecords[ newDnsRecord.name ] = newDnsRecord

      tResp( f'executeCallback: {executeCallback}' )
      if self.callback_ and executeCallback:
         self.callback_( newDnsRecord )

   def finishHost( self, name ):
      """Finishes an A/AAAA DNS query for (str) name.

      Finishing a query stops it from being re-resolved by the Aresolve
      state machine.
      """
      tState( 'finish:', name )
      val = self.dnsIn_.request.get( name )
      if val is not None:
         if val:
            self.dnsIn_.request[ name ] -= 1
            # Re-set the variable again in case we've hit zero
            val = self.dnsIn_.request[ name ]
         if not val:
            tState( 'delete:', name )
            # No more referers, so kill the query.
            try:
               del self.dnsIn_.request[ name ]
            except KeyError:
               pass

   def handleIpAddrOrHostname( self, ipOrHost, addrFamily=None ):
      """Handles a configuration change of the address of the host.

      ipOrHost : string - contains IPv4/6 address or hostname/FQDN
      addrFamily : socket.AddressFamily - forces the return of particular address

      Returns a tuple of (hostname, list of IP addresses).
      If ipOrHost is an IP address, first element will be set to None and second
      element will be a list containing that IP address.
      If ipOrHost is a hostname, first element will be set to it. If the hostname
      is already resolved, the second element will be a list of IP addresses.
      Otherwise second element will be set to None.
      If addrFamily is set, the returned list of IP addresses will be one that
      matches it, otherwise None will be returned."""
      def isIpAddress( addr ):
         try:
            Arnet.IpGenAddr( addr )
            return True
         except ValueError:
            return False
      
      if not ipOrHost:
         return None, None

      if isIpAddress( ipOrHost ):
         # No querying if ipOrHost is an IP address
         tState( f"ipAddrOrHostname: {ipOrHost} is an IP address" )
         return None, [ ipOrHost ]
      else:
         hostname = ipOrHost
         tState( f"ipAddrOrHostname: {hostname} is not an IP address" )
         record = self.dnsRecords.get( hostname )
         if record:
            # Hostname already resolved
            if addrFamily == AF_INET and record.ipAddress:
               return hostname, record.ipAddress
            elif addrFamily == AF_INET6 and record.ip6Address:
               return hostname, record.ip6Address
            elif not addrFamily:
               if record.ipAddress:
                  return hostname, record.ipAddress
               elif record.ip6Address:
                  return hostname, record.ip6Address
         elif hostname not in self.dnsIn_.request:
            tState( f"Resolve {hostname}" )
            self.host( hostname )
         # We might already have registered the hostname which wasn't resolved just
         # yet. In such case just let Aresolve (re)try resolving it.
         return hostname, None

   def removeRecord( self, hostname ):
      """Removes DNS records for the given hostname"""
      if hostname in self.dnsRecords:
         del self.dnsRecords[ hostname ]

   def clear( self ):
      """Stops all DNS queries created by this instance."""
      tState( "clearing all requests/responses" )
      self.dnsIn_.request.clear()
      # dnsOut_.record is cleared for us by Afetch.tin when we clear requests


def getSigned( number, numBits=32 ):
   """Gets a signed number from an unsigned number by two's complement.

   Args:
      number: int, the unsigned input value
      numBits: int, the length in bits of the input value

   Returns:
      An int, the two's complement of number for length numBits
   """
   mask = ( 2 ** numBits ) - 1
   if number & ( 1 << ( numBits - 1 ) ):
      return number | ~mask
   else:
      return number & mask


def gaiSterror( errno, signErrno=True ):
   """Emulates the C gai_sterror(3) call, returning a str for the int errno."""
   if signErrno:
      errno = getSigned( errno )
   return EAI_ERROR_TABLE.get(
      errno, EAI_ERRNO_TO_SYMBOL.get( errno, 'UNKNOWN_ERROR' ) )
