[RPC] More robust tracker protocol (#1085)

* [RPC] More robust tracker protocol

* fix normal rpc
This commit is contained in:
Tianqi Chen 2018-04-06 14:18:44 -07:00 коммит произвёл GitHub
Родитель 2e17e85005
Коммит 1418134003
7 изменённых файлов: 113 добавлений и 28 удалений

Просмотреть файл

@ -34,6 +34,7 @@ class TrackerCode(object):
REQUEST = 4
UPDATE_INFO = 5
SUMMARY = 6
GET_PENDING_MATCHKEYS = 7
RPC_SESS_MASK = 128

Просмотреть файл

@ -230,7 +230,7 @@ class TrackerSession(object):
if value[0] != base.TrackerCode.SUCCESS:
raise RuntimeError("Invalid return value %s" % str(value))
url, port, matchkey = value[1]
return connect(url, port, key + matchkey, session_timeout)
return connect(url, port, matchkey, session_timeout)
except socket.error as err:
self.close()
last_err = err

Просмотреть файл

@ -14,6 +14,7 @@ import socket
import multiprocessing
import errno
import struct
import time
try:
import tornado
@ -45,6 +46,7 @@ class ForwardHandler(object):
self.rpc_key = None
self.match_key = None
self.forward_proxy = None
self.alloc_time = None
def __del__(self):
logging.info("Delete %s...", self.name())
@ -237,6 +239,7 @@ class ProxyServerHandler(object):
self.sock.fileno(), event_handler, self.loop.READ)
self._client_pool = {}
self._server_pool = {}
self.timeout_alloc = 5
self.timeout_client = timeout_client
self.timeout_server = timeout_server
# tracker information
@ -245,8 +248,12 @@ class ProxyServerHandler(object):
self._tracker_conn = None
self._tracker_pending_puts = []
self._key_set = set()
self.update_tracker_period = 2
if tracker_addr:
logging.info("Tracker address:%s", str(tracker_addr))
def _callback():
self._update_tracker(True)
self.loop.call_later(self.update_tracker_period, _callback)
logging.info("RPCProxy: Websock port bind to %d", web_port)
def _on_event(self, _):
@ -271,7 +278,22 @@ class ProxyServerHandler(object):
rhs.send_data(lhs.rpc_key.encode("utf-8"))
logging.info("Pairup connect %s and %s", lhs.name(), rhs.name())
def _update_tracker(self):
def _regenerate_server_keys(self, keys):
"""Regenerate keys for server pool"""
keyset = set(self._server_pool.keys())
new_keys = []
# re-generate the server match key, so old information is invalidated.
for key in keys:
rpc_key, _ = key.split(":")
handle = self._server_pool[key]
del self._server_pool[key]
new_key = base.random_key(rpc_key + ":", keyset)
self._server_pool[new_key] = handle
keyset.add(new_key)
new_keys.append(new_key)
return new_keys
def _update_tracker(self, period_update=False):
"""Update information on tracker."""
try:
if self._tracker_conn is None:
@ -285,13 +307,33 @@ class ProxyServerHandler(object):
# just connect to tracker, need to update all keys
self._tracker_pending_puts = self._server_pool.keys()
if self._tracker_conn and period_update:
# periodically update tracker information
# regenerate key if the key is not in tracker anymore
# and there is no in-coming connection after timeout_alloc
base.sendjson(self._tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS])
pending_keys = set(base.recvjson(self._tracker_conn))
update_keys = []
for k, v in self._server_pool.items():
if k not in pending_keys:
if v.alloc_time is None:
v.alloc_time = time.time()
elif time.time() - v.alloc_time > self.timeout_alloc:
update_keys.append(k)
v.alloc_time = None
if update_keys:
logging.info("RPCProxy: No incoming conn on %s, regenerate keys...",
str(update_keys))
new_keys = self._regenerate_server_keys(update_keys)
self._tracker_pending_puts += new_keys
need_update_info = False
# report new connections
for key in self._tracker_pending_puts:
rpc_key, match_key = key.split(":")
rpc_key = key.split(":")[0]
base.sendjson(self._tracker_conn,
[TrackerCode.PUT, rpc_key,
(self._listen_port, ":" + match_key)])
(self._listen_port, key)])
assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS
if rpc_key not in self._key_set:
self._key_set.add(rpc_key)
@ -305,24 +347,17 @@ class ProxyServerHandler(object):
assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS
self._tracker_pending_puts = []
except (socket.error, IOError) as err:
retry_period = 5
logging.info(
"Lost tracker connection: %s, try reconnect in %g sec",
str(err), retry_period)
str(err), self.update_tracker_period)
self._tracker_conn.close()
self._tracker_conn = None
new_pool = {}
keyset = set(self._server_pool.keys())
# re-generate the server match key, so old information is invalidated.
for key, handle in self._server_pool.items():
rpc_key, _ = key.split(":")
key = base.random_key(rpc_key + ":", keyset)
new_pool[key] = handle
keyset.add(key)
self._server_pool = new_pool
self._regenerate_server_keys(self._server_pool.keys())
if period_update:
def _callback():
self._update_tracker()
self.loop.call_later(retry_period, _callback)
self._update_tracker(True)
self.loop.call_later(self.update_tracker_period, _callback)
def _handler_ready_tracker_mode(self, handler):
"""tracker mode to handle handler ready."""

Просмотреть файл

@ -6,7 +6,7 @@ Server is TCP based with the following protocol:
- Initial handshake to the peer
- [RPC_MAGIC, keysize(int32), key-bytes]
- The key is in format
- {server|client}:device-type[:matchkey] [-timeout=timeout]
- {server|client}:device-type[:random-key] [-timeout=timeout]
"""
from __future__ import absolute_import
@ -75,7 +75,7 @@ def _parse_server_opt(opts):
def _listen_loop(sock, port, rpc_key, tracker_addr):
"""Lisenting loop of the server master."""
def _accept_conn(listen_sock, tracker_conn, ping_period=0.1):
def _accept_conn(listen_sock, tracker_conn, ping_period=2):
"""Accept connection from the other places.
Parameters
@ -89,22 +89,40 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
ping_period : float, optional
ping tracker every k seconds if no connection is accepted.
"""
old_keyset = set()
# Report resource to tracker
if tracker_conn:
matchkey = base.random_key(":")
matchkey = base.random_key(rpc_key + ":")
base.sendjson(tracker_conn,
[TrackerCode.PUT, rpc_key, (port, matchkey)])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
else:
matchkey = ""
matchkey = rpc_key
unmatch_period_count = 0
unmatch_timeout = 4
# Wait until we get a valid connection
while True:
if tracker_conn:
trigger = select.select([listen_sock], [], [], ping_period)
if not listen_sock in trigger[0]:
base.sendjson(tracker_conn, [TrackerCode.PING])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
base.sendjson(tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS])
pending_keys = base.recvjson(tracker_conn)
old_keyset.add(matchkey)
# if match key not in pending key set
# it means the key is aqquired by a client but not used.
if matchkey not in pending_keys:
unmatch_period_count += 1
else:
unmatch_period_count = 0
# regenerate match key if key is aqquired but not used for a while
if unmatch_period_count * ping_period > unmatch_timeout + ping_period:
logging.info("RPCServer: no incoming connections, regenerate key ...")
matchkey = base.random_key(rpc_key + ":", old_keyset)
base.sendjson(tracker_conn,
[TrackerCode.PUT, rpc_key, (port, matchkey)])
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
unmatch_period_count = 0
continue
conn, addr = listen_sock.accept()
magic = struct.unpack("@i", base.recvall(conn, 4))[0]
@ -114,7 +132,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
keylen = struct.unpack("@i", base.recvall(conn, 4))[0]
key = py_str(base.recvall(conn, keylen))
arr = key.split()
expect_header = "client:" + rpc_key + matchkey
expect_header = "client:" + matchkey
server_key = "server:" + rpc_key
if arr[0] != expect_header:
conn.sendall(struct.pack("@i", base.RPC_CODE_MISMATCH))

Просмотреть файл

@ -48,6 +48,8 @@ class TCPHandler(object):
def write_message(self, message, binary=True):
assert binary
if self._sock is None:
raise IOError("socket is already closed")
self._pending_write.append(message)
self._update_write()

Просмотреть файл

@ -92,7 +92,9 @@ class PriorityScheduler(Scheduler):
value = self._values.pop(0)
item = heapq.heappop(self._requests)
callback = item[-1]
if not callback(value):
if callback(value[1:]):
value[0].pending_matchkeys.remove(value[-1])
else:
self._values.append(value)
def put(self, value):
@ -124,6 +126,8 @@ class TCPEventHandler(tornado_util.TCPHandler):
self._addr = addr
self._init_req_nbytes = 4
self._info = {"addr": addr}
# list of pending match keys that has not been used.
self.pending_matchkeys = set()
self._tracker._connections.add(self)
def name(self):
@ -189,18 +193,27 @@ class TCPEventHandler(tornado_util.TCPHandler):
if code == TrackerCode.PUT:
key = args[1]
port, matchkey = args[2]
self._tracker.put(key, (self._addr[0], port, matchkey))
self.pending_matchkeys.add(matchkey)
self._tracker.put(key, (self, self._addr[0], port, matchkey))
self.ret_value(TrackerCode.SUCCESS)
elif code == TrackerCode.REQUEST:
key = args[1]
user = args[2]
priority = args[3]
def _cb(value):
self.ret_value([TrackerCode.SUCCESS, value])
# if the connection is already closed
if not self._sock:
return False
try:
self.ret_value([TrackerCode.SUCCESS, value])
except (socket.sock_error, IOError):
return False
return True
self._tracker.request(key, user, priority, _cb)
elif code == TrackerCode.PING:
self.ret_value(TrackerCode.SUCCESS)
elif code == TrackerCode.GET_PENDING_MATCHKEYS:
self.ret_value(list(self.pending_matchkeys))
elif code == TrackerCode.STOP:
# safe stop tracker
if self._tracker._stop_key == args[1]:

Просмотреть файл

@ -23,6 +23,11 @@ def check_server_drop():
tproxy = proxy.Proxy("localhost", 8881,
tracker_addr=("localhost", tserver.port))
tclient = rpc.connect_tracker("localhost", tserver.port)
server0 = rpc.Server(
"localhost", port=9099,
tracker_addr=("localhost", tserver.port),
key="abc")
server1 = rpc.Server(
"localhost", port=9099,
tracker_addr=("localhost", tserver.port),
@ -34,6 +39,10 @@ def check_server_drop():
"localhost", tproxy.port, is_proxy=True,
key="xyz1")
# Fault tolerence to un-handled requested value
_put(tclient, [TrackerCode.REQUEST, "abc", "", 1])
_put(tclient, [TrackerCode.REQUEST, "xyz1", "", 1])
# Fault tolerence to stale worker value
_put(tclient, [TrackerCode.PUT, "xyz", (server1.port, "abc")])
_put(tclient, [TrackerCode.PUT, "xyz", (server1.port, "abcxxx")])
@ -58,14 +67,21 @@ def check_server_drop():
assert f1(10) == 11
f1 = remote2.get_function("rpc.test2.addone")
assert f1(10) == 11
except tvm.TVMError as e:
pass
remote3 = tclient.request("abc")
f1 = remote3.get_function("rpc.test2.addone")
remote3 = tclient.request("xyz1")
f1 = remote3.get_function("rpc.test2.addone")
assert f1(10) == 11
check_timeout(0.01, 0.1)
check_timeout(2, 0)
tserver.terminate()
server2.terminate()
server0.terminate()
server1.terminate()
server2.terminate()
server3.terminate()
tproxy.terminate()
except ImportError: