зеркало из https://github.com/mozilla/gecko-dev.git
Bug 1716963 - Guard access to marionette socket with a lock, r=webdriver-reviewers,whimboo
This should ensure that we can't end up with multiple threads interleaving reads or writes on the socket. Differential Revision: https://phabricator.services.mozilla.com/D118148
This commit is contained in:
Родитель
83bd42eb4d
Коммит
154910af70
|
@ -8,22 +8,23 @@ import json
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from threading import RLock
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
|
||||||
class SocketTimeout(object):
|
class SocketTimeout(object):
|
||||||
def __init__(self, socket, timeout):
|
def __init__(self, socket_ctx, timeout):
|
||||||
self.sock = socket
|
self.socket_ctx = socket_ctx
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.old_timeout = None
|
self.old_timeout = None
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.old_timeout = self.sock.gettimeout()
|
self.old_timeout = self.socket_ctx.socket_timeout
|
||||||
self.sock.settimeout(self.timeout)
|
self.socket_ctx.socket_timeout = self.timeout
|
||||||
|
|
||||||
def __exit__(self, *args, **kwargs):
|
def __exit__(self, *args, **kwargs):
|
||||||
self.sock.settimeout(self.old_timeout)
|
self.socket_ctx.socket_timeout = self.old_timeout
|
||||||
|
|
||||||
|
|
||||||
class Message(object):
|
class Message(object):
|
||||||
|
@ -90,6 +91,35 @@ class Response(Message):
|
||||||
return Response(data[1], data[2], data[3])
|
return Response(data[1], data[2], data[3])
|
||||||
|
|
||||||
|
|
||||||
|
class SocketContext(object):
|
||||||
|
"""Object that guards access to a socket via a lock.
|
||||||
|
|
||||||
|
The socket must be accessed using this object as a context manager;
|
||||||
|
access to the socket outside of a context will bypass the lock."""
|
||||||
|
|
||||||
|
def __init__(self, host, port, timeout):
|
||||||
|
self.lock = RLock()
|
||||||
|
|
||||||
|
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
self._sock.settimeout(timeout)
|
||||||
|
self._sock.connect((host, port))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def socket_timeout(self):
|
||||||
|
return self._sock.gettimeout()
|
||||||
|
|
||||||
|
@socket_timeout.setter
|
||||||
|
def socket_timeout(self, value):
|
||||||
|
self._sock.settimeout(value)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.lock.acquire()
|
||||||
|
return self._sock
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
self.lock.release()
|
||||||
|
|
||||||
|
|
||||||
class TcpTransport(object):
|
class TcpTransport(object):
|
||||||
"""Socket client that communciates with Marionette via TCP.
|
"""Socket client that communciates with Marionette via TCP.
|
||||||
|
|
||||||
|
@ -111,11 +141,11 @@ class TcpTransport(object):
|
||||||
will be used. Setting it to `1` or `None` disables timeouts on
|
will be used. Setting it to `1` or `None` disables timeouts on
|
||||||
socket operations altogether.
|
socket operations altogether.
|
||||||
"""
|
"""
|
||||||
self._sock = None
|
self._socket_context = None
|
||||||
|
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.socket_timeout = socket_timeout
|
self._socket_timeout = socket_timeout
|
||||||
|
|
||||||
self.protocol = self.min_protocol_level
|
self.protocol = self.min_protocol_level
|
||||||
self.application_type = None
|
self.application_type = None
|
||||||
|
@ -130,8 +160,8 @@ class TcpTransport(object):
|
||||||
def socket_timeout(self, value):
|
def socket_timeout(self, value):
|
||||||
self._socket_timeout = value
|
self._socket_timeout = value
|
||||||
|
|
||||||
if self._sock:
|
if self._socket_context is not None:
|
||||||
self._sock.settimeout(value)
|
self._socket_context.socket_timeout = value
|
||||||
|
|
||||||
def _unmarshal(self, packet):
|
def _unmarshal(self, packet):
|
||||||
msg = None
|
msg = None
|
||||||
|
@ -168,89 +198,91 @@ class TcpTransport(object):
|
||||||
# is 4 bytes: "2:{}". In practice the marionette format has some required fields so the
|
# is 4 bytes: "2:{}". In practice the marionette format has some required fields so the
|
||||||
# message is longer, but 4 bytes allows reading messages with bodies up to 999 bytes in
|
# message is longer, but 4 bytes allows reading messages with bodies up to 999 bytes in
|
||||||
# length in two reads, which is the common case.
|
# length in two reads, which is the common case.
|
||||||
recv_bytes = 4
|
with self._socket_context as sock:
|
||||||
|
recv_bytes = 4
|
||||||
|
|
||||||
length_prefix = b""
|
length_prefix = b""
|
||||||
|
|
||||||
body_length = -1
|
body_length = -1
|
||||||
body_received = 0
|
body_received = 0
|
||||||
body_parts = []
|
body_parts = []
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
timeout_time = (
|
timeout_time = (
|
||||||
now + self.socket_timeout if self.socket_timeout is not None else None
|
now + self.socket_timeout if self.socket_timeout is not None else None
|
||||||
)
|
)
|
||||||
|
|
||||||
while recv_bytes > 0:
|
while recv_bytes > 0:
|
||||||
if timeout_time is not None and time.time() > timeout_time:
|
if timeout_time is not None and time.time() > timeout_time:
|
||||||
raise socket.timeout(
|
raise socket.timeout(
|
||||||
"Connection timed out after {}s".format(self.socket_timeout)
|
"Connection timed out after {}s".format(self.socket_timeout)
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
chunk = self._sock.recv(recv_bytes)
|
|
||||||
except OSError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not chunk:
|
|
||||||
raise socket.error("No data received over socket")
|
|
||||||
|
|
||||||
body_part = None
|
|
||||||
if body_length > 0:
|
|
||||||
body_part = chunk
|
|
||||||
else:
|
|
||||||
parts = chunk.split(b":", 1)
|
|
||||||
length_prefix += parts[0]
|
|
||||||
|
|
||||||
# With > 10 decimal digits we aren't going to have a 32 bit number
|
|
||||||
if len(length_prefix) > 10:
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid message length: {!r}".format(length_prefix)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(parts) == 2:
|
try:
|
||||||
# We found a : so we know the full length
|
chunk = sock.recv(recv_bytes)
|
||||||
err = None
|
except OSError:
|
||||||
try:
|
continue
|
||||||
body_length = int(length_prefix)
|
|
||||||
except ValueError:
|
if not chunk:
|
||||||
err = "expected an integer"
|
raise socket.error("No data received over socket")
|
||||||
else:
|
|
||||||
if body_length <= 0:
|
body_part = None
|
||||||
err = "expected a positive integer"
|
if body_length > 0:
|
||||||
elif body_length > 2 ** 32 - 1:
|
body_part = chunk
|
||||||
err = "expected a 32 bit integer"
|
else:
|
||||||
if err is not None:
|
parts = chunk.split(b":", 1)
|
||||||
|
length_prefix += parts[0]
|
||||||
|
|
||||||
|
# With > 10 decimal digits we aren't going to have a 32 bit number
|
||||||
|
if len(length_prefix) > 10:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid message length: {} got {!r}".format(
|
"Invalid message length: {!r}".format(length_prefix)
|
||||||
err, length_prefix
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
body_part = parts[1]
|
|
||||||
|
|
||||||
# If we didn't find a : yet we keep reading 4 bytes at a time until we do.
|
if len(parts) == 2:
|
||||||
# We could increase this here to 7 bytes (since we can't have more than 10 length
|
# We found a : so we know the full length
|
||||||
# bytes and a seperator byte), or just increase it to int(length_prefix) + 1 since
|
err = None
|
||||||
# that's the minimum total number of remaining bytes (if the : is in the next
|
try:
|
||||||
# byte), but it's probably not worth optimising for large messages.
|
body_length = int(length_prefix)
|
||||||
|
except ValueError:
|
||||||
|
err = "expected an integer"
|
||||||
|
else:
|
||||||
|
if body_length <= 0:
|
||||||
|
err = "expected a positive integer"
|
||||||
|
elif body_length > 2 ** 32 - 1:
|
||||||
|
err = "expected a 32 bit integer"
|
||||||
|
if err is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid message length: {} got {!r}".format(
|
||||||
|
err, length_prefix
|
||||||
|
)
|
||||||
|
)
|
||||||
|
body_part = parts[1]
|
||||||
|
|
||||||
if body_part is not None:
|
# If we didn't find a : yet we keep reading 4 bytes at a time until we do.
|
||||||
body_received += len(body_part)
|
# We could increase this here to 7 bytes (since we can't have more than 10
|
||||||
body_parts.append(body_part)
|
# length bytes and a seperator byte), or just increase it to
|
||||||
recv_bytes = body_length - body_received
|
# int(length_prefix) + 1 since that's the minimum total number of remaining
|
||||||
|
# bytes (if the : is in the next byte), but it's probably not worth optimising
|
||||||
|
# for large messages.
|
||||||
|
|
||||||
body = b"".join(body_parts)
|
if body_part is not None:
|
||||||
if unmarshal:
|
body_received += len(body_part)
|
||||||
msg = self._unmarshal(body)
|
body_parts.append(body_part)
|
||||||
self.last_id = msg.id
|
recv_bytes = body_length - body_received
|
||||||
|
|
||||||
# keep reading incoming responses until
|
body = b"".join(body_parts)
|
||||||
# we receive the user's expected response
|
if unmarshal:
|
||||||
if isinstance(msg, Response) and msg != self.expected_response:
|
msg = self._unmarshal(body)
|
||||||
return self.receive(unmarshal)
|
self.last_id = msg.id
|
||||||
|
|
||||||
return msg
|
# keep reading incoming responses until
|
||||||
return body
|
# we receive the user's expected response
|
||||||
|
if isinstance(msg, Response) and msg != self.expected_response:
|
||||||
|
return self.receive(unmarshal)
|
||||||
|
|
||||||
|
return msg
|
||||||
|
return body
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
"""Connect to the server and process the hello message we expect
|
"""Connect to the server and process the hello message we expect
|
||||||
|
@ -259,18 +291,17 @@ class TcpTransport(object):
|
||||||
Returns a tuple of the protocol level and the application type.
|
Returns a tuple of the protocol level and the application type.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
self._socket_context = SocketContext(
|
||||||
self._sock.settimeout(self.socket_timeout)
|
self.host, self.port, self._socket_timeout
|
||||||
|
)
|
||||||
self._sock.connect((self.host, self.port))
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# Unset so that the next attempt to send will cause
|
# Unset so that the next attempt to send will cause
|
||||||
# another connection attempt.
|
# another connection attempt.
|
||||||
self._sock = None
|
self._socket_context = None
|
||||||
raise
|
raise
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with SocketTimeout(self._sock, 60.0):
|
with SocketTimeout(self._socket_context, 60.0):
|
||||||
# first packet is always a JSON Object
|
# first packet is always a JSON Object
|
||||||
# which we can use to tell which protocol level we are at
|
# which we can use to tell which protocol level we are at
|
||||||
raw = self.receive(unmarshal=False)
|
raw = self.receive(unmarshal=False)
|
||||||
|
@ -301,7 +332,7 @@ class TcpTransport(object):
|
||||||
"""Send message to the remote server. Allowed input is a
|
"""Send message to the remote server. Allowed input is a
|
||||||
``Message`` instance or a JSON serialisable object.
|
``Message`` instance or a JSON serialisable object.
|
||||||
"""
|
"""
|
||||||
if not self._sock:
|
if not self._socket_context:
|
||||||
self.connect()
|
self.connect()
|
||||||
|
|
||||||
if isinstance(obj, Message):
|
if isinstance(obj, Message):
|
||||||
|
@ -313,17 +344,18 @@ class TcpTransport(object):
|
||||||
data = six.ensure_binary(data)
|
data = six.ensure_binary(data)
|
||||||
payload = six.ensure_binary(str(len(data))) + b":" + data
|
payload = six.ensure_binary(str(len(data))) + b":" + data
|
||||||
|
|
||||||
totalsent = 0
|
with self._socket_context as sock:
|
||||||
while totalsent < len(payload):
|
totalsent = 0
|
||||||
sent = self._sock.send(payload[totalsent:])
|
while totalsent < len(payload):
|
||||||
if sent == 0:
|
sent = sock.send(payload[totalsent:])
|
||||||
raise IOError(
|
if sent == 0:
|
||||||
"Socket error after sending {0} of {1} bytes".format(
|
raise IOError(
|
||||||
totalsent, len(payload)
|
"Socket error after sending {0} of {1} bytes".format(
|
||||||
|
totalsent, len(payload)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
else:
|
||||||
else:
|
totalsent += sent
|
||||||
totalsent += sent
|
|
||||||
|
|
||||||
def respond(self, obj):
|
def respond(self, obj):
|
||||||
"""Send a response to a command. This can be an arbitrary JSON
|
"""Send a response to a command. This can be an arbitrary JSON
|
||||||
|
@ -355,20 +387,21 @@ class TcpTransport(object):
|
||||||
|
|
||||||
See: https://docs.python.org/2/howto/sockets.html#disconnecting
|
See: https://docs.python.org/2/howto/sockets.html#disconnecting
|
||||||
"""
|
"""
|
||||||
if self._sock:
|
if self._socket_context:
|
||||||
try:
|
with self._socket_context as sock:
|
||||||
self._sock.shutdown(socket.SHUT_RDWR)
|
try:
|
||||||
except IOError as exc:
|
sock.shutdown(socket.SHUT_RDWR)
|
||||||
# If the socket is already closed, don't care about:
|
except IOError as exc:
|
||||||
# Errno 57: Socket not connected
|
# If the socket is already closed, don't care about:
|
||||||
# Errno 107: Transport endpoint is not connected
|
# Errno 57: Socket not connected
|
||||||
if exc.errno not in (57, 107):
|
# Errno 107: Transport endpoint is not connected
|
||||||
raise
|
if exc.errno not in (57, 107):
|
||||||
|
raise
|
||||||
|
|
||||||
if self._sock:
|
if sock:
|
||||||
# Guard against unclean shutdown.
|
# Guard against unclean shutdown.
|
||||||
self._sock.close()
|
sock.close()
|
||||||
self._sock = None
|
self._socket_context = None
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.close()
|
self.close()
|
||||||
|
|
|
@ -69,7 +69,8 @@ class TestMarionette(MarionetteTestCase):
|
||||||
|
|
||||||
self.assertEqual(current_socket_timeout, self.marionette.client.socket_timeout)
|
self.assertEqual(current_socket_timeout, self.marionette.client.socket_timeout)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
current_socket_timeout, self.marionette.client._sock.gettimeout()
|
current_socket_timeout,
|
||||||
|
self.marionette.client._socket_context._sock.gettimeout(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_application_update_disabled(self):
|
def test_application_update_disabled(self):
|
||||||
|
|
Загрузка…
Ссылка в новой задаче