#!/usr/bin/env arista-python
# Copyright (c) 2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

"""Arastra networking utilities.

This module contains utilities relating to networking, in particular, for
manipulating IP addresses.
"""

from functools import total_ordering
import Tac
import numbers
import re
import socket
import struct

@total_ordering
class IpAddress:
   """Represents an IP address (v4/v6).  IpAddress objects are immutable."""

   def __init__( self, addr, addrFamily=socket.AF_INET ):
      """Constructs an IpAddress (v4/v6) object from a string , an int 
      or long (in host byte order), or another IpAddress object."""
      self.addrFamily_ = addrFamily
      if addrFamily == socket.AF_INET6:
         if type( addr ) == type( self ): # pylint: disable=unidiomatic-typecheck
            self.addr_ = addr.addr_
         elif isinstance( addr, numbers.Integral ):
            self.addr_ = ( socket.htonl( addr >> 96 & 0xFFFFFFFF ), 
                           socket.htonl( addr >> 64 & 0xFFFFFFFF ),
                           socket.htonl( addr >> 32 & 0xFFFFFFFF ),
                           socket.htonl( addr & 0xFFFFFFFF ) )

         else:
            addrStr = str( addr )
            try:
               self.addr_ = struct.unpack( '4I', socket.inet_pton( 
                                                         addrFamily, addrStr ) )
            except OSError:
               errMsg = "invalid IPv6 address: %s, type(addr): %s" \
                           % ( addrStr, type( addr ) )
               raise ValueError( errMsg ) # pylint: disable=raise-missing-from
      elif addrFamily == socket.AF_INET:
         if isinstance( addr, numbers.Integral ):
            if addr < 0 or addr > 4294967295:  # check range
               raise ValueError( "invalid IP address: %d" % addr )
            tmpAddr = socket.htonl( addr )
            if tmpAddr < 0:
               tmpAddr = struct.unpack( 'I', struct.pack( 'i', tmpAddr ) )[ 0 ]
            self.addr_ = tmpAddr
         elif type( addr ) == type( self ): # pylint: disable=unidiomatic-typecheck
            self.addr_ = addr.addr_
         else:
            addrStr = str( addr )
            try:
               self.addr_ = struct.unpack( 'I', socket.inet_aton( addrStr ) )[ 0 ]
            except OSError:
               errMsg = "invalid IP address: %s, type(addr): %s" \
                           % ( addrStr, type( addr ) )
               raise ValueError( errMsg ) # pylint: disable=raise-missing-from
      else:
         errMsg = "addrFamily must be either AF_INET or AF_INET6"
         raise ValueError( "%s: %s" % ( errMsg, addrFamily ) )

   def __eq__( self, other ):
      """Checks if the addr value is equal to the given addr, numerically."""
      return ( self.addrFamily_ == other.addrFamily_ and
               self.toNum() == other.toNum() )

   def __ne__( self, other ):
      return not self == other

   def __lt__( self, other ):
      """Checks if the addr value is lower than the given addr, numerically."""
      if self.addrFamily_ == other.addrFamily_:
         return self.toNum() < other.toNum()
      else:
         return False

   def __hash__( self ):
      return hash( ( self.addrFamily_, self.toNum() ) )

   def __str__( self ):
      """Returns the IP address as a string."""
      if self.addrFamily_ == socket.AF_INET6:
         val = struct.pack( 
               '4I', self.addr_[0], self.addr_[1], self.addr_[2], self.addr_[3] )
      else:
         val = struct.pack( 'I', self.addr_ )
      return socket.inet_ntop( self.addrFamily_, val )

   def toNum( self ):
      """Returns the IP address as a non-negative int or long in host byte order."""
      intAddr = 0
      if self.addrFamily_ == socket.AF_INET6:
         for val in self.addr_:
            intAddr <<= 32 
            intAddr |= socket.ntohl( val )
      else:
         intAddr = socket.ntohl( self.addr_ )
         if intAddr < 0:
            intAddr = struct.unpack( 'I', struct.pack( 'i', intAddr ) )[ 0 ]
      return intAddr

class Mask:
   """Represents an IPv4/v6 subnet mask.  Mask objects are immutable."""
   
   # Initialize the static mappings between mask lengths and masks.
   _ipMaskToPrefixLen = {}
   _ipPrefixLenToMask = []
   _ipPrefixLenToMask.extend( range( 33 ) )
   _ip6MaskToPrefixLen = {}
   _ip6PrefixLenToMask = []
   _ip6PrefixLenToMask.extend( range( 129 ) )
   for ii in range( 33 ):
      i = 32 - ii
      mask = ( ( ( ( 2 ** 32 ) - 1 ) >> ii) << ii)
      _ipPrefixLenToMask[ i ] = mask
      _ipMaskToPrefixLen[ mask ] = i
   del i, ii, mask
   
   for ii in range( 129 ):
      i = 128 - ii
      mask = ( ( ( ( 2 ** 128 ) - 1 ) >> ii) << ii)
      _ip6PrefixLenToMask[ i ] = mask
      _ip6MaskToPrefixLen[ mask ] = i
   del i, ii, mask

   def __init__( self, mask, inverse=False, addrFamily=socket.AF_INET ):
      """Constructs a Mask object from a string (in dotted decimal format), an int or
      long (containing the prefix length), or another Mask object."""

      if addrFamily not in [ socket.AF_INET, socket.AF_INET6 ]:
         errMsg = "addrFamily must be either AF_INET or AF_INET6"
         raise ValueError( "%s: %s" % ( errMsg, addrFamily ) )

      self.addrFamily_ = addrFamily

      if isinstance( mask, numbers.Integral ):
         # pylint: disable-next=no-else-raise
         if addrFamily == socket.AF_INET and ( mask < 0 or mask > 32 ):
            raise ValueError( "invalid mask length: %d" % mask )
         elif mask < 0 or mask > 128:
            raise ValueError( "invalid mask length: %d" % mask )
         self.masklen_ = int( mask )
      elif type( mask ) == type( self ): # pylint: disable=unidiomatic-typecheck
         self.masklen_ = mask.masklen_
      else:
         num = IpAddress( mask, addrFamily=addrFamily ).toNum()
         if inverse:
            if addrFamily == socket.AF_INET:
               len = 32 # pylint: disable=redefined-builtin
            else:
               len = 128
            num ^= ( 2 ** len ) - 1
         try:
            if addrFamily == socket.AF_INET:
               self.masklen_ = Mask._ipMaskToPrefixLen[ num ]
            else:   
               self.masklen_ = Mask._ip6MaskToPrefixLen[ num ]
         except KeyError:
            # pylint: disable-next=raise-missing-from
            raise ValueError( "invalid netmask: %s" % mask )

   def __str__( self ):
      """Returns the subnet mask as a string in dotted decimal format."""
      if self.addrFamily_ == socket.AF_INET:
         return str( IpAddress( Mask._ipPrefixLenToMask[ self.masklen_ ] ) )
      else:   
         return str( IpAddress( Mask._ip6PrefixLenToMask[ self.masklen_ ], 
                                addrFamily=self.addrFamily_ ) )

   def inverseStr( self ):
      """Returns the subnet mask as a wildcard-bits (inverse) string
      in dotted decimal format."""
      if self.addrFamily_ == socket.AF_INET:
         return str( IpAddress( Mask._ipPrefixLenToMask[ self.masklen_ ] ^
                             0xffffffff ) )
      else:
         return str( IpAddress( Mask._ip6PrefixLenToMask[ self.masklen_ ] ^
                              ( ( 2 ** 128 ) - 1 ), addrFamily=self.addrFamily_ ) )

   def toNum( self ):
      """Returns the subnet mask as a non-negative int or long in host byte order."""
      if self.addrFamily_ == socket.AF_INET:
         return IpAddress( 
            Mask._ipPrefixLenToMask[ self.masklen_ ], 
            addrFamily=self.addrFamily_ ).toNum()
      else:
         return IpAddress( 
            Mask._ip6PrefixLenToMask[ self.masklen_ ], 
            addrFamily=self.addrFamily_ ).toNum()

   def numHostInSubnet( self ):
      """Returns the number of IP addresses that are contained within a subnet that
      has this mask."""
      if self.addrFamily_ == socket.AF_INET:
         return 2 ** ( 32 - self.masklen_ )
      else:
         return 2 ** ( 128 - self.masklen_ )

   @classmethod
   def ipMaskToPrefixLen( cls ):
      return Mask._ipMaskToPrefixLen

   maskLen = property( lambda self: self.masklen_ )
   
      
def Prefix( s ):
   """Constructs a string representing an Arnet::Prefix from a string (in the
   format '1.2.3.4', '1.2.3.4/16' or 1:2:3::4 or 1:2:3::4/64 )."""

   m = re.match( r"(\d+\.\d+\.\d+\.\d+)/(\d+)$", s )
   if m:
      return s

   m = re.match( r"(\d+\.\d+\.\d+\.\d+)$", s )
   if m:
      return s + "/32"

   m = re.match( r"[\da-f:]+/(\d+)$", s )
   if m:
      return s

   m = re.match( r"[\da-f:]+$", s )
   if m:
      return s + "/128"

   raise ValueError( "Expected prefix in addr/len format; got '" + s + "'" )

def compareIpAddress( addr1, addr2 ):
   """Compares two IP addresses numerically."""
   ip1 = IpAddress( addr1 )
   ip2 = IpAddress( addr2 )
   return ( ip1 > ip2 ) - ( ip1 < ip2 ) # equivalent to Python 2 cmp()

def compareIp6Address( addr1, addr2 ):
   """Compares two IP addresses numerically."""
   ip1 = IpAddress( addr1.stringValue, addrFamily=socket.AF_INET6 )
   ip2 = IpAddress( addr2.stringValue, addrFamily=socket.AF_INET6 )
   return ( ip1 > ip2 ) - ( ip1 < ip2 ) # equivalent to Python 2 cmp()

def compareIp6AddressStr( addr1, addr2 ):
   """Compares two IP addresses numerically."""
   ip1 = IpAddress( addr1, addrFamily=socket.AF_INET6 )
   ip2 = IpAddress( addr2, addrFamily=socket.AF_INET6 )
   return ( ip1 > ip2 ) - ( ip1 < ip2 ) # equivalent to Python 2 cmp()

def _compare_prefix( a1, l1, a2, l2 ):
   # compare IpAddress a1 and a2, if equal fall back to prefix length
   cmpaddr = ( a1 > a2 ) - ( a1 < a2 ) # equivalent to Python 2 cmp()
   il1 = int( l1 )
   il2 = int( l2 )
   cmplen = ( il1 > il2 ) - ( il1 < il2 )
   return cmpaddr or cmplen

def compareIp6PrefixStr( p1, p2 ):
   """Compares two IPv6 prefix strings."""
   ( a1, l1 ) = p1.split( '/' )
   ( a2, l2 ) = p2.split( '/' )
   ia1 = IpAddress( a1, addrFamily=socket.AF_INET6 )
   ia2 = IpAddress( a2, addrFamily=socket.AF_INET6 )
   return _compare_prefix( ia1, l1, ia2, l2 )

def compareIpPrefixStr( p1, p2 ):
   """Compares two IP prefix strings of the form '1.2.3.4/24'.
   Returns the one with a lower IP address, breaking ties in favor of
   the one with a shorter prefix length."""

   ( a1, l1 ) = p1.split( '/' )
   ( a2, l2 ) = p2.split( '/' )
   return _compare_prefix( IpAddress( a1 ), l1, IpAddress( a2 ), l2 )

def compareIpPrefix( prefix1, prefix2 ):
   """Compares two Prefix objects, first by IP address and then by prefix length."""
   return _compare_prefix( IpAddress( prefix1.address ), prefix1.len,
                           IpAddress( prefix2.address ), prefix2.len )

def compareIpAddrOrHostname( addr1, addr2 ):
   """ Compares IP address and/or hostname strings. 
   IP addresses precede hostnames.
   """

   # The arguments must be strings
   for addr in [ addr1, addr2 ]:
      if type( addr ) is not str: # pylint: disable=unidiomatic-typecheck
         errMsg = "compareIpAddrOrHostnameStr() argument must be a string"
         raise TypeError( "%s: %s" % ( errMsg, addr ) )

   # If the input is not a valid IP address, we assume it's a hostname
   # The first element of the tuple is used to enforce precedence.
   try:
      t1 = ( 0, IpAddress( addr1 ) )
   except ValueError:
      t1 = ( 1, addr1 )
   try:
      t2 = ( 0, IpAddress( addr2 ) )
   except ValueError:
      t2 = ( 1, addr2 )
   
   if t1[ 0 ] != t2[ 0 ]:
      return t1[ 0 ] - t2[ 0 ]
   elif t1[ 0 ]: # hostnames
      return ( addr1 > addr2 ) - ( addr1 < addr2 )
   else: # IP addresses
      return ( t1[ 1 ] > t2[ 1 ] ) - ( t1[ 1 ] < t2[ 1 ] )
