Bug 1716963 - Guard access to marionette socket with a lock, r=webdriver-reviewers,whimboo

This should ensure that we can't end up with multiple threads
interleaving reads or writes on the socket.

Differential Revision: https://phabricator.services.mozilla.com/D118148
This commit is contained in:
James Graham 2021-07-07 14:22:59 +00:00
Родитель 83bd42eb4d
Коммит 154910af70
2 изменённых файлов: 144 добавлений и 110 удалений

Просмотреть файл

@ -8,22 +8,23 @@ import json
import socket import socket
import sys import sys
import time import time
from threading import RLock
import six import six
class SocketTimeout(object): class SocketTimeout(object):
def __init__(self, socket, timeout): def __init__(self, socket_ctx, timeout):
self.sock = socket self.socket_ctx = socket_ctx
self.timeout = timeout self.timeout = timeout
self.old_timeout = None self.old_timeout = None
def __enter__(self): def __enter__(self):
self.old_timeout = self.sock.gettimeout() self.old_timeout = self.socket_ctx.socket_timeout
self.sock.settimeout(self.timeout) self.socket_ctx.socket_timeout = self.timeout
def __exit__(self, *args, **kwargs): def __exit__(self, *args, **kwargs):
self.sock.settimeout(self.old_timeout) self.socket_ctx.socket_timeout = self.old_timeout
class Message(object): class Message(object):
@ -90,6 +91,35 @@ class Response(Message):
return Response(data[1], data[2], data[3]) return Response(data[1], data[2], data[3])
class SocketContext(object):
"""Object that guards access to a socket via a lock.
The socket must be accessed using this object as a context manager;
access to the socket outside of a context will bypass the lock."""
def __init__(self, host, port, timeout):
self.lock = RLock()
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._sock.settimeout(timeout)
self._sock.connect((host, port))
@property
def socket_timeout(self):
return self._sock.gettimeout()
@socket_timeout.setter
def socket_timeout(self, value):
self._sock.settimeout(value)
def __enter__(self):
self.lock.acquire()
return self._sock
def __exit__(self, *args, **kwargs):
self.lock.release()
class TcpTransport(object): class TcpTransport(object):
"""Socket client that communciates with Marionette via TCP. """Socket client that communciates with Marionette via TCP.
@ -111,11 +141,11 @@ class TcpTransport(object):
will be used. Setting it to `1` or `None` disables timeouts on will be used. Setting it to `1` or `None` disables timeouts on
socket operations altogether. socket operations altogether.
""" """
self._sock = None self._socket_context = None
self.host = host self.host = host
self.port = port self.port = port
self.socket_timeout = socket_timeout self._socket_timeout = socket_timeout
self.protocol = self.min_protocol_level self.protocol = self.min_protocol_level
self.application_type = None self.application_type = None
@ -130,8 +160,8 @@ class TcpTransport(object):
def socket_timeout(self, value): def socket_timeout(self, value):
self._socket_timeout = value self._socket_timeout = value
if self._sock: if self._socket_context is not None:
self._sock.settimeout(value) self._socket_context.socket_timeout = value
def _unmarshal(self, packet): def _unmarshal(self, packet):
msg = None msg = None
@ -168,89 +198,91 @@ class TcpTransport(object):
# is 4 bytes: "2:{}". In practice the marionette format has some required fields so the # is 4 bytes: "2:{}". In practice the marionette format has some required fields so the
# message is longer, but 4 bytes allows reading messages with bodies up to 999 bytes in # message is longer, but 4 bytes allows reading messages with bodies up to 999 bytes in
# length in two reads, which is the common case. # length in two reads, which is the common case.
recv_bytes = 4 with self._socket_context as sock:
recv_bytes = 4
length_prefix = b"" length_prefix = b""
body_length = -1 body_length = -1
body_received = 0 body_received = 0
body_parts = [] body_parts = []
now = time.time() now = time.time()
timeout_time = ( timeout_time = (
now + self.socket_timeout if self.socket_timeout is not None else None now + self.socket_timeout if self.socket_timeout is not None else None
) )
while recv_bytes > 0: while recv_bytes > 0:
if timeout_time is not None and time.time() > timeout_time: if timeout_time is not None and time.time() > timeout_time:
raise socket.timeout( raise socket.timeout(
"Connection timed out after {}s".format(self.socket_timeout) "Connection timed out after {}s".format(self.socket_timeout)
)
try:
chunk = self._sock.recv(recv_bytes)
except OSError:
continue
if not chunk:
raise socket.error("No data received over socket")
body_part = None
if body_length > 0:
body_part = chunk
else:
parts = chunk.split(b":", 1)
length_prefix += parts[0]
# With > 10 decimal digits we aren't going to have a 32 bit number
if len(length_prefix) > 10:
raise ValueError(
"Invalid message length: {!r}".format(length_prefix)
) )
if len(parts) == 2: try:
# We found a : so we know the full length chunk = sock.recv(recv_bytes)
err = None except OSError:
try: continue
body_length = int(length_prefix)
except ValueError: if not chunk:
err = "expected an integer" raise socket.error("No data received over socket")
else:
if body_length <= 0: body_part = None
err = "expected a positive integer" if body_length > 0:
elif body_length > 2 ** 32 - 1: body_part = chunk
err = "expected a 32 bit integer" else:
if err is not None: parts = chunk.split(b":", 1)
length_prefix += parts[0]
# With > 10 decimal digits we aren't going to have a 32 bit number
if len(length_prefix) > 10:
raise ValueError( raise ValueError(
"Invalid message length: {} got {!r}".format( "Invalid message length: {!r}".format(length_prefix)
err, length_prefix
)
) )
body_part = parts[1]
# If we didn't find a : yet we keep reading 4 bytes at a time until we do. if len(parts) == 2:
# We could increase this here to 7 bytes (since we can't have more than 10 length # We found a : so we know the full length
# bytes and a seperator byte), or just increase it to int(length_prefix) + 1 since err = None
# that's the minimum total number of remaining bytes (if the : is in the next try:
# byte), but it's probably not worth optimising for large messages. body_length = int(length_prefix)
except ValueError:
err = "expected an integer"
else:
if body_length <= 0:
err = "expected a positive integer"
elif body_length > 2 ** 32 - 1:
err = "expected a 32 bit integer"
if err is not None:
raise ValueError(
"Invalid message length: {} got {!r}".format(
err, length_prefix
)
)
body_part = parts[1]
if body_part is not None: # If we didn't find a : yet we keep reading 4 bytes at a time until we do.
body_received += len(body_part) # We could increase this here to 7 bytes (since we can't have more than 10
body_parts.append(body_part) # length bytes and a seperator byte), or just increase it to
recv_bytes = body_length - body_received # int(length_prefix) + 1 since that's the minimum total number of remaining
# bytes (if the : is in the next byte), but it's probably not worth optimising
# for large messages.
body = b"".join(body_parts) if body_part is not None:
if unmarshal: body_received += len(body_part)
msg = self._unmarshal(body) body_parts.append(body_part)
self.last_id = msg.id recv_bytes = body_length - body_received
# keep reading incoming responses until body = b"".join(body_parts)
# we receive the user's expected response if unmarshal:
if isinstance(msg, Response) and msg != self.expected_response: msg = self._unmarshal(body)
return self.receive(unmarshal) self.last_id = msg.id
return msg # keep reading incoming responses until
return body # we receive the user's expected response
if isinstance(msg, Response) and msg != self.expected_response:
return self.receive(unmarshal)
return msg
return body
def connect(self): def connect(self):
"""Connect to the server and process the hello message we expect """Connect to the server and process the hello message we expect
@ -259,18 +291,17 @@ class TcpTransport(object):
Returns a tuple of the protocol level and the application type. Returns a tuple of the protocol level and the application type.
""" """
try: try:
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket_context = SocketContext(
self._sock.settimeout(self.socket_timeout) self.host, self.port, self._socket_timeout
)
self._sock.connect((self.host, self.port))
except Exception: except Exception:
# Unset so that the next attempt to send will cause # Unset so that the next attempt to send will cause
# another connection attempt. # another connection attempt.
self._sock = None self._socket_context = None
raise raise
try: try:
with SocketTimeout(self._sock, 60.0): with SocketTimeout(self._socket_context, 60.0):
# first packet is always a JSON Object # first packet is always a JSON Object
# which we can use to tell which protocol level we are at # which we can use to tell which protocol level we are at
raw = self.receive(unmarshal=False) raw = self.receive(unmarshal=False)
@ -301,7 +332,7 @@ class TcpTransport(object):
"""Send message to the remote server. Allowed input is a """Send message to the remote server. Allowed input is a
``Message`` instance or a JSON serialisable object. ``Message`` instance or a JSON serialisable object.
""" """
if not self._sock: if not self._socket_context:
self.connect() self.connect()
if isinstance(obj, Message): if isinstance(obj, Message):
@ -313,17 +344,18 @@ class TcpTransport(object):
data = six.ensure_binary(data) data = six.ensure_binary(data)
payload = six.ensure_binary(str(len(data))) + b":" + data payload = six.ensure_binary(str(len(data))) + b":" + data
totalsent = 0 with self._socket_context as sock:
while totalsent < len(payload): totalsent = 0
sent = self._sock.send(payload[totalsent:]) while totalsent < len(payload):
if sent == 0: sent = sock.send(payload[totalsent:])
raise IOError( if sent == 0:
"Socket error after sending {0} of {1} bytes".format( raise IOError(
totalsent, len(payload) "Socket error after sending {0} of {1} bytes".format(
totalsent, len(payload)
)
) )
) else:
else: totalsent += sent
totalsent += sent
def respond(self, obj): def respond(self, obj):
"""Send a response to a command. This can be an arbitrary JSON """Send a response to a command. This can be an arbitrary JSON
@ -355,20 +387,21 @@ class TcpTransport(object):
See: https://docs.python.org/2/howto/sockets.html#disconnecting See: https://docs.python.org/2/howto/sockets.html#disconnecting
""" """
if self._sock: if self._socket_context:
try: with self._socket_context as sock:
self._sock.shutdown(socket.SHUT_RDWR) try:
except IOError as exc: sock.shutdown(socket.SHUT_RDWR)
# If the socket is already closed, don't care about: except IOError as exc:
# Errno 57: Socket not connected # If the socket is already closed, don't care about:
# Errno 107: Transport endpoint is not connected # Errno 57: Socket not connected
if exc.errno not in (57, 107): # Errno 107: Transport endpoint is not connected
raise if exc.errno not in (57, 107):
raise
if self._sock: if sock:
# Guard against unclean shutdown. # Guard against unclean shutdown.
self._sock.close() sock.close()
self._sock = None self._socket_context = None
def __del__(self): def __del__(self):
self.close() self.close()

Просмотреть файл

@ -69,7 +69,8 @@ class TestMarionette(MarionetteTestCase):
self.assertEqual(current_socket_timeout, self.marionette.client.socket_timeout) self.assertEqual(current_socket_timeout, self.marionette.client.socket_timeout)
self.assertEqual( self.assertEqual(
current_socket_timeout, self.marionette.client._sock.gettimeout() current_socket_timeout,
self.marionette.client._socket_context._sock.gettimeout(),
) )
def test_application_update_disabled(self): def test_application_update_disabled(self):