From aee214453a410f9faafadc5b48035fa36c86957c Mon Sep 17 00:00:00 2001 From: Valentin Gosu Date: Wed, 9 Nov 2022 14:37:13 +0000 Subject: [PATCH] Bug 1799860 - Add WebSocketConnection helper class to help with sending multiple WS messages r=necko-reviewers,kershaw It also fixes an issue in NodeWebSocketServerCode.messageHandler where the response would always be sent to the last websocket client. Instead, we change the signature to be messageHandler(data, ws) and global.wsInputHandler(data, ws) where ws is the websocket that sent the data. Differential Revision: https://phabricator.services.mozilla.com/D161678 --- netwerk/test/unit/head_servers.js | 101 +++++++++++++- netwerk/test/unit/test_websocket_server.js | 155 ++++++++++----------- 2 files changed, 171 insertions(+), 85 deletions(-) diff --git a/netwerk/test/unit/head_servers.js b/netwerk/test/unit/head_servers.js index 8b29a96cb1fc..03bc2d624420 100644 --- a/netwerk/test/unit/head_servers.js +++ b/netwerk/test/unit/head_servers.js @@ -608,13 +608,13 @@ class NodeHTTP2ProxyServer extends BaseHTTPProxy { // websocket server class NodeWebSocketServerCode extends BaseNodeHTTPServerCode { - static messageHandler(data) { + static messageHandler(data, ws) { if (global.wsInputHandler) { - global.wsInputHandler(data); + global.wsInputHandler(data, ws); return; } - global.ws.send("test"); + ws.send("test"); } static async startServer(port) { @@ -634,8 +634,9 @@ class NodeWebSocketServerCode extends BaseNodeHTTPServerCode { WebSocket.Server = require(`${node_ws_root}/lib/websocket-server`); global.webSocketServer = new WebSocket.Server({ server: global.server }); global.webSocketServer.on("connection", function connection(ws) { - global.ws = ws; - ws.on("message", NodeWebSocketServerCode.messageHandler); + ws.on("message", data => + NodeWebSocketServerCode.messageHandler(data, ws) + ); }); await global.server.listen(port); @@ -695,11 +696,10 @@ class NodeWebSocketHttp2ServerCode extends BaseNodeHTTPServerCode { const ws = new WebSocket(null); stream.setNoDelay = () => {}; ws.setSocket(stream, Buffer.from(""), 100 * 1024 * 1024); - global.ws = ws; ws.on("message", data => { if (global.wsInputHandler) { - global.wsInputHandler(data); + global.wsInputHandler(data, ws); return; } @@ -773,3 +773,90 @@ function getTestServerCertificate() { } return null; } + +class WebSocketConnection { + constructor() { + this._openPromise = new Promise(resolve => { + this._openCallback = resolve; + }); + + this._stopPromise = new Promise(resolve => { + this._stopCallback = resolve; + }); + + this._msgPromise = new Promise(resolve => { + this._msgCallback = resolve; + }); + this._messages = []; + this._ws = null; + } + + get QueryInterface() { + return ChromeUtils.generateQI(["nsIWebSocketListener"]); + } + + onAcknowledge(aContext, aSize) {} + onBinaryMessageAvailable(aContext, aMsg) { + info(`received binary ${aMsg}`); + this._messages.push(aMsg); + this._msgCallback(); + } + onMessageAvailable(aContext, aMsg) {} + onServerClose(aContext, aCode, aReason) {} + onWebSocketListenerStart(aContext) {} + onStart(aContext) { + this._openCallback(); + } + onStop(aContext, aStatusCode) { + this._stopCallback({ status: aStatusCode }); + this._ws = null; + } + static makeWebSocketChan() { + let chan = Cc["@mozilla.org/network/protocol;1?name=wss"].createInstance( + Ci.nsIWebSocketChannel + ); + chan.initLoadInfo( + null, // aLoadingNode + Services.scriptSecurityManager.getSystemPrincipal(), + null, // aTriggeringPrincipal + Ci.nsILoadInfo.SEC_ALLOW_CROSS_ORIGIN_SEC_CONTEXT_IS_NULL, + Ci.nsIContentPolicy.TYPE_WEBSOCKET + ); + return chan; + } + // Returns a promise that resolves when the websocket channel is opened. + open(url) { + this._ws = WebSocketConnection.makeWebSocketChan(); + let uri = Services.io.newURI(url); + this._ws.asyncOpen(uri, url, {}, 0, this, null); + return this._openPromise; + } + // Closes the inner websocket. code and reason arguments are optional. + close(code, reason) { + this._ws.close(code || Ci.nsIWebSocketChannel.CLOSE_NORMAL, reason || ""); + } + // Sends a message to the server. + send(msg) { + this._ws.sendMsg(msg); + } + // Returns a promise that resolves when the channel's onStop is called. + // Promise resolves with an `{status}` object, where status is the + // result passed to onStop. + finished() { + return this._stopPromise; + } + + // Returned promise resolves with an array of received messages + // If messages have been received in the the past before calling + // receiveMessages, the promise will immediately resolve. Otherwise + // it will resolve when the first message is received. + async receiveMessages() { + await this._msgPromise; + this._msgPromise = new Promise(resolve => { + this._msgCallback = resolve; + }); + let messages = this._messages; + this._messages = []; + return messages; + } +} diff --git a/netwerk/test/unit/test_websocket_server.js b/netwerk/test/unit/test_websocket_server.js index 6eab69c96e2c..2d1533c10058 100644 --- a/netwerk/test/unit/test_websocket_server.js +++ b/netwerk/test/unit/test_websocket_server.js @@ -25,84 +25,85 @@ add_setup(async function setup() { }); }); -function WebSocketListener(closure, ws, sentMsg) { - this._closure = closure; - this._ws = ws; - this._sentMsg = sentMsg; -} - -WebSocketListener.prototype = { - _closure: null, - _ws: null, - _sentMsg: null, - _received: null, - QueryInterface: ChromeUtils.generateQI(["nsIWebSocketListener"]), - - onAcknowledge(aContext, aSize) {}, - onBinaryMessageAvailable(aContext, aMsg) { - this._received = aMsg; - this._ws.close(0, null); - }, - onMessageAvailable(aContext, aMsg) {}, - onServerClose(aContext, aCode, aReason) {}, - onSWebSocketListenertart(aContext) {}, - onStart(aContext) { - this._ws.sendMsg(this._sentMsg); - }, - onStop(aContext, aStatusCode) { - try { - this._closure(aStatusCode, this._received); - this._ws = null; - } catch (ex) { - do_throw("Error in closure function: " + ex); - } - }, -}; - -function makeWebSocketChan() { - let chan = Cc["@mozilla.org/network/protocol;1?name=wss"].createInstance( - Ci.nsIWebSocketChannel - ); - chan.initLoadInfo( - null, // aLoadingNode - Services.scriptSecurityManager.getSystemPrincipal(), - null, // aTriggeringPrincipal - Ci.nsILoadInfo.SEC_ALLOW_CROSS_ORIGIN_SEC_CONTEXT_IS_NULL, - Ci.nsIContentPolicy.TYPE_WEBSOCKET - ); - return chan; -} - -function channelOpenPromise(chan, url, msg) { - let uri = Services.io.newURI(url); - return new Promise(resolve => { - function finish(status, result) { - resolve([status, result]); - } - chan.asyncOpen( - uri, - url, - {}, - 0, - new WebSocketListener(finish, chan, msg), - null - ); - }); +async function channelOpenPromise(url, msg) { + let conn = new WebSocketConnection(); + await conn.open(url); + conn.send(msg); + let res = await conn.receiveMessages(); + conn.close(); + let { status } = await conn.finished(); + return [status, res]; } add_task(async function test_websocket() { let wss = new NodeWebSocketServer(); await wss.start(); + registerCleanupFunction(async () => wss.stop()); Assert.notEqual(wss.port(), null); - await wss.registerMessageHandler(data => { - global.ws.send(data); + await wss.registerMessageHandler((data, ws) => { + ws.send(data); }); - let chan = makeWebSocketChan(); let url = `wss://localhost:${wss.port()}`; const msg = "test websocket"; - let [status, res] = await channelOpenPromise(chan, url, msg); + + let conn = new WebSocketConnection(); + await conn.open(url); + conn.send(msg); + let mess1 = await conn.receiveMessages(); + Assert.deepEqual(mess1, [msg]); + + // Now send 3 more, and check that we received all of them + conn.send(msg); + conn.send(msg); + conn.send(msg); + let mess2 = []; + while (mess2.length < 3) { + // receive could return 1, 2 or all 3 replies. + mess2 = mess2.concat(await conn.receiveMessages()); + } + Assert.deepEqual(mess2, [msg, msg, msg]); + + conn.close(); + let { status } = await conn.finished(); + Assert.equal(status, Cr.NS_OK); - Assert.equal(res, msg); + await wss.stop(); +}); + +add_task(async function test_two_clients() { + let wss = new NodeWebSocketServer(); + await wss.start(); + registerCleanupFunction(async () => wss.stop()); + Assert.notEqual(wss.port(), null); + await wss.registerMessageHandler((data, ws) => { + ws.send(data); + }); + let url = `wss://localhost:${wss.port()}`; + + let conn1 = new WebSocketConnection(); + await conn1.open(url); + + let conn2 = new WebSocketConnection(); + await conn2.open(url); + + conn1.send("msg1"); + conn2.send("msg2"); + + let mess2 = await conn2.receiveMessages(); + Assert.deepEqual(mess2, ["msg2"]); + + conn1.send("msg1 again"); + let mess1 = []; + while (mess1.length < 2) { + // receive could return only the fist or both replies. + mess1 = mess1.concat(await conn1.receiveMessages()); + } + Assert.deepEqual(mess1, ["msg1", "msg1 again"]); + + conn1.close(); + conn2.close(); + Assert.deepEqual({ status: Cr.NS_OK }, await conn1.finished()); + Assert.deepEqual({ status: Cr.NS_OK }, await conn2.finished()); await wss.stop(); }); @@ -119,16 +120,15 @@ add_task(async function test_ws_through_https_proxy() { let wss = new NodeWebSocketServer(); await wss.start(); Assert.notEqual(wss.port(), null); - await wss.registerMessageHandler(data => { - global.ws.send(data); + await wss.registerMessageHandler((data, ws) => { + ws.send(data); }); - let chan = makeWebSocketChan(); let url = `wss://localhost:${wss.port()}`; const msg = "test websocket through proxy"; - let [status, res] = await channelOpenPromise(chan, url, msg); + let [status, res] = await channelOpenPromise(url, msg); Assert.equal(status, Cr.NS_OK); - Assert.equal(res, msg); + Assert.deepEqual(res, [msg]); await proxy.stop(); await wss.stop(); @@ -139,14 +139,13 @@ add_task(async function test_websocket_over_h2() { let wss = new NodeWebSocketHttp2Server(); await wss.start(); Assert.notEqual(wss.port(), null); - await wss.registerMessageHandler(data => { - global.ws.send(data); + await wss.registerMessageHandler((data, ws) => { + ws.send(data); }); - let chan = makeWebSocketChan(); let url = `wss://localhost:${wss.port()}`; const msg = "test websocket"; - let [status, res] = await channelOpenPromise(chan, url, msg); + let [status, res] = await channelOpenPromise(url, msg); Assert.equal(status, Cr.NS_OK); - Assert.equal(res, msg); + Assert.deepEqual(res, [msg]); await wss.stop(); });