[RPC] More robust tracker protocol (#1085)
* [RPC] More robust tracker protocol * fix normal rpc
This commit is contained in:
Родитель
2e17e85005
Коммит
1418134003
|
@ -34,6 +34,7 @@ class TrackerCode(object):
|
|||
REQUEST = 4
|
||||
UPDATE_INFO = 5
|
||||
SUMMARY = 6
|
||||
GET_PENDING_MATCHKEYS = 7
|
||||
|
||||
RPC_SESS_MASK = 128
|
||||
|
||||
|
|
|
@ -230,7 +230,7 @@ class TrackerSession(object):
|
|||
if value[0] != base.TrackerCode.SUCCESS:
|
||||
raise RuntimeError("Invalid return value %s" % str(value))
|
||||
url, port, matchkey = value[1]
|
||||
return connect(url, port, key + matchkey, session_timeout)
|
||||
return connect(url, port, matchkey, session_timeout)
|
||||
except socket.error as err:
|
||||
self.close()
|
||||
last_err = err
|
||||
|
|
|
@ -14,6 +14,7 @@ import socket
|
|||
import multiprocessing
|
||||
import errno
|
||||
import struct
|
||||
import time
|
||||
|
||||
try:
|
||||
import tornado
|
||||
|
@ -45,6 +46,7 @@ class ForwardHandler(object):
|
|||
self.rpc_key = None
|
||||
self.match_key = None
|
||||
self.forward_proxy = None
|
||||
self.alloc_time = None
|
||||
|
||||
def __del__(self):
|
||||
logging.info("Delete %s...", self.name())
|
||||
|
@ -237,6 +239,7 @@ class ProxyServerHandler(object):
|
|||
self.sock.fileno(), event_handler, self.loop.READ)
|
||||
self._client_pool = {}
|
||||
self._server_pool = {}
|
||||
self.timeout_alloc = 5
|
||||
self.timeout_client = timeout_client
|
||||
self.timeout_server = timeout_server
|
||||
# tracker information
|
||||
|
@ -245,8 +248,12 @@ class ProxyServerHandler(object):
|
|||
self._tracker_conn = None
|
||||
self._tracker_pending_puts = []
|
||||
self._key_set = set()
|
||||
self.update_tracker_period = 2
|
||||
if tracker_addr:
|
||||
logging.info("Tracker address:%s", str(tracker_addr))
|
||||
def _callback():
|
||||
self._update_tracker(True)
|
||||
self.loop.call_later(self.update_tracker_period, _callback)
|
||||
logging.info("RPCProxy: Websock port bind to %d", web_port)
|
||||
|
||||
def _on_event(self, _):
|
||||
|
@ -271,7 +278,22 @@ class ProxyServerHandler(object):
|
|||
rhs.send_data(lhs.rpc_key.encode("utf-8"))
|
||||
logging.info("Pairup connect %s and %s", lhs.name(), rhs.name())
|
||||
|
||||
def _update_tracker(self):
|
||||
def _regenerate_server_keys(self, keys):
|
||||
"""Regenerate keys for server pool"""
|
||||
keyset = set(self._server_pool.keys())
|
||||
new_keys = []
|
||||
# re-generate the server match key, so old information is invalidated.
|
||||
for key in keys:
|
||||
rpc_key, _ = key.split(":")
|
||||
handle = self._server_pool[key]
|
||||
del self._server_pool[key]
|
||||
new_key = base.random_key(rpc_key + ":", keyset)
|
||||
self._server_pool[new_key] = handle
|
||||
keyset.add(new_key)
|
||||
new_keys.append(new_key)
|
||||
return new_keys
|
||||
|
||||
def _update_tracker(self, period_update=False):
|
||||
"""Update information on tracker."""
|
||||
try:
|
||||
if self._tracker_conn is None:
|
||||
|
@ -285,13 +307,33 @@ class ProxyServerHandler(object):
|
|||
# just connect to tracker, need to update all keys
|
||||
self._tracker_pending_puts = self._server_pool.keys()
|
||||
|
||||
if self._tracker_conn and period_update:
|
||||
# periodically update tracker information
|
||||
# regenerate key if the key is not in tracker anymore
|
||||
# and there is no in-coming connection after timeout_alloc
|
||||
base.sendjson(self._tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS])
|
||||
pending_keys = set(base.recvjson(self._tracker_conn))
|
||||
update_keys = []
|
||||
for k, v in self._server_pool.items():
|
||||
if k not in pending_keys:
|
||||
if v.alloc_time is None:
|
||||
v.alloc_time = time.time()
|
||||
elif time.time() - v.alloc_time > self.timeout_alloc:
|
||||
update_keys.append(k)
|
||||
v.alloc_time = None
|
||||
if update_keys:
|
||||
logging.info("RPCProxy: No incoming conn on %s, regenerate keys...",
|
||||
str(update_keys))
|
||||
new_keys = self._regenerate_server_keys(update_keys)
|
||||
self._tracker_pending_puts += new_keys
|
||||
|
||||
need_update_info = False
|
||||
# report new connections
|
||||
for key in self._tracker_pending_puts:
|
||||
rpc_key, match_key = key.split(":")
|
||||
rpc_key = key.split(":")[0]
|
||||
base.sendjson(self._tracker_conn,
|
||||
[TrackerCode.PUT, rpc_key,
|
||||
(self._listen_port, ":" + match_key)])
|
||||
(self._listen_port, key)])
|
||||
assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS
|
||||
if rpc_key not in self._key_set:
|
||||
self._key_set.add(rpc_key)
|
||||
|
@ -305,24 +347,17 @@ class ProxyServerHandler(object):
|
|||
assert base.recvjson(self._tracker_conn) == TrackerCode.SUCCESS
|
||||
self._tracker_pending_puts = []
|
||||
except (socket.error, IOError) as err:
|
||||
retry_period = 5
|
||||
logging.info(
|
||||
"Lost tracker connection: %s, try reconnect in %g sec",
|
||||
str(err), retry_period)
|
||||
str(err), self.update_tracker_period)
|
||||
self._tracker_conn.close()
|
||||
self._tracker_conn = None
|
||||
new_pool = {}
|
||||
keyset = set(self._server_pool.keys())
|
||||
# re-generate the server match key, so old information is invalidated.
|
||||
for key, handle in self._server_pool.items():
|
||||
rpc_key, _ = key.split(":")
|
||||
key = base.random_key(rpc_key + ":", keyset)
|
||||
new_pool[key] = handle
|
||||
keyset.add(key)
|
||||
self._server_pool = new_pool
|
||||
self._regenerate_server_keys(self._server_pool.keys())
|
||||
|
||||
if period_update:
|
||||
def _callback():
|
||||
self._update_tracker()
|
||||
self.loop.call_later(retry_period, _callback)
|
||||
self._update_tracker(True)
|
||||
self.loop.call_later(self.update_tracker_period, _callback)
|
||||
|
||||
def _handler_ready_tracker_mode(self, handler):
|
||||
"""tracker mode to handle handler ready."""
|
||||
|
|
|
@ -6,7 +6,7 @@ Server is TCP based with the following protocol:
|
|||
- Initial handshake to the peer
|
||||
- [RPC_MAGIC, keysize(int32), key-bytes]
|
||||
- The key is in format
|
||||
- {server|client}:device-type[:matchkey] [-timeout=timeout]
|
||||
- {server|client}:device-type[:random-key] [-timeout=timeout]
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
|
@ -75,7 +75,7 @@ def _parse_server_opt(opts):
|
|||
|
||||
def _listen_loop(sock, port, rpc_key, tracker_addr):
|
||||
"""Lisenting loop of the server master."""
|
||||
def _accept_conn(listen_sock, tracker_conn, ping_period=0.1):
|
||||
def _accept_conn(listen_sock, tracker_conn, ping_period=2):
|
||||
"""Accept connection from the other places.
|
||||
|
||||
Parameters
|
||||
|
@ -89,22 +89,40 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
|
|||
ping_period : float, optional
|
||||
ping tracker every k seconds if no connection is accepted.
|
||||
"""
|
||||
old_keyset = set()
|
||||
# Report resource to tracker
|
||||
if tracker_conn:
|
||||
matchkey = base.random_key(":")
|
||||
matchkey = base.random_key(rpc_key + ":")
|
||||
base.sendjson(tracker_conn,
|
||||
[TrackerCode.PUT, rpc_key, (port, matchkey)])
|
||||
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
|
||||
else:
|
||||
matchkey = ""
|
||||
matchkey = rpc_key
|
||||
|
||||
unmatch_period_count = 0
|
||||
unmatch_timeout = 4
|
||||
# Wait until we get a valid connection
|
||||
while True:
|
||||
if tracker_conn:
|
||||
trigger = select.select([listen_sock], [], [], ping_period)
|
||||
if not listen_sock in trigger[0]:
|
||||
base.sendjson(tracker_conn, [TrackerCode.PING])
|
||||
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
|
||||
base.sendjson(tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS])
|
||||
pending_keys = base.recvjson(tracker_conn)
|
||||
old_keyset.add(matchkey)
|
||||
# if match key not in pending key set
|
||||
# it means the key is aqquired by a client but not used.
|
||||
if matchkey not in pending_keys:
|
||||
unmatch_period_count += 1
|
||||
else:
|
||||
unmatch_period_count = 0
|
||||
# regenerate match key if key is aqquired but not used for a while
|
||||
if unmatch_period_count * ping_period > unmatch_timeout + ping_period:
|
||||
logging.info("RPCServer: no incoming connections, regenerate key ...")
|
||||
matchkey = base.random_key(rpc_key + ":", old_keyset)
|
||||
base.sendjson(tracker_conn,
|
||||
[TrackerCode.PUT, rpc_key, (port, matchkey)])
|
||||
assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
|
||||
unmatch_period_count = 0
|
||||
continue
|
||||
conn, addr = listen_sock.accept()
|
||||
magic = struct.unpack("@i", base.recvall(conn, 4))[0]
|
||||
|
@ -114,7 +132,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
|
|||
keylen = struct.unpack("@i", base.recvall(conn, 4))[0]
|
||||
key = py_str(base.recvall(conn, keylen))
|
||||
arr = key.split()
|
||||
expect_header = "client:" + rpc_key + matchkey
|
||||
expect_header = "client:" + matchkey
|
||||
server_key = "server:" + rpc_key
|
||||
if arr[0] != expect_header:
|
||||
conn.sendall(struct.pack("@i", base.RPC_CODE_MISMATCH))
|
||||
|
|
|
@ -48,6 +48,8 @@ class TCPHandler(object):
|
|||
|
||||
def write_message(self, message, binary=True):
|
||||
assert binary
|
||||
if self._sock is None:
|
||||
raise IOError("socket is already closed")
|
||||
self._pending_write.append(message)
|
||||
self._update_write()
|
||||
|
||||
|
|
|
@ -92,7 +92,9 @@ class PriorityScheduler(Scheduler):
|
|||
value = self._values.pop(0)
|
||||
item = heapq.heappop(self._requests)
|
||||
callback = item[-1]
|
||||
if not callback(value):
|
||||
if callback(value[1:]):
|
||||
value[0].pending_matchkeys.remove(value[-1])
|
||||
else:
|
||||
self._values.append(value)
|
||||
|
||||
def put(self, value):
|
||||
|
@ -124,6 +126,8 @@ class TCPEventHandler(tornado_util.TCPHandler):
|
|||
self._addr = addr
|
||||
self._init_req_nbytes = 4
|
||||
self._info = {"addr": addr}
|
||||
# list of pending match keys that has not been used.
|
||||
self.pending_matchkeys = set()
|
||||
self._tracker._connections.add(self)
|
||||
|
||||
def name(self):
|
||||
|
@ -189,18 +193,27 @@ class TCPEventHandler(tornado_util.TCPHandler):
|
|||
if code == TrackerCode.PUT:
|
||||
key = args[1]
|
||||
port, matchkey = args[2]
|
||||
self._tracker.put(key, (self._addr[0], port, matchkey))
|
||||
self.pending_matchkeys.add(matchkey)
|
||||
self._tracker.put(key, (self, self._addr[0], port, matchkey))
|
||||
self.ret_value(TrackerCode.SUCCESS)
|
||||
elif code == TrackerCode.REQUEST:
|
||||
key = args[1]
|
||||
user = args[2]
|
||||
priority = args[3]
|
||||
def _cb(value):
|
||||
self.ret_value([TrackerCode.SUCCESS, value])
|
||||
# if the connection is already closed
|
||||
if not self._sock:
|
||||
return False
|
||||
try:
|
||||
self.ret_value([TrackerCode.SUCCESS, value])
|
||||
except (socket.sock_error, IOError):
|
||||
return False
|
||||
return True
|
||||
self._tracker.request(key, user, priority, _cb)
|
||||
elif code == TrackerCode.PING:
|
||||
self.ret_value(TrackerCode.SUCCESS)
|
||||
elif code == TrackerCode.GET_PENDING_MATCHKEYS:
|
||||
self.ret_value(list(self.pending_matchkeys))
|
||||
elif code == TrackerCode.STOP:
|
||||
# safe stop tracker
|
||||
if self._tracker._stop_key == args[1]:
|
||||
|
|
|
@ -23,6 +23,11 @@ def check_server_drop():
|
|||
tproxy = proxy.Proxy("localhost", 8881,
|
||||
tracker_addr=("localhost", tserver.port))
|
||||
tclient = rpc.connect_tracker("localhost", tserver.port)
|
||||
|
||||
server0 = rpc.Server(
|
||||
"localhost", port=9099,
|
||||
tracker_addr=("localhost", tserver.port),
|
||||
key="abc")
|
||||
server1 = rpc.Server(
|
||||
"localhost", port=9099,
|
||||
tracker_addr=("localhost", tserver.port),
|
||||
|
@ -34,6 +39,10 @@ def check_server_drop():
|
|||
"localhost", tproxy.port, is_proxy=True,
|
||||
key="xyz1")
|
||||
|
||||
# Fault tolerence to un-handled requested value
|
||||
_put(tclient, [TrackerCode.REQUEST, "abc", "", 1])
|
||||
_put(tclient, [TrackerCode.REQUEST, "xyz1", "", 1])
|
||||
|
||||
# Fault tolerence to stale worker value
|
||||
_put(tclient, [TrackerCode.PUT, "xyz", (server1.port, "abc")])
|
||||
_put(tclient, [TrackerCode.PUT, "xyz", (server1.port, "abcxxx")])
|
||||
|
@ -58,14 +67,21 @@ def check_server_drop():
|
|||
assert f1(10) == 11
|
||||
f1 = remote2.get_function("rpc.test2.addone")
|
||||
assert f1(10) == 11
|
||||
|
||||
except tvm.TVMError as e:
|
||||
pass
|
||||
remote3 = tclient.request("abc")
|
||||
f1 = remote3.get_function("rpc.test2.addone")
|
||||
remote3 = tclient.request("xyz1")
|
||||
f1 = remote3.get_function("rpc.test2.addone")
|
||||
assert f1(10) == 11
|
||||
|
||||
check_timeout(0.01, 0.1)
|
||||
check_timeout(2, 0)
|
||||
tserver.terminate()
|
||||
server2.terminate()
|
||||
server0.terminate()
|
||||
server1.terminate()
|
||||
server2.terminate()
|
||||
server3.terminate()
|
||||
tproxy.terminate()
|
||||
except ImportError:
|
||||
|
|
Загрузка…
Ссылка в новой задаче