# Copyright (c) 2006-2009, 2010 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

# pylint: disable=consider-using-f-string

"""Utilities for managing Linux kernel devices.

This module provides classes for managing Linux kernel network devices, including TAP
devices."""

# pkgdeps: specline %{python3_sitelib}/Arnet/Device.py

import atexit
import os
import errno
import sys
import re
import socket
import select
import fcntl
import struct
import weakref

import Arnet
from Arnet.NsLib import DEFAULT_NS, runMaybeInNetNs
from Arnet.Verify import devicePresent
from Arnet import IpTestLib
import ArPyUtils.AsanHelper
from SysConstants.if_ether_h import ETH_P_ALL
from SysConstants.if_tun_h import TUNSETIFF, IFF_TUN, IFF_TAP, IFF_NO_PI
import Tac
import Tracing

t9 = Tracing.trace9

_OPENEDDEVICES = weakref.WeakValueDictionary()

@atexit.register
def _closeOpenedDevs():
   for d in list( _OPENEDDEVICES.values() ):
      d.close()

def ifNameToIndex( intfName, netNs=DEFAULT_NS, deviceNs=None ):
   '''Returns the corresponding index if intfName is the name
   of an interface. Otherwise, zero.
   '''

   path = "/sys/class/net/%s/ifindex" % intfName
   try:
      out = runMaybeInNetNs( netNs, [ "cat", path ], deviceNs=deviceNs,
                             stdout=Tac.CAPTURE, stderr=Tac.DISCARD )
      return int( out )
   except Tac.SystemCommandError:
      return 0

class Device:
   """Wrapper for a Linux kernel network device.  Provides methods for configuring IP
   and Ethernet properties of the interface."""

   nextSuffix_ = {} # key is prefix, value is next suffix (int)
   @classmethod
   def uniqueDevName( cls, prefix ):
      # Ok, not guaranteed to be unique, but if there is a conflict,
      # something probably needs to be cleaned up anyway.
      suffix = cls.nextSuffix_.setdefault( prefix, 1 )
      devname = "%s%d-%d" % ( prefix, os.getpid(), suffix )
      cls.nextSuffix_[ prefix ] = suffix + 1
      return devname

   def __init__( self, ifname, hw=None, ip=None, netmask=None, 
                 ip6Addr=None, up=True, mtu=None, netNs=DEFAULT_NS,
                 ip6Enabled=True, dadEnabled=True ):
      self.name_ = ifname
      self.hw_ = hw
      self.ip_ = ip
      self.up_ = up
      self.netmask_ = netmask
      self.mtu_ = mtu
      self.netNs_ = netNs
      self.ifconfig()

      if mtu is None:
         self.mtu_ = 1500

      if not dadEnabled:
         self.dadDisableIs()

      # Make sure to set ip6Enabled after enabling/disabling DAD, otherwise the
      # link local address will go through DAD too.
      if not ip6Enabled:
         self.ip6DisableIs()

      # Make sure to set ip6Addr after enabling/disabling DAD, otherwise DAD
      # will run in the background during address binding.
      if ip6Addr is not None:
         self.ip6AddrIs( ip6Addr )

      t9( "Interface", self, "ready" )
      _OPENEDDEVICES[ self.name_ ] = self

   name = property( lambda self: self.name_ )
   deviceName = property( lambda self: self.name_ )
   ip = property( lambda self: self.ip_ )

   def prefix( self ):
      return str( Arnet.Subnet( self.ip_, self.netmask_ ) )

   def __repr__( self ):
      return "<" + self.name_ + " ethernet device>"

   def ifconfig( self, *arguments ):
      if not arguments:
         arguments += ( 'down', )
         if self.hw_:
            arguments += ( "hw", "ether", self.hw_ )
         if self.ip_:
            arguments += ( self.ip_, )
         if self.netmask_:
            arguments += ( "netmask", self.netmask_ )
         if self.up_:
            arguments += ( "up", )
         if self.mtu_:
            arguments += ( "mtu", str( self.mtu_ ) )
      if not arguments:
         return
      runMaybeInNetNs( self.netNs_, 
                       [ "ifconfig", self.name_ ] + list( arguments ),
                       stdout=Tac.CAPTURE, asRoot=True )

   def ifup( self, up ):
      if up:
         self.up_  = True
         arguments = [ "up" ]
      else:
         self.up_ = False
         arguments = [ "down" ]
      runMaybeInNetNs( self.netNs_, [ "ifconfig", self.name_ ]  + arguments,
                       stdout=Tac.CAPTURE, asRoot=True )
      
   def arpSet( self, ip, ether ):
      runMaybeInNetNs( self.netNs_, 
                       [ "arp", "-i", self.name_, "-s", ip, ether ],
                       stdout=Tac.CAPTURE, asRoot=True )

   def ipAddrIs( self, ipAddr, delete=False ):
      runMaybeInNetNs( self.netNs_, [ "ip", "addr", "add" if not delete else "del",
                                      ipAddr, "dev", self.name_ ],
                       stdout=Tac.CAPTURE, asRoot=True )

   @staticmethod
   def _renameIntf( netNs, name, newName ):
      runMaybeInNetNs( netNs, [ "ip", "link", "set", name, "name", newName ],
                       stdout=Tac.CAPTURE, asRoot=True )

   def rename( self, newName ):
      self._renameIntf( self.netNs_, self.name_, newName )
      self.name_ = newName

   def mtuIs( self, mtu ):
      if mtu != self.mtu_:
         runMaybeInNetNs( self.netNs_, [ "ip", "link", "set", "dev", self.name_,
                                         "mtu", str( mtu ) ],
                          stdout=Tac.CAPTURE, asRoot=True )
         self.mtu_ = mtu

   def flagSet( self, flagName, set_=True ):
      runMaybeInNetNs( self.netNs_, [ "ip", "link", "set", "dev", self.name_,
                                      flagName, "on" if set_ else "off" ],
                       stdout=Tac.CAPTURE, asRoot=True )

   def __getAndSetPath( self, path, value ):
      cmd = [ 'bash', '-c', f'cat {path}' ]
      ret = runMaybeInNetNs( self.netNs_, cmd, stdout=Tac.CAPTURE, asRoot=True )
      cmd = [ 'bash', '-c', f'echo "{value}" > {path}' ]
      runMaybeInNetNs( self.netNs_, cmd, asRoot=True )
      return ret.strip()

   def ip6DisableIs( self ):
      return self.__disableIp6( '1' )

   def ip6EnableIs( self ):
      return self.__disableIp6( '0' )

   def dadEnableIs( self ):
      return self.__enableDad( '1' )

   def dadDisableIs( self ):
      return self.__enableDad( '0' )

   def acceptDad( self ):
      path = f'/proc/sys/net/ipv6/conf/{self.name_}/accept_dad'
      cmd = [ 'bash', '-c', f'cat {path}' ]
      value = runMaybeInNetNs( self.netNs_, cmd, stdout=Tac.CAPTURE, asRoot=True )
      return int( value )

   def autoConfEnableIs( self ):
      return self.__enableAutoConf( '1' )

   def autoConfDisableIs( self ):
      return self.__enableAutoConf( '0' )

   def dadTransmitsIs( self, count ):
      fname = f'/proc/sys/net/ipv6/conf/{self.name_}/dad_transmits'
      return int( self.__getAndSetPath( fname, count ) )

   def neighRetransTimeIs( self, count ):
      fname = f'/proc/sys/net/ipv6/neigh/{self.name_}/retrans_time'
      return int( self.__getAndSetPath( fname, count ) )

   def ip6AddrIs( self, ip6AddrList, delete=False, kernelTimeout=None ):
      if not delete:
         self.__disableIp6( "0" )
         action = 'add'
      else:
         action = 'del'
      if isinstance( ip6AddrList, str ):
         ip6AddrList = [ ip6AddrList ]
      # do not attempt to install a link local address as that will result
      # in a failed command
      hw = self.hw()
      linkLocal = str( IpTestLib.ip6AddrFromMac( hw ) ) if hw else None
      for addr in ip6AddrList:
         addr = str( addr )
         if kernelTimeout:
            cmdList = [ 'ip', '-6', 'addr', action, addr, 'dev', self.name_,
                        'valid_lft', str( kernelTimeout ), 'preferred_lft',
                        str( kernelTimeout ) ]
         else:
            cmdList = [ 'ip', '-6', 'addr', action, addr, 'dev', self.name_ ]
         if addr != linkLocal:
            runMaybeInNetNs( self.netNs_, cmdList,
                            stdout=Tac.CAPTURE, asRoot=True )

   def ip6AddrDel( self, ip6AddrList ):
      self.ip6AddrIs( ip6AddrList, True )

   def ip6Addr( self ):
      cmdStr = "ip -6 addr show %s" % ( self.name_ )
      # "ip -6 addr show" is known to leak memory, disable asan leak check
      with ArPyUtils.AsanHelper.disableSubprocessAsanLeakcheck():
         addrDict = IpTestLib.ip6KernelAddrCommon(
            Arnet.NsLib.runMaybeInNetNs( self.netNs_, cmdStr.split(),
                                         stdout=Tac.CAPTURE ).splitlines() )
      ip6List = []
      if self.name_ in addrDict:
         for i in addrDict[ self.name_ ][ 'addr' ]:
            ip6List.append( i[ 'addr' ] )
      return ip6List

   def ip6Enabled( self ):
      cmdStr = "ip -6 addr show dev %s" % self.name_
      outStr = Arnet.NsLib.runMaybeInNetNs(
         self.netNs_, cmdStr.split(), stdout=Tac.CAPTURE ).splitlines()
      for l in outStr:
         # we need to match the /64 one.  Other link locals will have /128
         # also this will return only if the link local has been assigned to the
         # intf.  If the string is something like '.* scope link tentative', it
         # means that DAD is running and the addr has not been assigned yet
         m = re.match( r'[ ]*inet6 %s/64 scope link[ ]*$' % self.ip6LinkLocalAddr(),
                       l )
         if m:
            return True
      return False

   def ip6LinkLocalAddr( self ):
      return "%s" % ( IpTestLib.ip6AddrFromMac( self.hw() ) ) 

   def netNsIs( self, netNs=DEFAULT_NS, move=False ):
      if move:
         cmd = [ "ip", "link", "set", self.name_, "netns", netNs ]
         Arnet.NsLib.runMaybeInNetNs( self.netNs_, cmd, asRoot=True )
      self.netNs_ = netNs

   def ifIndex( self ):
      return ifNameToIndex( self.name_, self.netNs_ )

   def __disableIp6( self, off ):
      fname = f'/proc/sys/net/ipv6/conf/{self.name_}/disable_ipv6'
      return int( self.__getAndSetPath( fname, off ) )

   def __enableDad( self, off ):
      fname = f'/proc/sys/net/ipv6/conf/{self.name_}/accept_dad'
      return int( self.__getAndSetPath( fname, off ) )

   def __enableAutoConf( self, off ):
      fname = f'/proc/sys/net/ipv6/conf/{self.name_}/autoconf'
      autoconf = int( self.__getAndSetPath( fname, off ) )
      fname = f'/proc/sys/net/ipv6/conf/{self.name_}/accept_ra'
      accept_ra = int( self.__getAndSetPath( fname, off ) )
      return ( autoconf, accept_ra )

   def hw( self ):
      ifList = runMaybeInNetNs( self.netNs_, [ 'ifconfig', self.name_ ],
                                stdout=Tac.CAPTURE, asRoot=True ).splitlines()
      exp = re.compile( '.*ether[ ]+([a-fA-F0-9:]+).* ' )
      for l in ifList:
         m = exp.match( l )
         if m:
            return m.group( 1 )
      return None

# accept_local - BOOLEAN
#   Accept packets with local source addresses. In combination with
#   suitable routing, this can be used to direct packets between two
#   local interfaces over the wire and have them accepted properly.
#   default FALSE
   def supportLocalIs( self, enable ):
      cmd = [ 'bash', '-c', 'echo %s > /proc/sys/net/ipv4/conf/%s/accept_local' %
              ( enable, self.name_ ) ]
      runMaybeInNetNs( self.netNs_, cmd )

class Tap( Device ):
   """Wrapper for a Linux kernel TUN/TAP device.  When an object of this class is
   created, it creates the corresponding Linux kernel TUN or TAP device.  When the
   object is destroyed, it destroys the Linux kernel TUN/TAP device.  Provides
   methods for sending and receiving packets over the interface.
   To create a TUN interface, specify ethernet=False.
   To make Tap.send wait for the packet to arrive in the tap by default,
   use sync=True
   """

   def __init__( self, ifname=None, hw=None, ip=None, netmask=None,
                 ip6Addr=None, ifprefix='tap', up=True, mtu=None, netNs=DEFAULT_NS,
                 blocking=True, ip6Enabled=True, ethernet=True, dadEnabled=True,
                 sync=False ):
      t9( 'Tap.__init__ %s' % ifname )
      if ifname is None:
         ifname = self.uniqueDevName( ifprefix )
         tmpName = ifname
      else:
         # If we want to create an interface in a non-default NS, and it exists
         # already in the default NS, opentuntap is failed.  To address this case,
         # first create the interface with a uniqueDevName, move it to the
         # non-default NS, and then rename it to the desired "ifname".
         if devicePresent( ifname, DEFAULT_NS ) and netNs != DEFAULT_NS:
            tmpName = self.uniqueDevName( ifprefix )
            t9( 'Tap.__init__', ifname, 'using temp intf name', tmpName )
         else:
            tmpName = ifname
      assert len( ifname ) <= 15 # hard limit in tuntap code
      self.sync = sync # default for Tap.send( sync )
      self.l2SocketOnTap = None # used to sync Tap.send
      self.closed_ = False
      self.fd_ = None
      # Create a TAP network device of the chosen name.
      try:
         t9( "opentuntap %s" % tmpName )
         ifr = struct.pack( '16sH', tmpName.encode( 'utf-8' ),
                            ( IFF_TAP if ethernet else IFF_TUN ) | IFF_NO_PI )
         self.fd_ = os.open( '/dev/net/tun', os.O_RDWR )
         fcntl.ioctl( self.fd_, TUNSETIFF, ifr )
         if not blocking:
            fcntl.fcntl( self.fd_, fcntl.F_SETFL, os.O_NONBLOCK)
      except OSError as e:
         self.close()
         if e.errno == errno.EPERM:
            sys.stderr.write(
               "EPERM setting up tap device: is '%s' already taken?\n" % tmpName )
         raise

      if netNs != DEFAULT_NS:
         cmd = [ "ip", "link", "set", tmpName, "netns", netNs ]
         runMaybeInNetNs( DEFAULT_NS, cmd, asRoot=True )

      # The device cannot be renamed if it's already up, so rename it (if required)
      # before calling Device.__init__ (if required).
      if tmpName != ifname:
         self._renameIntf( netNs, tmpName, ifname )
      Device.__init__( self, ifname, hw, ip, netmask, ip6Addr, up, mtu, netNs,
                       ip6Enabled, dadEnabled )

   def __del__( self ):
      if not self.closed_:
         t9( 'Tap.__del__ %s' % self.fd_ )
         self.close()

   def close( self ):
      t9( 'Tap.close closed:%s fd:%s' % ( self.closed_, self.fd_ ) )
      if not self.closed_:
         self.closed_ = True
         if self.l2SocketOnTap is not None:
            self.l2SocketOnTap.close()
         if self.fd_ is not None:
            os.close( self.fd_ )
  
   def recv( self, mtu=None, timeout=None ):
      if timeout is not None:
         poller = select.poll()
         poller.register( self.fd_, select.EPOLLIN )
         if not poller.poll( timeout * 1000 ):
            t9( self.name_, ": poller timeout after", timeout * 1000, "s" )
            return None
      data = os.read( self.fd_, mtu if mtu else ( self.mtu_ + 100 ) )
      t9( self.name_, ": received packet (from kernel):", Tracing.HexDump( data ) )
      return data

   def _maybeInitL2ListenSocket( self, sync ):
      if self.l2SocketOnTap is not None or not sync:
         return
      assert self.netNs_ == DEFAULT_NS, \
         'TODO: implement sync mode support for non-default netNs'
      self.l2SocketOnTap = socket.socket( socket.AF_PACKET, socket.SOCK_RAW,
                                          socket.htons( ETH_P_ALL ) )
      self.l2SocketOnTap.bind( ( self.name_, ETH_P_ALL ) )

   def waitPkt_( self, pktsToWait ):
      poller = select.poll()
      poller.register( self.l2SocketOnTap, select.EPOLLIN )
      if not poller.poll( 1 ):
         return False
      pktReceived = self.l2SocketOnTap.recv( self.mtu_ + 100 )
      return pktReceived in pktsToWait

   def send( self, data, sync=None, pktsToWait=None ):
      '''Send data through tap
      sync: when true, return only after packet is sniffed on eth port
      pktsToWait: wait until any of these packets is sniffed; default = pkt sent
      We don't wait for data because vlan tags can be stripped in some
      linux kernel versions.
      '''
      if sync is None:
         sync = self.sync
      t9( self.name_, ": transmitting packet (to kernel, sync ", sync, "):",
          Tracing.HexDump( data ) )
      self._maybeInitL2ListenSocket( sync=sync )
      os.write( self.fd_, data )
      if sync:
         if pktsToWait is None:
            pktsToWait = [ data ]
         else:
            assert isinstance( pktsToWait, list )
         Tac.waitFor( lambda: self.waitPkt_( pktsToWait ), sleep=True,
                      description='packet in tap interface' )
      return data
  
   def fileno( self ):
      return self.fd_

# Wrapper for a Linux kernel dummy device.
class Dummy( Device ):
   def __init__( self, ifname=None, hw=None, ip=None, netmask=None,
                 ip6Addr=None, ifprefix='dummy', up=True, mtu=None, netNs=DEFAULT_NS,
                 ip6Enabled=True, dadEnabled=True ):
      t9( 'Tap.__init__ %s' % ifname )
      if ifname is None:
         ifname = self.uniqueDevName( ifprefix )
      assert len( ifname ) <= 15

      self.closed_ = False

      # Create the dummy device
      cmd = [ "ip", "link", "add", ifname, "type", "dummy" ]
      runMaybeInNetNs( netNs, cmd, asRoot=True )

      Device.__init__( self, ifname, hw, ip, netmask, ip6Addr, up, mtu, netNs,
                       ip6Enabled, dadEnabled )

   def __del__( self ):
      self.close()

   def close( self ):
      if not self.closed_:
         self.closed_ = True

         cmd = [ "ip", "link", "del", self.name_ ]
         runMaybeInNetNs( self.netNs_, cmd, asRoot=True )
