# Copyright (c) 2013 Arista Networks, Inc.  All rights reserved.
# Arista Networks, Inc. Confidential and Proprietary.

import socket
import struct

import Tracing

import IpcConnectionContext_pb2 as IpcConnectionContext
import RpcPayloadHeader_pb2 as RpcPayloadHeader

from google.protobuf.message import DecodeError # pylint: disable=F0401

traceHandle = Tracing.defaultTraceHandle()
t0 = traceHandle.trace0
t1 = traceHandle.trace1
t2 = traceHandle.trace2
t5 = traceHandle.trace5


# Arbitrary amount to limit how big arrays and strings we're willing to read
# from the wire.  Set aggressively low.
MAX_SIZE = 1024 * 1024


def serializeVLong( n ):
   """Returns a variable-length encoded long integer.

   This code has been adapted from asynchbase's HBaseRpc.java
   """
   if -112 <= n <= 127:
      return struct.pack( "b", n )

   if n < 0:
      n = ~n
      b = 0x88
   else:
      b = 0x90

   tmp = n
   tmp >>= 8
   b -= 1
   while tmp:
      tmp >>= 8
      b -= 1

   return struct.pack( "B", b ) + struct.pack( ">Q", n )[ b & 0x07: ]

class Socket:
   """Class to wrap a socket and add a few convenience methods and buffered reads."""

   def __init__( self, sock, engine, timeout=60 ):
      """Constructor.

      Args:
        - sock: The underlying socket.socket() we're wrapping.
        - engine: The RpcEngine we intend to use on this connection.
        - timeout: The socket-level timeout, in seconds.
      """
      self.sock_ = sock
      # When we read from the socket, we attempt to read more than we need and
      # we store the result in this buffer, which we consume until it's empty,
      # at which point we attempt to read again from the socket.
      self.buf_ = b""
      self.timeout_ = timeout
      self.setTimeout( self.timeout_ )
      self.sock_.setsockopt( socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1 )
      self.engine_ = engine

   def setTimeout( self, timeout ):
      self.timeout_ = timeout
      if self.sock_:
         self.sock_.settimeout( timeout )

   def engine( self ):
      """Returns the RpcEngine used on this socket."""
      return self.engine_

   def connect( self, *args ):
      return self.sock_.connect( *args )

   def sendall( self, *args ):
      try:
         self.sock_.sendall( *args )
      except OSError:
         self.close()
         raise

   def recvall( self, length, *args ):
      """Like recv() but always tries to return `length' bytes.

      Raises:
        - ConnectionClosedException: if the connection is closed on us while
          we're trying to read from it.
      """
      if length == 0:
         return b""
      while True:
         if length <= len( self.buf_ ):  # We can satisfy the read from the buffer
            resp = self.buf_[ : length ]
            self.buf_ = self.buf_[ length : ]
            return resp

         try:
            received = self.sock_.recv( 4096, *args )
         except OSError:
            self.close()
            raise

         if not received:
            # Server closed the connection, let's close our side of the
            # connection and raise a socket exception for our clients.
            try:
               raise ConnectionClosedException( self.sock_ )
            finally:
               self.close()
         if Tracing.enabled( 5 ):
            t5( repr( received ) )
         if self.buf_:  # We already had some data, but not enough...
            self.buf_ += received  # ... so keep accumulating.
         else:  # We didn't have any data, so just assign the buffer we received
            self.buf_ = received  # from the socket to free any slice we may have.

   def unread( self, buf ):
      """Pushes data back at the beginning of our socket buffer.
      This data will then be the next thing that we'll read in recvall().

      Args:
        - buf: A string that we'll append to our buffer.
      """
      self.buf_ = buf + self.buf_

   def readBool( self ):
      """Reads a boolean."""
      return self.recvall( 1 ) != b"\x00"

   def readShort( self ):
      """Reads a 16-bit integer."""
      return struct.unpack( ">h", self.recvall( 2 ) )[ 0 ]

   def readInt( self ):
      """Reads a 32-bit integer."""
      return struct.unpack( ">i", self.recvall( 4 ) )[ 0 ]

   def readLong( self ):
      """Reads a 64-bit integer."""
      return struct.unpack( ">q", self.recvall( 8 ) )[ 0 ]

   def readVLong( self ):
      """Decodes a variable-length encoded integer and returns a python int/long.

      This code has been adapted from asynchbase's HBaseRpc.java
      """
      b = struct.unpack( "B", self.recvall( 1 ) )[ 0 ]
      if b & 0xF0 != 0x80:
         return b if b <= 127 else b - 256

      result = ( b & 0x07 ) * b"\000" + self.recvall( 8 - ( b & 0x07 ) )
      result = struct.unpack( ">Q", result )[ 0 ]
      return result if b & 0b00001000 else ~result

   def readFloat( self ):
      """Reads a 32-bit IEEE 754 flaoting point value."""
      return struct.unpack( ">f", self.recvall( 4 ) )[ 0 ]

   def readString( self ):
      """Reads a string prefixed by its 32-bit size."""
      length = self.readInt()
      if length in ( -1, 0 ):
         return ""
      ensure( 0 <= length < MAX_SIZE, "Invalid string length: %d" % length )
      return self.recvall( length )

   def readUTF8( self ):
      """Reads a string prefixed by its 16-bit size."""
      length = self.readShort()
      ensure( length >= 0, "Invalid string length: %d" % length )
      return self.recvall( length ).decode()

   def readText( self ):
      """Reads a string prefixed by its vint size."""
      length = self.readVLong()
      ensure( 0 <= length < MAX_SIZE, "Invalid string length: %d" % length )
      return self.recvall( length ).decode()

   def close( self ):
      if self.sock_:
         self.buf_ = b""
         self.sock_.close()
         self.sock_ = None

   def connected( self ):
      return bool( self.sock_ )


class HadoopException( Exception ):
   """Base class for all Hadoop related exceptions."""


class NullInstanceException( HadoopException ):
   """ Exception raised when we receive a NullInstance response from Hadoop """
   def __init__( self, message ):
      HadoopException.__init__( self, message )


class RemoteRpcException( HadoopException ):
   """Exception raised when the RPC fails with a remote exception."""

   def __init__( self, klass, message ):
      HadoopException.__init__( self, message )
      self.klass = klass  # Class name of the remote exception

   def __str__( self ):
      base = super().__str__()
      return f"{self.klass}: {base}"


class ConnectionClosedException( HadoopException ):
   """The server unexpectedly closed the connection on us."""

   def __init__( self, sock ):
      try:
         peer = sock.getpeername()
         if len( peer ) == 2:  # For IP sockets:
            peer = "%s:%s" % peer
         msg = f"Connection to {peer} reset by peer"
      except OSError:  # ECONNRESET is typical here.
         msg = "Connection reset by peer"
      HadoopException.__init__( self, msg )


class InvalidRpcResponseException( HadoopException ):
   """We tried to de-serialize a response from the wire that is illegal."""


def ensure( condition, message ):
   """Like an assert to use to validate conditions while reading responses.

   Args:
     - condition: A boolean.  If False, causes an exception to be raised.
     - message: The message used to construct the exception.  Must be a string
       or a callable.  If a callable, will be called and must return a string.
   Raises:
     InvalidRpcResponseException
   """
   if not condition:
      if callable( message ):
         message = message()
      t0( "Assertion failed", message )
      raise InvalidRpcResponseException( message )


class WritableMetaClass( type ):

   # pylint: disable-next=bad-mcs-classmethod-argument
   def __new__( mcs, name, bases, fields ):
      klass = type.__new__( mcs, name, bases, fields )
      if name not in ( "Writable", "HadoopArray" ):
         remote = fields.get( "HADOOP_CLASS_NAME" )
         assert remote, "Class %r is missing HADOOP_CLASS_NAME" % name
         assert remote not in Writable.CLASSMAP, repr( remote )
         Writable.CLASSMAP[ remote ] = klass
      return klass

class Writable( metaclass=WritableMetaClass ):
   """Base class for classes that mimic Hadoop's Writable class."""

   CLASSMAP = {}  # Maps strings to subclasses of Writable.
                  # Filled by subclass creation with the meta class.

   USER_DEFINED_TYPE = True  # Set to false for built-in types (e.g. "int").
   __slots__ = []
   @staticmethod
   def classFor( name ):
      """Given a Java class name (string) returns the corresponding Python class."""
      if name[ 0 ] == "[":
         cls = HadoopArray.of( name )
      else:
         cls = Writable.CLASSMAP.get( name )
         if cls is None:
            raise InvalidRpcResponseException( "Unsupported class %r" % name )
      return cls

   def serialize( self ):
      """Returns a serialized representation of this object as a string."""
      raise NotImplementedError()

   @classmethod
   def readFrom( cls, sock ):
      """Deserializes an instance of cls from the given Socket object."""
      raise NotImplementedError

   @staticmethod
   def serializeString( s ):
      """Returns a serialized version of the given string with 16-bit length."""
      b = s.encode()
      return struct.pack( ">H", len( b ) ) + b

   @staticmethod
   def serializeText( s ):
      """Returns a serialized version of the given string with vint length."""
      b = s.encode()
      return serializeVLong( len( b ) ) + b

   def __hash__( self ):
      """Returns a hash of all the 'public' attributes."""
      return hash( tuple( getattr( self, attr ) for attr in self.__slots__ ) )

   def __eq__( self, other ):
      """Equality comparison operator based on all the 'public' attributes."""
      if self.__class__ != other.__class__:
         return False
      for attr in self.__slots__:
         if attr[ 0 ] == "_":
            continue
         if getattr( self, attr ) != getattr( other, attr ):
            return False
      return True

   def __ne__( self, other ):
      """Inequality comparison operator based on all the 'public' attributes."""
      return not self == other

   def __repr__( self ):
      attributes = sorted( self.__slots__ )
      return "{}({})".format( self.__class__.__name__,
                          ", ".join( f"{attr}={getattr( self, attr )!r}"
                                     for attr in attributes ) )


class Boolean( Writable ):
   """Wraps a built-in bool in a Writable."""
   HADOOP_CLASS_NAME = "boolean"
   USER_DEFINED_TYPE = False
   __slots__ = [ 'value' ]

   def __init__( self, value ):
      super().__init__()
      assert isinstance( value, bool ), f"Invalid argument: {value!r}"
      self.value = value

   def serialize( self ):
      return b"\x01" if self.value else b"\x00"

   @classmethod
   def readFrom( cls, sock ):
      return sock.readBool()


class Integer( Writable ):
   """Wraps a built-in int in a Writable."""
   HADOOP_CLASS_NAME = "int"  # Always 32-bit signed in Java.
   USER_DEFINED_TYPE = False
   __slots__ = [ 'value' ]

   def __init__( self, value ):
      super().__init__()
      assert isinstance( value, int ), f"Invalid argument: {value!r}"
      self.value = value

   def serialize( self ):
      return struct.pack( ">i", self.value )

   @classmethod
   def readFrom( cls, sock ):
      return sock.readInt()


class Long( Writable ):
   """Wraps a built-in long in a Writable."""
   HADOOP_CLASS_NAME = "long"  # Always 64-bit signed in Java.
   USER_DEFINED_TYPE = False
   __slots__ = [ 'value' ]

   def __init__( self, value ):
      super().__init__()
      assert isinstance( value, int ), \
            f"Invalid argument: {value!r}"
      self.value = value

   def serialize( self ):
      return struct.pack( ">q", self.value )

   @classmethod
   def readFrom( cls, sock ):
      return sock.readLong()


class String( Writable ):
   """Wraps a java string in a Writable."""
   HADOOP_CLASS_NAME = "java.lang.String"
   USER_DEFINED_TYPE = False
   __slots__ = [ 'value' ]

   def __init__( self, value ):
      super().__init__()
      self.value = value

   @classmethod
   def readFrom( cls, sock ):
      value = sock.readUTF8()
      return String( value )

   def serialize( self ):
      b = self.value.encode()
      return struct.pack( ">h", len( b ) ) + b

class GenericWritable( Writable ):
   """Wraps a GenericeWritable response for an RPC. Responses like NullInstance are
      serialized a UTF8 org.apache.hadoop.io.Writable, then the instance like
      org.apache.hadoop.io.ObjectWritable$NullInstance.  Depending on the
      instance there are different serializations."""
   HADOOP_CLASS_NAME = "org.apache.hadoop.io.Writable"
   USER_DEFINED_TYPE = False
   HADOOP_NULL_INSTANCE = "org.apache.hadoop.io.ObjectWritable$NullInstance"
   __slots__ = [ 'instance', 'nullClass' ]

   def __init__( self, instance, nullClass ):
      super().__init__()
      self.instance = instance
      self.nullClass = nullClass

   @classmethod
   def readFrom( cls, sock ):
      instance = sock.readUTF8()
      if instance == GenericWritable.HADOOP_NULL_INSTANCE:
         nullClass = sock.readUTF8()
         raise NullInstanceException( 'Received Null Instance from %s' % nullClass )
      raise InvalidRpcResponseException( "Unsupported %s instance %s" %
                                         ( GenericWritable.HADOOP_CLASS_NAME,
                                           instance ) )
   def serialize( self ):
      raise NotImplementedError()

   def serializeNullInstance( self ):
      buf = Writable.serializeString( GenericWritable.HADOOP_CLASS_NAME )
      buf += Writable.serializeString( GenericWritable.HADOOP_NULL_INSTANCE )
      b = self.nullClass.encode()
      buf += struct.pack( ">h", len( b ) ) + b
      return buf

class HadoopArray( Writable ):
   """Represents an array of values of the same type."""
   __slots__ = [ 'name', 'values' ]

   def __init__( self, name, values ):
      super().__init__()
      self.name = name
      self.values = values

   def serialize( self ):
      # TODO(tsuna): Handle arrays of native types.
      buf = self.serializeString( "[L" + self.name + ";" )
      buf += struct.pack( ">I", len( self.values ) )
      for value in self.values:
         buf += value.serialize()
      return buf

   @staticmethod
   def of( name ):
      # TODO(tsuna): Handle arrays of native types.
      ensure( len( name ) > 3
              and name[ 0 ] == "["
              and name[ 1 ] == "L"
              and name[ -1 ] == ";", repr( name ) )
      klass = Writable.classFor( name[ 2 : -1 ] )

      def readOne( sock ):
         name = sock.readUTF8()
         # pylint is confused about life here.
         # pylint: disable-msg=E1103
         want = klass.HADOOP_CLASS_NAME
         ensure( name == want, f"Got {name!r} but expected {want!r}" )
         if issubclass( klass, Writable ): # Stupidity of the Hadoop protocol...
            name2 = sock.readUTF8()        # ... the type name is repeated twice.
            ensure( name2 == want, f"Got {name!r} but expected {want!r}" )
         return klass.readFrom( sock )

      class DictReader:
         USER_DEFINED_TYPE = False
         @classmethod
         def readFrom( cls, sock ):
            length = sock.readInt()
            ensure( 0 <= length < MAX_SIZE, "Invalid array length: %d" % length )
            return { obj.key(): obj for obj in ( readOne( sock ) for _ in
                                                 range( length ) ) }

      class ListReader:
         USER_DEFINED_TYPE = False
         @classmethod
         def readFrom( cls, sock ):
            length = sock.readInt()
            ensure( 0 <= length < MAX_SIZE, "Invalid array length: %d" % length )
            return [ readOne( sock ) for _ in range( length ) ]

      if hasattr( klass, "key" ):
         return DictReader
      else:
         return ListReader

   @classmethod
   def readFrom( cls, sock ):
      raise NotImplementedError

   def __repr__( self ):
      return f"HadoopArray({self.name!r}, {self.values!r})"

class HadoopRpc:
   """Base class for all Hadoop RPCs."""
   __slots__ = [ 'method_', 'args_' ]

   def __init__( self, method, *args ):
      self.method_ = method
      self.args_ = args

   def method( self ):
      return self.method_

   def serialize( self ):
      raise NotImplementedError()


class WritableRpc( HadoopRpc ):
   """Base class for Hadoop RPCs that have Writable payloads."""

   CLIENT_VERSION = None  # Must be set by subclasses.
   PROTOCOL = None  # Must be set by subclasses.

   def serialize( self ):
      buf = struct.pack( ">I", len( self.args_ ) )
      for arg in self.args_:
         assert isinstance( arg, Writable ), f"Invalid RPC argument: {arg!r}"
         buf += Writable.serializeString( arg.HADOOP_CLASS_NAME )
         if arg.USER_DEFINED_TYPE:
            buf += Writable.serializeString( arg.HADOOP_CLASS_NAME )
         buf += arg.serialize()
      return buf

   @staticmethod
   def getRpcParam( sock ):
      """ Read the RPC Parameters from the socket and return its python object
          representation.  The RPC param contains the class name and that's how
          the python objects can be created. """
      name = sock.readUTF8()
      cls = Writable.classFor( name )
      if issubclass( cls, Writable ) and cls.USER_DEFINED_TYPE:
          # Stupidity of the Hadoop protocol...
         name2 = sock.readUTF8()  # ... the type name is repeated twice.
         ensure( name2 == name, f"2nd class name was {name2!r}, expected {name!r}" )
      return cls.readFrom( sock )

class RpcEngine:
   """Interface for RPC engines.

   There are different versions / flavors of RPC protocols to communicate with
   Hadoop:
     - "Writable"-based RPCs in up to and including Hadoop 1.x
     - Protobuf-based RPCs in Hadoop 2.x
     - Mixed-mode (part-Protobuf, part-"Writable") in "Hadoop 2.0 with MR1"
       used in CDH4.
     - Secure RPC.
   """

   def __init__( self, user ):
      self.user = user  # Credentials to use in the RPCs.

   def helloPreamble( self, rpc ):
      """Returns a string containing the 'hello' preamble.

      Upon connecting to a Hadoop server, the client must send a preliminary
      message to declare which protocol and what version it's using as well
      as its credentials.  This is sometimes called the "connection header",
      or the "hello preamble", or the "initial handshake".

      Args:
        - rpc: The first RPC we're trying to send so that we know the name of
          the protocol we're trying to use.
      Returns:
        A string that will be the first thing written on the socket upon
        connecting to the server.
      """
      raise NotImplementedError()

   def serializeRpc( self, rpc, callid ):
      """Serializes an RPC so it can be sent to the wire.

      Args:
        - rpc: The RPC object we're trying to serialize.
        - callid: A 32-bit integer of the ID of this RPC.
      Returns:
        A string containing the serialized RPC as it should be written to the wire.
      """
      raise NotImplementedError()

   def recvResponse( self, sock ):
      """De-serializes a response from the wire.

      Args:
        - sock: An instance of Socket to read from.
      Returns:
        An object representing the de-serialized response.
      """
      raise NotImplementedError()


class Hadoop1xWritableEngine( RpcEngine ):
   """Legacy Writable-based RPC engine.

   This flavor of the RPC engine is for Hadoop 1.x and distros based on it,
   such as HDP 1.x.
   """
   # See Client.writeRpcHeader()
   PREAMBLE = ( b"hrpc"  # Server.HEADER
                b"\4"    # Server.CURRENT_VERSION
                b"P"     # AuthMethod.SIMPLE
                )

   # enum ipc.Status
   RPC_SUCCESS = 0
   RPC_ERROR = 1
   RPC_FATAL = -1

   def helloPreamble( self, rpc ):
      # See ConnectionHeader.write()
      header = Writable.serializeText( rpc.PROTOCOL )
      header += b"\x01"  # Boolean: true (indicates a UserGroupInformation follows)
      header += Writable.serializeString( self.user )
      header += b"\x00"  # Boolean: false (we're not a proxy user)
      return self.PREAMBLE + struct.pack( ">I", len( header ) ) + header

   def serializeRpc( self, rpc, callid ):
      # See Client.sendParam()
      payload = struct.pack( ">I", callid )
      # See RPC$Invocation.write()
      payload += Writable.serializeString( rpc.method() )
      payload += rpc.serialize()
      return payload

   def recvResponse( self, sock ):
      # See Client.receiveResponse()
      callid = sock.readInt()  # TODO: Check it matches the callid we sent.
      status = sock.readInt()
      if status != Hadoop1xWritableEngine.RPC_SUCCESS:
         t2( "RPC #", callid, "failed with status", status )
         try:
            raise RemoteRpcException( sock.readString(), sock.readString() )
         finally:
            if status == Hadoop1xWritableEngine.RPC_FATAL:
               sock.close()
      return WritableRpc.getRpcParam( sock )


class ProtoBufWithWritableEngine( RpcEngine ):
   """Mixed-mode RPC engine, using a Protocol Buffer header.

   This transitional mode (RPC header in protobuf, payload in Writable) is
   present only in CDH4.
   """
   # See Client.writeConnectionHeader()
   PREAMBLE = ( b"hrpc"  # Server.HEADER
                b"\7"    # Server.CURRENT_VERSION
                b"P"     # AuthMethod.SIMPLE
                b"\0"    # Server.IpcSerializationType.PROTOBUF
                )

   WRITABLE_RPC_VERSION = 2

   def helloPreamble( self, rpc ):
      # See Client.writeConnectionContext()
      # pylint is ignorant of how protobufs work (sigh), so silent it here.
      # pylint: disable-msg=E1101
      context = IpcConnectionContext.IpcConnectionContextProto()
      context.userInfo.effectiveUser = self.user
      context.protocol = rpc.PROTOCOL
      context = context.SerializeToString()
      return self.PREAMBLE + struct.pack( ">I", len( context ) ) + context

   def serializeRpc( self, rpc, callid ):
      # See Client.sendParam()
      header = RpcPayloadHeader.RpcPayloadHeaderProto() # pylint: disable=no-member
      # Because we chose RPC_WRITABLE, our payload is a
      # WritableRpcEngine$Invocation.
      header.rpcKind = RpcPayloadHeader.RPC_WRITABLE # pylint: disable=no-member
      header.rpcOp = RpcPayloadHeader.RPC_FINAL_PAYLOAD # pylint: disable=no-member
      header.callId = callid
      # pylint is ignorant of how protobufs work (sigh), so silent it here.
      # pylint: disable-msg=E1101
      header = header.SerializeToString()
      # For some stupid reason, the code uses header.writeDelimitedTo()
      # so we have a vint for the header size first.
      # Here the vint will always be on 1 byte.
      ensure( len( header ) <= 127, "Header unexpectedly large: %r" % header )
      header = struct.pack( ">B", len( header ) ) + header
      # See WritableRpcEngine$Invocation.write()
      header += struct.pack( ">Q", ProtoBufWithWritableEngine.WRITABLE_RPC_VERSION )
      header += Writable.serializeString( rpc.PROTOCOL )
      header += Writable.serializeString( rpc.method() )
      header += struct.pack( ">Q", rpc.CLIENT_VERSION )
      header += struct.pack( ">I", 0xDEADBEEF )  # Methods hash (unused).
      return header + rpc.serialize()

   def recvResponse( self, sock ):
      # See Client.receiveResponse()
      # The length is a vint but we know it's always going to be on 1 byte.
      length = ord( sock.recvall( 1 ) )
      ensure( 0 < length < 127, lambda: "RPC header length unexpectedly %r: %r"
              % ( length, sock.recvall( min( length, MAX_SIZE ) ) ) )
      header = RpcPayloadHeader.RpcResponseHeaderProto() # pylint: disable=no-member
      # pylint is ignorant of how protobufs work (sigh), so silent it here.
      # pylint: disable-msg=E1101
      headerBytes = sock.recvall( length )
      try:
         header.ParseFromString( headerBytes )
      except DecodeError as e:
         t0( "Failed to decode protobuf header", repr( headerBytes), e )
         raise InvalidRpcResponseException( "Invalid RPC header (check that"
               " the RPC port of the JobTracker is correctly configured)" ) from e
      # TODO: Check the callid matches the one we sent.
      if header.status != RpcPayloadHeader.SUCCESS:
         t2( "RPC #", header.callId, "failed with status", header.status )
         try:
            raise RemoteRpcException( sock.readString(), sock.readString() )
         finally:
            if header.status == RpcPayloadHeader.FATAL:
               sock.close()
      return WritableRpc.getRpcParam( sock )

class HadoopRpcClient:
   """Base class for Hadoop RPC clients.

   Subclasses can implement a Hadoop RPC protocol by defining additional
   methods that must then call into `_sendRpc()'.
   """

   def __init__( self, host, port, user, timeout=60 ):
      """ timeout is in seconds and configure the amount of time
                the client will wait for a server response """
      assert 0 < port < 2**16 - 1, "Invalid port: %r" % port
      assert host
      assert user
      self.host_ = host
      self.port_ = port
      self.sock_ = None
      self.callid_ = 0  # Next RPC call ID to use.
      self.timeout_ = timeout
      self.user_ = user
      # The RPC engine is what actually performs the handshake, serialization
      # and de-serialization for the version of the protocol we need to use on
      # the wire.
      self.engine = None  # Instance of an RpcEngine subclass.

   def setTimeout( self, timeout ):
      """ Set a new socket timeout in seconds, returns the old timeout """
      oldTimeout = self.timeout_
      self.timeout_ = timeout
      if self.sock_:
         self.sock_.setTimeout( self.timeout_ )
      return oldTimeout

   def _connect( self ):
      self.sock_ = Socket( socket.socket(), self.engine, timeout=self.timeout_ )
      hostport = ( self.host_, self.port_ )
      t1( "Connecting to", hostport )
      try:
         self.sock_.connect( hostport )
      except OSError:
         self.sock_.close()
         raise

   def _sendRpc( self, rpc, expectedReturnType ):
      assert isinstance( rpc, HadoopRpc ), "unexpected RPC: %r" % rpc
      assert isinstance( expectedReturnType, type ), repr( expectedReturnType )
      self.callid_ += 1

      # Do we need to create a new connection?
      newConnection = not self.sock_ or not self.sock_.connected()
      if newConnection:
         # We start off with the most recent RPC engine and, through a
         # trial and error process, we eventually find the "correct" one
         # to use for the version of Hadoop we're trying to connect to.
         self.engine = ProtoBufWithWritableEngine( self.user_ )

      def doSendRpc( newConnection ):
         def _checkType( item ):
            if not isinstance( item, expectedReturnType ):
               raise InvalidRpcResponseException( "Expected a response of type"
                        " %s to RPC %s but got a %s: %r"
                        % ( expectedReturnType.__name__, rpc.method(),
                            item.__class__.__name__, item ) )

         t2( "Sending RPC #", self.callid_, rpc.method() )
         payload = self.engine.serializeRpc( rpc, self.callid_ )
         # Regardless of the engine, the payload is always prefixed with its size.
         payload = struct.pack( ">I", len( payload ) ) + payload
         if newConnection:   # If we need to create a new connection ...
            self._connect()  # ... then actually connect and ...
            # ... prefix the payload with the preamble.
            payload = self.engine.helloPreamble( rpc ) + payload
         self.sock_.sendall( payload )
         try:
            response = self.engine.recvResponse( self.sock_ )
            if isinstance( response, list ):
               for item in response:
                  _checkType( item )
            elif isinstance( response, dict ):
               for item in response.values():
                  _checkType( item )
            else:
               _checkType( response )
            return response
         except InvalidRpcResponseException:
            self.sock_.close()
            raise

      try:
         return doSendRpc( newConnection )
      except ConnectionClosedException:
         if not newConnection:
            t1( "Connection unexpectedly closed by peer" )
            raise
         # If we get disconnected after a new connection, it's typically because
         # we're talking to Hadoop with the wrong protocol version.  So let's
         # try again but using a different RPC engine.
         if isinstance( self.engine, Hadoop1xWritableEngine ):
            raise  # We already tried a different engine, so give up here.
         self.engine = Hadoop1xWritableEngine( self.user_ )
         t1( "Connection closed on us, trying engine",
             self.engine.__class__.__name__ )
         try:
            return doSendRpc( True )  # Try again.
         except ConnectionClosedException as e:
            e.args = ( e.args[ 0 ]
                  + ".  Check that this version of Hadoop is compatible.", )
            raise

   def close( self ):
      """ Closes a RPC connection if opened """
      if self.sock_:
         t1( "Closing the connection and re-initializing" )
         self.sock_.close()
         self.sock_ = None
         self.engine = None
         self.callid_ = 0
