#!/usr/bin/env python3
# Copyright (c) 2024 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import abc
import os
import socket
import sys

import Swag
import Tac
import Tracing

t0 = Tracing.trace0

TCP_ADDRESS_ENVVAR = 'ACR_TCP_ADDRESS'
PORT_RANGE_MIN_ENVVAR = 'ACR_TCP_PORT_RANGE_MIN'
PORT_RANGE_MAX_ENVVAR = 'ACR_TCP_PORT_RANGE_MAX'

PrivateTcpPorts = Tac.Type( 'Arnet::PrivateTcpPorts' )
DEFAULT_TCP_PORT_RANGE = ( PrivateTcpPorts.networkAcrDefaultMin,
                           PrivateTcpPorts.networkAcrDefaultMax )

class SocketProvider( metaclass=abc.ABCMeta ):
   """
   An abstract base class for factory classes to be used to obtain sockets for
   servicing ACRs.
   """
   def createSocket( self, requestName, timeout ):
      sock, targetSpec = self._createSocketImpl( requestName )
      if sock is None:
         return None, None

      sock.settimeout( timeout )
      sock.listen( 1 )

      return sock, targetSpec

   @abc.abstractmethod
   def _createSocketImpl( self, requestName ):
      raise NotImplementedError

   @abc.abstractmethod
   def deleteSocket( self, requestName, targetSpec ):
      raise NotImplementedError

class UnixSocketProvider( SocketProvider ):
   """
   ACR socket factory class that returns UNIX domain sockets created under the /tmp
   directory.
   """
   def _createSocketImpl( self, requestName ):
      socketFile = '/tmp/' + requestName

      # Make sure the socket does not already exist.
      try:
         os.unlink( socketFile )
      except OSError:
         if os.path.exists( socketFile ):
            t0( 'Error: existing socket' )
            return None, None

      # Create the socket and bind to the name.
      t0( 'Initializing socket' )
      try:
         sock = socket.socket( socket.AF_UNIX, socket.SOCK_STREAM )
      except OSError:
         e = sys.exc_info()[ 0 ]
         t0( f'Error on socket creation: {e}' )
         return None, None

      sock.bind( socketFile )

      return sock, Tac.Value( 'Agent::AgentCommandRequestTarget',
                              unixSocket=requestName )

   def deleteSocket( self, requestName, targetSpec ):
      # The responsibility of closing the socket falls on the caller. This function
      # performs any additional resource cleanup (beyond closing the socket) that is
      # associated with the provisioned socket.
      t0( 'Deleting socket' )

      if not targetSpec.containsUnixSocket():
         t0( 'Error on socket cleanup: targetSpec is not a UNIX socket: ' +
             targetSpec )

      socketFile = '/tmp/' + targetSpec.unixSocket
      os.unlink( socketFile )

class NetworkSocketProvider( SocketProvider ):
   """ACR socket factory class that returns TCP network sockets."""
   def __init__( self, networkAddress, portRange=DEFAULT_TCP_PORT_RANGE ):
      self._networkAddress = networkAddress
      self._portRangeStart, self._portRangeEnd = portRange

   def _bindSocket( self, sock, bindAddr ):
      bindAddr = str( bindAddr )
      for portCandidate in range( self._portRangeStart, self._portRangeEnd + 1 ):
         try:
            sock.bind( ( bindAddr, portCandidate ) )
            return portCandidate
         except OSError:
            pass

      t0( 'Failed to bind ACR port' )
      return None

   def _createSocketImpl( self, requestName ):
      # Create and bind the socket.
      t0( 'Initializing socket' )
      try:
         sock = socket.socket( socket.AF_INET, socket.SOCK_STREAM )
      except OSError:
         e = sys.exc_info()[ 0 ]
         t0( f'Error on socket creation: {e}' )
         return None, None

      localPort = self._bindSocket( sock, self._networkAddress )
      if localPort is None:
         sock.close()
         return None, None

      return sock, Tac.Value(
         'Agent::AgentCommandRequestTarget',
         networkSocket=Tac.Value( 'Agent::AgentCommandNetworkTarget',
                                  self._networkAddress, localPort ) )

   def deleteSocket( self, requestName, targetSpec ):
      # Closing the socket is sufficient, and the responsibility of closing the
      # socket falls on the caller.
      pass

def networkSocketParams():
   env = os.environ

   portOverrideMin = env.get( PORT_RANGE_MIN_ENVVAR )
   portOverrideMax = env.get( PORT_RANGE_MAX_ENVVAR )

   def tracePortRangeError( msg ):
      t0( f'{msg};'
          f' {PORT_RANGE_MIN_ENVVAR}={portOverrideMin!r}'
          f' {PORT_RANGE_MAX_ENVVAR}={portOverrideMax!r}' )

   # Determine the network address.
   networkAddr = env.get( TCP_ADDRESS_ENVVAR )
   if not networkAddr:
      swagMemberId = Swag.memberId()
      if swagMemberId is not None:
         networkAddr = Swag.memberIpAddr( swagMemberId )
   if not networkAddr:
      # If no network address is configured, ACR will fall back on using Unix domain
      # sockets. If the environment has been configured with port overrides, emit a
      # warning.
      if portOverrideMin or portOverrideMax:
         tracePortRangeError(
            'ACR network port range overrides unused due to running on non-SWAG '
            'switch without ACR network address override' )
      return None

   # If no port range override has been specified, use the default port range.
   paramsWithDefaultPortRange = ( networkAddr, DEFAULT_TCP_PORT_RANGE )
   if not ( portOverrideMin or portOverrideMax ):
      return paramsWithDefaultPortRange

   # Check that both limits have been specified in the port range override.
   if not ( portOverrideMin and portOverrideMax ):
      tracePortRangeError( 'ACR network port range overrides must be specified in '
                           'pairs' )
      return paramsWithDefaultPortRange

   # Parse the port range override to integers.
   try:
      portOverrideMin = int( portOverrideMin )
      portOverrideMax = int( portOverrideMax )
   except ValueError:
      tracePortRangeError( 'Failed to parse ACR network port range to integers' )
      return paramsWithDefaultPortRange

   # Perform a basic check that the port range override makes sense.
   if portOverrideMin > portOverrideMax:
      tracePortRangeError( 'ACR network port range override has lower limit '
                           'greater than upper limit' )
      return paramsWithDefaultPortRange

   return networkAddr, ( portOverrideMin, portOverrideMax )

def _defaultSocketProvider():
   networkParams = networkSocketParams()
   if networkParams is not None:
      t0( f'ACR using networked socket provider: {networkParams}' )
      return NetworkSocketProvider( *networkParams )

   t0( 'ACR using Unix domain socket provider' )
   return UnixSocketProvider()

_socketProvider = _defaultSocketProvider()

def socketProvider():
   return _socketProvider
