[RPC] Allow back pressure from writer (#250)
* [RPC] Allow backpressure from writer * fix * fix
This commit is contained in:
Родитель
c6d4f5af20
Коммит
8d241b9d86
|
@ -235,8 +235,9 @@ class RequestHandler(tornado.web.RequestHandler):
|
|||
self.page = open(kwargs.pop("file_path")).read()
|
||||
web_port = kwargs.pop("rpc_web_port", None)
|
||||
if web_port:
|
||||
self.page.replace(r"ws://localhost:9888/ws",
|
||||
r"ws://localhost:%d/ws" % web_port)
|
||||
self.page = self.page.replace(
|
||||
"ws://localhost:9190/ws",
|
||||
"ws://localhost:%d/ws" % web_port)
|
||||
super(RequestHandler, self).__init__(*args, **kwargs)
|
||||
|
||||
def data_received(self, _):
|
||||
|
@ -468,14 +469,14 @@ def websocket_proxy_server(url, key=""):
|
|||
logging.info("Connection established")
|
||||
msg = msg[4:]
|
||||
if msg:
|
||||
on_message(bytearray(msg))
|
||||
on_message(bytearray(msg), 3)
|
||||
|
||||
while True:
|
||||
try:
|
||||
msg = yield conn.read_message()
|
||||
if msg is None:
|
||||
break
|
||||
on_message(bytearray(msg))
|
||||
on_message(bytearray(msg), 3)
|
||||
except websocket.WebSocketClosedError as err:
|
||||
break
|
||||
logging.info("WebSocketProxyServer closed...")
|
||||
|
|
|
@ -29,7 +29,7 @@ def main():
|
|||
help='the hostname of the server')
|
||||
parser.add_argument('--port', type=int, default=9090,
|
||||
help='The port of the PRC')
|
||||
parser.add_argument('--web-port', type=int, default=9888,
|
||||
parser.add_argument('--web-port', type=int, default=9190,
|
||||
help='The port of the http/websocket server')
|
||||
parser.add_argument('--example-rpc', type=bool, default=False,
|
||||
help='Whether to switch on example rpc mode')
|
||||
|
|
|
@ -32,18 +32,18 @@ class CallbackChannel final : public RPCChannel {
|
|||
PackedFunc fsend_;
|
||||
};
|
||||
|
||||
PackedFunc CreateEvenDrivenServer(PackedFunc fsend, std::string name) {
|
||||
PackedFunc CreateEventDrivenServer(PackedFunc fsend, std::string name) {
|
||||
std::unique_ptr<CallbackChannel> ch(new CallbackChannel(fsend));
|
||||
std::shared_ptr<RPCSession> sess = RPCSession::Create(std::move(ch), name);
|
||||
return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
|
||||
bool ret = sess->ServerOnMessageHandler(args[0]);
|
||||
int ret = sess->ServerEventHandler(args[0], args[1]);
|
||||
*rv = ret;
|
||||
});
|
||||
}
|
||||
|
||||
TVM_REGISTER_GLOBAL("contrib.rpc._CreateEventDrivenServer")
|
||||
.set_body([](TVMArgs args, TVMRetValue* rv) {
|
||||
*rv = CreateEvenDrivenServer(args[0], args[1]);
|
||||
*rv = CreateEventDrivenServer(args[0], args[1]);
|
||||
});
|
||||
} // namespace runtime
|
||||
} // namespace tvm
|
||||
|
|
|
@ -752,19 +752,23 @@ void RPCSession::ServerLoop() {
|
|||
channel_.reset(nullptr);
|
||||
}
|
||||
|
||||
|
||||
bool RPCSession::ServerOnMessageHandler(const std::string& bytes) {
|
||||
int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) {
|
||||
std::lock_guard<std::recursive_mutex> lock(mutex_);
|
||||
reader_.Write(bytes.c_str(), bytes.length());
|
||||
TVMRetValue rv;
|
||||
RPCCode code = handler_->HandleNextEvent(&rv, false, nullptr);
|
||||
while (writer_.bytes_available() != 0) {
|
||||
RPCCode code = RPCCode::kNone;
|
||||
if (bytes.length() != 0) {
|
||||
reader_.Write(bytes.c_str(), bytes.length());
|
||||
TVMRetValue rv;
|
||||
code = handler_->HandleNextEvent(&rv, false, nullptr);
|
||||
}
|
||||
if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) {
|
||||
writer_.ReadWithCallback([this](const void *data, size_t size) {
|
||||
return channel_->Send(data, size);
|
||||
}, writer_.bytes_available());
|
||||
}
|
||||
CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck);
|
||||
return code != RPCCode::kShutdown;
|
||||
if (code == RPCCode::kShutdown) return 0;
|
||||
if (writer_.bytes_available() != 0) return 2;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Get remote function with name
|
||||
|
|
|
@ -86,13 +86,18 @@ class RPCSession {
|
|||
* \brief Message handling function for event driven server.
|
||||
* Called when the server receives a message.
|
||||
* Event driven handler will never call recv on the channel
|
||||
* and always relies on the ServerOnMessageHandler
|
||||
* and always relies on the ServerEventHandler.
|
||||
* to receive the data.
|
||||
*
|
||||
* \param bytes The incoming bytes.
|
||||
* \return Whether need continue running, return false when receive a shutdown message.
|
||||
* \param in_bytes The incoming bytes.
|
||||
* \param event_flag 1: read_available, 2: write_avaiable.
|
||||
* \return State flag.
|
||||
* 1: continue running, no need to write,
|
||||
* 2: need to write
|
||||
* 0: shutdown
|
||||
*/
|
||||
bool ServerOnMessageHandler(const std::string& bytes);
|
||||
int ServerEventHandler(const std::string& in_bytes,
|
||||
int event_flag);
|
||||
/*!
|
||||
* \brief Call into remote function
|
||||
* \param handle The function handle
|
||||
|
@ -161,7 +166,7 @@ class RPCSession {
|
|||
return table_index_;
|
||||
}
|
||||
/*!
|
||||
* \brief Create a RPC session with given socket
|
||||
* \brief Create a RPC session with given channel.
|
||||
* \param channel The communication channel.
|
||||
* \param name The name of the session, used for debug
|
||||
* \return The session.
|
||||
|
|
|
@ -5,7 +5,7 @@ import time
|
|||
import multiprocessing
|
||||
from tvm.contrib import rpc
|
||||
|
||||
def rpc_proxy_test():
|
||||
def rpc_proxy_check():
|
||||
"""This is a simple test function for RPC Proxy
|
||||
|
||||
It is not included as nosetests, because:
|
||||
|
@ -47,4 +47,4 @@ def rpc_proxy_test():
|
|||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
rpc_proxy_test()
|
||||
rpc_proxy_check()
|
||||
|
|
|
@ -31,7 +31,7 @@
|
|||
<li> run "python tests/web/websock_rpc_test.py" to run the rpc client.
|
||||
</ul>
|
||||
<h2>Options</h2>
|
||||
Proxy URL<input name="proxyurl" id="proxyURL" type="text" value="ws://localhost:9888/ws"><br>
|
||||
Proxy URL<input name="proxyurl" id="proxyURL" type="text" value="ws://localhost:9190/ws"><br>
|
||||
RPC Server Key<input name="serverkey" id="proxyKey" type="text" value="js"><br>
|
||||
<button onclick="connect_rpc()">Connect To Proxy</button>
|
||||
<button onclick="clear_log()">Clear Log</button>
|
||||
|
|
|
@ -9,6 +9,6 @@ var Module = require("../lib/libtvm_web_runtime.js");
|
|||
const tvm_runtime = require("../web/tvm_runtime.js");
|
||||
const tvm = tvm_runtime.create(Module);
|
||||
|
||||
var websock_proxy = "ws://localhost:9888/ws";
|
||||
var websock_proxy = "ws://localhost:9190/ws";
|
||||
var num_sess = 100;
|
||||
tvm.startRPCServer(websock_proxy, "js", num_sess)
|
||||
|
|
|
@ -503,7 +503,7 @@ var tvm_runtime = tvm_runtime || {};
|
|||
* @return {boolean} Whether f is PackedFunc
|
||||
*/
|
||||
this.isPackedFunc = function(f) {
|
||||
return (typeof f._tvm_function !== "undefined");
|
||||
return (typeof f == "function") && f.hasOwnProperty("_tvm_function");
|
||||
};
|
||||
var isPackedFunc = this.isPackedFunc;
|
||||
/**
|
||||
|
@ -633,7 +633,7 @@ var tvm_runtime = tvm_runtime || {};
|
|||
}
|
||||
} else if (tp == "number") {
|
||||
this.setDouble(i, v);
|
||||
} else if (typeof v._tvm_function !== "undefined") {
|
||||
} else if (tp == "function" && v.hasOwnProperty("_tvm_function")) {
|
||||
this.setString(i, v._tvm_function.handle, kFuncHandle);
|
||||
} else if (v === null) {
|
||||
this.setHandle(i, 0, kNull);
|
||||
|
@ -907,12 +907,15 @@ var tvm_runtime = tvm_runtime || {};
|
|||
}
|
||||
logging(server_name + "init end...");
|
||||
if (msg.length > 4) {
|
||||
if (!message_handler(new Uint8Array(event.data, 4, msg.length -4))) {
|
||||
if (message_handler(
|
||||
new Uint8Array(event.data, 4, msg.length -4),
|
||||
new TVMConstant(3, "int32")) == 0) {
|
||||
socket.close();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!message_handler(new Uint8Array(event.data))) {
|
||||
if (message_handler(new Uint8Array(event.data),
|
||||
new TVMConstant(3, "int32")) == 0) {
|
||||
socket.close();
|
||||
}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче