зеркало из https://github.com/microsoft/nni.git
Merge command channel APIs and add unit tests (#5450)
This commit is contained in:
Родитель
1f6aedc48f
Коммит
cf5fabd968
|
@ -0,0 +1,4 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
from .channel import WsChannelClient
|
|
@ -0,0 +1,106 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import nni
|
||||||
|
from ..base import Command, CommandChannel
|
||||||
|
from .connection import WsConnection
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class WsChannelClient(CommandChannel):
|
||||||
|
def __init__(self, url: str):
|
||||||
|
self._url: str = url
|
||||||
|
self._closing: bool = False
|
||||||
|
self._conn: WsConnection | None = None
|
||||||
|
|
||||||
|
def connect(self) -> None:
|
||||||
|
_logger.debug(f'Connect to {self._url}')
|
||||||
|
assert not self._closing
|
||||||
|
self._ensure_conn()
|
||||||
|
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
_logger.debug(f'Disconnect from {self._url}')
|
||||||
|
self.send({'type': '_bye_'})
|
||||||
|
self._closing = True
|
||||||
|
self._close_conn('client intentionally close')
|
||||||
|
|
||||||
|
def send(self, command: Command) -> None:
|
||||||
|
if self._closing:
|
||||||
|
return
|
||||||
|
_logger.debug(f'Send {command}')
|
||||||
|
msg = nni.dump(command)
|
||||||
|
for i in range(5):
|
||||||
|
try:
|
||||||
|
conn = self._ensure_conn()
|
||||||
|
conn.send(msg)
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
_logger.exception(f'Failed to send command. Retry in {i}s')
|
||||||
|
self._terminate_conn('send fail')
|
||||||
|
time.sleep(i)
|
||||||
|
_logger.warning(f'Failed to send command {command}. Last retry')
|
||||||
|
conn = self._ensure_conn()
|
||||||
|
conn.send(msg)
|
||||||
|
|
||||||
|
def receive(self) -> Command | None:
|
||||||
|
while True:
|
||||||
|
if self._closing:
|
||||||
|
return None
|
||||||
|
msg = self._receive_msg()
|
||||||
|
if msg is None:
|
||||||
|
return None
|
||||||
|
command = nni.load(msg)
|
||||||
|
if command['type'] == '_nop_':
|
||||||
|
continue
|
||||||
|
if command['type'] == '_bye_':
|
||||||
|
reason = command.get('reason')
|
||||||
|
_logger.debug(f'Server close connection: {reason}')
|
||||||
|
self._closing = True
|
||||||
|
self._close_conn('server intentionally close')
|
||||||
|
return None
|
||||||
|
return command
|
||||||
|
|
||||||
|
def _ensure_conn(self) -> WsConnection:
|
||||||
|
if self._conn is None and not self._closing:
|
||||||
|
self._conn = WsConnection(self._url)
|
||||||
|
self._conn.connect()
|
||||||
|
_logger.debug('Connected')
|
||||||
|
return self._conn # type: ignore
|
||||||
|
|
||||||
|
def _close_conn(self, reason: str) -> None:
|
||||||
|
if self._conn is not None:
|
||||||
|
try:
|
||||||
|
self._conn.disconnect(reason)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._conn = None
|
||||||
|
|
||||||
|
def _terminate_conn(self, reason: str) -> None:
|
||||||
|
if self._conn is not None:
|
||||||
|
try:
|
||||||
|
self._conn.terminate(reason)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._conn = None
|
||||||
|
|
||||||
|
def _receive_msg(self) -> str | None:
|
||||||
|
for i in range(5):
|
||||||
|
try:
|
||||||
|
conn = self._ensure_conn()
|
||||||
|
msg = conn.receive()
|
||||||
|
_logger.debug(f'Receive {msg}')
|
||||||
|
if not self._closing:
|
||||||
|
assert msg is not None
|
||||||
|
return msg
|
||||||
|
except Exception:
|
||||||
|
_logger.exception(f'Failed to receive command. Retry in {i}s')
|
||||||
|
self._terminate_conn('receive fail')
|
||||||
|
time.sleep(i)
|
||||||
|
_logger.warning(f'Failed to receive command. Last retry')
|
||||||
|
conn = self._ensure_conn()
|
||||||
|
conn.receive()
|
|
@ -0,0 +1,139 @@
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Synchronized and object-oriented WebSocket class.
|
||||||
|
|
||||||
|
WebSocket guarantees that messages will not be divided at API level.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
__all__ = ['WsConnection']
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from threading import Lock, Thread
|
||||||
|
from typing import Any, Type
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# the singleton event loop
|
||||||
|
_event_loop: asyncio.AbstractEventLoop = None # type: ignore
|
||||||
|
_event_loop_lock: Lock = Lock()
|
||||||
|
_event_loop_refcnt: int = 0 # number of connected websockets
|
||||||
|
|
||||||
|
class WsConnection:
|
||||||
|
"""
|
||||||
|
A WebSocket connection.
|
||||||
|
|
||||||
|
Call :meth:`connect` before :meth:`send` and :meth:`receive`.
|
||||||
|
|
||||||
|
All methods are thread safe.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
url
|
||||||
|
The WebSocket URL.
|
||||||
|
For tuner command channel it should be something like ``ws://localhost:8080/tuner``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ConnectionClosed: Type[Exception] = websockets.ConnectionClosed # type: ignore
|
||||||
|
|
||||||
|
def __init__(self, url: str):
|
||||||
|
self._url: str = url
|
||||||
|
self._ws: Any = None # the library does not provide type hints
|
||||||
|
|
||||||
|
def connect(self) -> None:
|
||||||
|
global _event_loop, _event_loop_refcnt
|
||||||
|
with _event_loop_lock:
|
||||||
|
_event_loop_refcnt += 1
|
||||||
|
if _event_loop is None:
|
||||||
|
_logger.debug('Starting event loop.')
|
||||||
|
# following line must be outside _run_event_loop
|
||||||
|
# because _wait() might be executed before first line of the child thread
|
||||||
|
_event_loop = asyncio.new_event_loop()
|
||||||
|
thread = Thread(target=_run_event_loop, name='NNI-WebSocketEventLoop', daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
_logger.debug(f'Connecting to {self._url}')
|
||||||
|
self._ws = _wait(_connect_async(self._url))
|
||||||
|
_logger.debug(f'Connected.')
|
||||||
|
|
||||||
|
def disconnect(self, reason: str | None = None, code: int | None = None) -> None:
|
||||||
|
if self._ws is None:
|
||||||
|
_logger.debug('disconnect: No connection.')
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
_wait(self._ws.close(code or 4000, reason))
|
||||||
|
_logger.debug('Connection closed by client.')
|
||||||
|
except Exception as e:
|
||||||
|
_logger.warning(f'Failed to close connection: {repr(e)}')
|
||||||
|
self._ws = None
|
||||||
|
_decrease_refcnt()
|
||||||
|
|
||||||
|
def terminate(self, reason: str | None = None) -> None:
|
||||||
|
if self._ws is None:
|
||||||
|
_logger.debug('terminate: No connection.')
|
||||||
|
return
|
||||||
|
self.disconnect(reason, 4001)
|
||||||
|
|
||||||
|
def send(self, message: str) -> None:
|
||||||
|
_logger.debug(f'Sending {message}')
|
||||||
|
try:
|
||||||
|
_wait(self._ws.send(message))
|
||||||
|
except websockets.ConnectionClosed: # type: ignore
|
||||||
|
_logger.debug('Connection closed by server.')
|
||||||
|
self._ws = None
|
||||||
|
_decrease_refcnt()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def receive(self) -> str | None:
|
||||||
|
"""
|
||||||
|
Return received message;
|
||||||
|
or return ``None`` if the connection has been closed by peer.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
msg = _wait(self._ws.recv())
|
||||||
|
_logger.debug(f'Received {msg}')
|
||||||
|
except websockets.ConnectionClosed: # type: ignore
|
||||||
|
_logger.debug('Connection closed by server.')
|
||||||
|
self._ws = None
|
||||||
|
_decrease_refcnt()
|
||||||
|
raise
|
||||||
|
|
||||||
|
# seems the library will inference whether it's text or binary, so we don't have guarantee
|
||||||
|
if isinstance(msg, bytes):
|
||||||
|
return msg.decode()
|
||||||
|
else:
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def _wait(coro):
|
||||||
|
# Synchronized version of "await".
|
||||||
|
future = asyncio.run_coroutine_threadsafe(coro, _event_loop)
|
||||||
|
return future.result()
|
||||||
|
|
||||||
|
def _run_event_loop() -> None:
|
||||||
|
# A separate thread to run the event loop.
|
||||||
|
# The event loop itself is blocking, and send/receive are also blocking,
|
||||||
|
# so they must run in different threads.
|
||||||
|
asyncio.set_event_loop(_event_loop)
|
||||||
|
_event_loop.run_forever()
|
||||||
|
_logger.debug('Event loop stopped.')
|
||||||
|
|
||||||
|
async def _connect_async(url):
|
||||||
|
# Theoretically this function is meaningless and one can directly use `websockets.connect(url)`,
|
||||||
|
# but it will not work, raising "TypeError: A coroutine object is required".
|
||||||
|
# Seems a design flaw in websockets library.
|
||||||
|
return await websockets.connect(url, max_size=None) # type: ignore
|
||||||
|
|
||||||
|
def _decrease_refcnt() -> None:
|
||||||
|
global _event_loop, _event_loop_refcnt
|
||||||
|
with _event_loop_lock:
|
||||||
|
_event_loop_refcnt -= 1
|
||||||
|
if _event_loop_refcnt == 0:
|
||||||
|
_event_loop.call_soon_threadsafe(_event_loop.stop)
|
||||||
|
_event_loop = None # type: ignore
|
|
@ -1,133 +1,4 @@
|
||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT license.
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
"""
|
from ..command_channel.websocket.connection import WsConnection as WebSocket # pylint: disable=unused-import
|
||||||
Synchronized and object-oriented WebSocket class.
|
|
||||||
|
|
||||||
WebSocket guarantees that messages will not be divided at API level.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
__all__ = ['WebSocket']
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from threading import Lock, Thread
|
|
||||||
from typing import Any, Type
|
|
||||||
|
|
||||||
import websockets
|
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# the singleton event loop
|
|
||||||
_event_loop: asyncio.AbstractEventLoop = None # type: ignore
|
|
||||||
_event_loop_lock: Lock = Lock()
|
|
||||||
_event_loop_refcnt: int = 0 # number of connected websockets
|
|
||||||
|
|
||||||
class WebSocket:
|
|
||||||
"""
|
|
||||||
A WebSocket connection.
|
|
||||||
|
|
||||||
Call :meth:`connect` before :meth:`send` and :meth:`receive`.
|
|
||||||
|
|
||||||
All methods are thread safe.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
url
|
|
||||||
The WebSocket URL.
|
|
||||||
For tuner command channel it should be something like ``ws://localhost:8080/tuner``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
ConnectionClosed: Type[Exception] = websockets.ConnectionClosed # type: ignore
|
|
||||||
|
|
||||||
def __init__(self, url: str):
|
|
||||||
self._url: str = url
|
|
||||||
self._ws: Any = None # the library does not provide type hints
|
|
||||||
|
|
||||||
def connect(self) -> None:
|
|
||||||
global _event_loop, _event_loop_refcnt
|
|
||||||
with _event_loop_lock:
|
|
||||||
_event_loop_refcnt += 1
|
|
||||||
if _event_loop is None:
|
|
||||||
_logger.debug('Starting event loop.')
|
|
||||||
# following line must be outside _run_event_loop
|
|
||||||
# because _wait() might be executed before first line of the child thread
|
|
||||||
_event_loop = asyncio.new_event_loop()
|
|
||||||
thread = Thread(target=_run_event_loop, name='NNI-WebSocketEventLoop', daemon=True)
|
|
||||||
thread.start()
|
|
||||||
|
|
||||||
_logger.debug(f'Connecting to {self._url}')
|
|
||||||
self._ws = _wait(_connect_async(self._url))
|
|
||||||
_logger.debug(f'Connected.')
|
|
||||||
|
|
||||||
def disconnect(self) -> None:
|
|
||||||
if self._ws is None:
|
|
||||||
_logger.debug('disconnect: No connection.')
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
_wait(self._ws.close())
|
|
||||||
_logger.debug('Connection closed by client.')
|
|
||||||
except Exception as e:
|
|
||||||
_logger.warning(f'Failed to close connection: {repr(e)}')
|
|
||||||
self._ws = None
|
|
||||||
_decrease_refcnt()
|
|
||||||
|
|
||||||
def send(self, message: str) -> None:
|
|
||||||
_logger.debug(f'Sending {message}')
|
|
||||||
try:
|
|
||||||
_wait(self._ws.send(message))
|
|
||||||
except websockets.ConnectionClosed: # type: ignore
|
|
||||||
_logger.debug('Connection closed by server.')
|
|
||||||
self._ws = None
|
|
||||||
_decrease_refcnt()
|
|
||||||
raise
|
|
||||||
|
|
||||||
def receive(self) -> str | None:
|
|
||||||
"""
|
|
||||||
Return received message;
|
|
||||||
or return ``None`` if the connection has been closed by peer.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
msg = _wait(self._ws.recv())
|
|
||||||
_logger.debug(f'Received {msg}')
|
|
||||||
except websockets.ConnectionClosed: # type: ignore
|
|
||||||
_logger.debug('Connection closed by server.')
|
|
||||||
self._ws = None
|
|
||||||
_decrease_refcnt()
|
|
||||||
raise
|
|
||||||
|
|
||||||
# seems the library will inference whether it's text or binary, so we don't have guarantee
|
|
||||||
if isinstance(msg, bytes):
|
|
||||||
return msg.decode()
|
|
||||||
else:
|
|
||||||
return msg
|
|
||||||
|
|
||||||
def _wait(coro):
|
|
||||||
# Synchronized version of "await".
|
|
||||||
future = asyncio.run_coroutine_threadsafe(coro, _event_loop)
|
|
||||||
return future.result()
|
|
||||||
|
|
||||||
def _run_event_loop() -> None:
|
|
||||||
# A separate thread to run the event loop.
|
|
||||||
# The event loop itself is blocking, and send/receive are also blocking,
|
|
||||||
# so they must run in different threads.
|
|
||||||
asyncio.set_event_loop(_event_loop)
|
|
||||||
_event_loop.run_forever()
|
|
||||||
_logger.debug('Event loop stopped.')
|
|
||||||
|
|
||||||
async def _connect_async(url):
|
|
||||||
# Theoretically this function is meaningless and one can directly use `websockets.connect(url)`,
|
|
||||||
# but it will not work, raising "TypeError: A coroutine object is required".
|
|
||||||
# Seems a design flaw in websockets library.
|
|
||||||
return await websockets.connect(url, max_size=None) # type: ignore
|
|
||||||
|
|
||||||
def _decrease_refcnt() -> None:
|
|
||||||
global _event_loop, _event_loop_refcnt
|
|
||||||
with _event_loop_lock:
|
|
||||||
_event_loop_refcnt -= 1
|
|
||||||
if _event_loop_refcnt == 0:
|
|
||||||
_event_loop.call_soon_threadsafe(_event_loop.stop)
|
|
||||||
_event_loop = None # type: ignore
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ async def read_stdin():
|
||||||
line = line.decode().strip()
|
line = line.decode().strip()
|
||||||
_debug(f'read from stdin: {line}')
|
_debug(f'read from stdin: {line}')
|
||||||
if line == '_close_':
|
if line == '_close_':
|
||||||
exit()
|
break
|
||||||
await _ws.send(line)
|
await _ws.send(line)
|
||||||
|
|
||||||
async def ws_server():
|
async def ws_server():
|
||||||
|
@ -46,9 +46,12 @@ async def on_connect(ws):
|
||||||
global _ws
|
global _ws
|
||||||
_debug('connected')
|
_debug('connected')
|
||||||
_ws = ws
|
_ws = ws
|
||||||
async for msg in ws:
|
try:
|
||||||
_debug(f'received from websocket: {msg}')
|
async for msg in ws:
|
||||||
print(msg, flush=True)
|
_debug(f'received from websocket: {msg}')
|
||||||
|
print(msg, flush=True)
|
||||||
|
except websockets.exceptions.ConnectionClosedError:
|
||||||
|
pass
|
||||||
|
|
||||||
def _debug(msg):
|
def _debug(msg):
|
||||||
#sys.stderr.write(f'[server-debug] {msg}\n')
|
#sys.stderr.write(f'[server-debug] {msg}\n')
|
||||||
|
|
|
@ -11,21 +11,21 @@ from subprocess import Popen, PIPE
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from nni.runtime.tuner_command_channel.websocket import WebSocket
|
from nni.runtime.command_channel.websocket import WsChannelClient
|
||||||
|
|
||||||
# A helper server that connects its stdio to incoming WebSocket.
|
# A helper server that connects its stdio to incoming WebSocket.
|
||||||
_server = None
|
_server = None
|
||||||
_client = None
|
_client = None
|
||||||
|
|
||||||
_command1 = 'T_hello world'
|
_command1 = {'type': 'ut_command', 'value': 123}
|
||||||
_command2 = 'T_你好'
|
_command2 = {'type': 'ut_command', 'value': '你好'}
|
||||||
|
|
||||||
## test cases ##
|
## test cases ##
|
||||||
|
|
||||||
def test_connect():
|
def test_connect():
|
||||||
global _client
|
global _client
|
||||||
port = _init()
|
port = _init()
|
||||||
_client = WebSocket(f'ws://localhost:{port}')
|
_client = WsChannelClient(f'ws://localhost:{port}')
|
||||||
_client.connect()
|
_client.connect()
|
||||||
|
|
||||||
def test_send():
|
def test_send():
|
||||||
|
@ -34,16 +34,16 @@ def test_send():
|
||||||
_client.send(_command2)
|
_client.send(_command2)
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
|
|
||||||
sent1 = _server.stdout.readline().strip()
|
sent1 = json.loads(_server.stdout.readline())
|
||||||
assert sent1 == _command1, sent1
|
assert sent1 == _command1, sent1
|
||||||
|
|
||||||
sent2 = _server.stdout.readline().strip()
|
sent2 = json.loads(_server.stdout.readline().strip())
|
||||||
assert sent2 == _command2, sent2
|
assert sent2 == _command2, sent2
|
||||||
|
|
||||||
def test_receive():
|
def test_receive():
|
||||||
# Send commands to server via stdin, and get them back via channel.
|
# Send commands to server via stdin, and get them back via channel.
|
||||||
_server.stdin.write(_command1 + '\n')
|
_server.stdin.write(json.dumps(_command1) + '\n')
|
||||||
_server.stdin.write(_command2 + '\n')
|
_server.stdin.write(json.dumps(_command2) + '\n')
|
||||||
_server.stdin.flush()
|
_server.stdin.flush()
|
||||||
|
|
||||||
received1 = _client.receive()
|
received1 = _client.receive()
|
|
@ -51,8 +51,8 @@ export class HttpChannelServer implements CommandChannelServer {
|
||||||
this.outgoingQueues.forEach(queue => { queue.clear(); });
|
this.outgoingQueues.forEach(queue => { queue.clear(); });
|
||||||
}
|
}
|
||||||
|
|
||||||
public getChannelUrl(channelId: string): string {
|
public getChannelUrl(channelId: string, ip?: string): string {
|
||||||
return globals.rest.getFullUrl('http', 'localhost', this.path, channelId);
|
return globals.rest.getFullUrl('http', ip ?? 'localhost', this.path, channelId);
|
||||||
}
|
}
|
||||||
|
|
||||||
public send(channelId: string, command: Command): void {
|
public send(channelId: string, command: Command): void {
|
||||||
|
@ -63,6 +63,10 @@ export class HttpChannelServer implements CommandChannelServer {
|
||||||
this.emitter.on('receive', callback);
|
this.emitter.on('receive', callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public onConnection(_callback: (channelId: string, channel: any) => void): void {
|
||||||
|
throw new Error('Not implemented');
|
||||||
|
}
|
||||||
|
|
||||||
private handleGet(request: Request, response: Response): void {
|
private handleGet(request: Request, response: Response): void {
|
||||||
const channelId = request.params['channel'];
|
const channelId = request.params['channel'];
|
||||||
const promise = this.getOutgoingQueue(channelId).asyncPop(timeoutMilliseconds);
|
const promise = this.getOutgoingQueue(channelId).asyncPop(timeoutMilliseconds);
|
||||||
|
|
|
@ -1,26 +1,158 @@
|
||||||
// Copyright (c) Microsoft Corporation.
|
// Copyright (c) Microsoft Corporation.
|
||||||
// Licensed under the MIT license.
|
// Licensed under the MIT license.
|
||||||
|
|
||||||
//export interface Command {
|
/**
|
||||||
// type: string;
|
* Common interface of command channels.
|
||||||
// [key: string]: any;
|
*
|
||||||
//}
|
* A command channel is a duplex connection which supports sending and receiving JSON commands.
|
||||||
|
*
|
||||||
|
* Typically a command channel implementation consists of a server and a client.
|
||||||
|
*
|
||||||
|
* The server should listen to a URL prefix like `http://localhost:8080/example/`;
|
||||||
|
* and each client should connect to a unique URL containing the prefix, e.g. `http://localhost:8080/example/channel1`.
|
||||||
|
* The client's URL should be created with `server.getChannelUrl(channelId, serverIp)`.
|
||||||
|
*
|
||||||
|
* We currently have implemented one full feature command channel, the WebSocket channel,
|
||||||
|
* and a simplified one, the HTTP channel.
|
||||||
|
* In v3.1 release we might implement file command channel and AzureML command channel.
|
||||||
|
*
|
||||||
|
* The clients might have a Python version locates in `nni/runtime/command_channel`.
|
||||||
|
* The TypeScript and Python version should be interchangable.
|
||||||
|
**/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A command is a JSON object.
|
||||||
|
*
|
||||||
|
* The object has only one mandatory entry, `type`.
|
||||||
|
*
|
||||||
|
* The type string should not be surrounded by underscore (e.g. `_nop_`),
|
||||||
|
* unless you are dealing with the underlying implementation of a specific command channel;
|
||||||
|
* it should never starts with two underscores (e.g. `__command`) in any circumstance.
|
||||||
|
**/
|
||||||
export type Command = any;
|
export type Command = any;
|
||||||
|
|
||||||
|
// Maybe it's better to disable `noPropertyAccessFromIndexSignature` in tscofnig?
|
||||||
|
// export interface Command {
|
||||||
|
// type: string;
|
||||||
|
// [key: string]: any;
|
||||||
|
// }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A command channel server serves one or more command channels.
|
* `CommandChannel` is the base interface used by both the servers and the clients.
|
||||||
* Each channel is connected to a client.
|
|
||||||
*
|
*
|
||||||
* Normally each client has a unique channel URL,
|
* For servers, channels can be got with `onConnection()` event listener.
|
||||||
* which can be got with `server.getChannelUrl(id)`.
|
* For clients, a channel can be created with the client subclass' constructor.
|
||||||
*
|
*
|
||||||
* The APIs might be changed to return `Promise<void>` in future.
|
* The channel should be fault tolerant to some extend. It has three different types of closing related events:
|
||||||
|
*
|
||||||
|
* 1. Close: The channel is intentionally closed.
|
||||||
|
*
|
||||||
|
* 2. Lost: The channel is temporarily unavailable and is trying to recover.
|
||||||
|
* The user of this class should examine the peer's status out-of-band when receiving "lost" event.
|
||||||
|
*
|
||||||
|
* 3. Error: The channel is dead and cannot recover.
|
||||||
|
* A "close" event may or may not occur following this event. Do not rely on that.
|
||||||
|
**/
|
||||||
|
export interface CommandChannel {
|
||||||
|
readonly name: string; // for better logging
|
||||||
|
|
||||||
|
enableHeartbeat(intervalMilliseconds?: number): void;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Graceful (intentional) close.
|
||||||
|
* A "close" event will be emitted by `this` and the peer.
|
||||||
|
**/
|
||||||
|
close(reason: string): void;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Force close. Should only be used when the channel is not working.
|
||||||
|
* An "error" event may be emitted by `this`.
|
||||||
|
* A "lost" and/or "error" event will be emitted by the peer, if its process is still alive.
|
||||||
|
**/
|
||||||
|
terminate(reason: string): void;
|
||||||
|
|
||||||
|
send(command: Command): void;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The async version should try to ensures the command is successfully sent to the peer.
|
||||||
|
* But this is not guaranteed.
|
||||||
|
**/
|
||||||
|
sendAsync(command: Command): Promise<void>;
|
||||||
|
|
||||||
|
onReceive(callback: (command: Command) => void): void;
|
||||||
|
onCommand(commandType: string, callback: (command: Command) => void): void;
|
||||||
|
|
||||||
|
onClose(callback: (reason?: string) => void): void;
|
||||||
|
onError(callback: (error: Error) => void): void;
|
||||||
|
onLost(callback: () => void): void;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Client side of a command channel.
|
||||||
|
*
|
||||||
|
* The constructor should have no side effects.
|
||||||
|
*
|
||||||
|
* The listeners should be registered before calling `connect()`,
|
||||||
|
* or the first few commands might be missed.
|
||||||
|
*
|
||||||
|
* Example usage:
|
||||||
|
*
|
||||||
|
* const client = new WsChannelClient('example', 'ws://1.2.3.4:8080/server/channel_id');
|
||||||
|
* await client.connect();
|
||||||
|
* client.send(command);
|
||||||
|
**/
|
||||||
|
export interface CommandChannelClient extends CommandChannel {
|
||||||
|
// constructor(name: string, url: string);
|
||||||
|
|
||||||
|
connect(): Promise<void>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Typically an alias of `close()`.
|
||||||
|
**/
|
||||||
|
disconnect(reason?: string): Promise<void>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Server side of a command channel.
|
||||||
|
*
|
||||||
|
* The consructor should have no side effects.
|
||||||
|
*
|
||||||
|
* The listeners should be registered before calling `start()`.
|
||||||
|
*
|
||||||
|
* Example usage:
|
||||||
|
*
|
||||||
|
* const server = new WsChannelServer('example_server', '/server_prefix');
|
||||||
|
* const url = server.getChannelUrl('channel_id');
|
||||||
|
* const client = new WsChannelClient('example_client', url);
|
||||||
|
* await server.start();
|
||||||
|
* await client.connect();
|
||||||
|
*
|
||||||
|
* There two ways to listen to command:
|
||||||
|
*
|
||||||
|
* 1. Handle all clients' commands in one space:
|
||||||
|
*
|
||||||
|
* server.onReceive((channelId, command) => { ... });
|
||||||
|
* server.send(channelId, command);
|
||||||
|
*
|
||||||
|
* 2. Maintain a `WsChannel` instance for each client:
|
||||||
|
*
|
||||||
|
* server.onConnection((channelId, channel) => {
|
||||||
|
* channel.onCommand(command => { ... });
|
||||||
|
* channel.send(command);
|
||||||
|
* });
|
||||||
**/
|
**/
|
||||||
export interface CommandChannelServer {
|
export interface CommandChannelServer {
|
||||||
// constructor(name: string, urlPath: string)
|
// constructor(name: string, urlPath: string);
|
||||||
|
|
||||||
start(): Promise<void>;
|
start(): Promise<void>;
|
||||||
shutdown(): Promise<void>;
|
shutdown(): Promise<void>;
|
||||||
getChannelUrl(channelId: string): string;
|
|
||||||
|
/**
|
||||||
|
* When `ip` is missing, it should default to localhost.
|
||||||
|
**/
|
||||||
|
getChannelUrl(channelId: string, ip?: string): string;
|
||||||
|
|
||||||
send(channelId: string, command: Command): void;
|
send(channelId: string, command: Command): void;
|
||||||
onReceive(callback: (channelId: string, command: Command) => void): void;
|
onReceive(callback: (channelId: string, command: Command) => void): void;
|
||||||
|
onConnection(callback: (channelId: string, channel: CommandChannel) => void): void;
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,7 +50,7 @@ import { DefaultMap } from 'common/default_map';
|
||||||
import { Deferred } from 'common/deferred';
|
import { Deferred } from 'common/deferred';
|
||||||
import { Logger, getLogger } from 'common/log';
|
import { Logger, getLogger } from 'common/log';
|
||||||
import type { TrialKeeper } from 'common/trial_keeper/keeper';
|
import type { TrialKeeper } from 'common/trial_keeper/keeper';
|
||||||
import { WsChannel } from './websocket/channel';
|
import type { CommandChannel } from './interface';
|
||||||
|
|
||||||
interface RpcResponseCommand {
|
interface RpcResponseCommand {
|
||||||
type: 'rpc_response';
|
type: 'rpc_response';
|
||||||
|
@ -61,14 +61,14 @@ interface RpcResponseCommand {
|
||||||
|
|
||||||
type Class = { new(...args: any[]): any; };
|
type Class = { new(...args: any[]): any; };
|
||||||
|
|
||||||
const rpcHelpers: Map<WsChannel, RpcHelper> = new Map();
|
const rpcHelpers: Map<CommandChannel, RpcHelper> = new Map();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Enable RPC on a channel.
|
* Enable RPC on a channel.
|
||||||
*
|
*
|
||||||
* The channel does not need to be connected for calling this function.
|
* The channel does not need to be connected for calling this function.
|
||||||
**/
|
**/
|
||||||
export function getRpcHelper(channel: WsChannel): RpcHelper {
|
export function getRpcHelper(channel: CommandChannel): RpcHelper {
|
||||||
if (!rpcHelpers.has(channel)) {
|
if (!rpcHelpers.has(channel)) {
|
||||||
rpcHelpers.set(channel, new RpcHelper(channel));
|
rpcHelpers.set(channel, new RpcHelper(channel));
|
||||||
}
|
}
|
||||||
|
@ -76,7 +76,7 @@ export function getRpcHelper(channel: WsChannel): RpcHelper {
|
||||||
}
|
}
|
||||||
|
|
||||||
export class RpcHelper {
|
export class RpcHelper {
|
||||||
private channel: WsChannel;
|
private channel: CommandChannel;
|
||||||
private lastId: number = 0;
|
private lastId: number = 0;
|
||||||
private localCtors: Map<string, Class> = new Map();
|
private localCtors: Map<string, Class> = new Map();
|
||||||
private localObjs: Map<number, any> = new Map();
|
private localObjs: Map<number, any> = new Map();
|
||||||
|
@ -87,7 +87,7 @@ export class RpcHelper {
|
||||||
/**
|
/**
|
||||||
* NOTE: Don't use this constructor directly. Use `getRpcHelper()`.
|
* NOTE: Don't use this constructor directly. Use `getRpcHelper()`.
|
||||||
**/
|
**/
|
||||||
constructor(channel: WsChannel) {
|
constructor(channel: CommandChannel) {
|
||||||
this.log = getLogger(`RpcHelper.${channel.name}`);
|
this.log = getLogger(`RpcHelper.${channel.name}`);
|
||||||
this.channel = channel;
|
this.channel = channel;
|
||||||
this.channel.onCommand('rpc_constructor', command => {
|
this.channel.onCommand('rpc_constructor', command => {
|
||||||
|
|
|
@ -4,353 +4,240 @@
|
||||||
/**
|
/**
|
||||||
* WebSocket command channel.
|
* WebSocket command channel.
|
||||||
*
|
*
|
||||||
* This is the base class that used by both server and client.
|
* A WsChannel operates on one WebSocket connection at a time.
|
||||||
|
* But when the network is unstable, it may close the underlying connection and create a new one.
|
||||||
|
* This is generally transparent to the user of this class, except that a "lost" event will be emitted.
|
||||||
*
|
*
|
||||||
* For the server, channels can be got with `onConnection()` event listener.
|
* To distinguish intentional close from connection lost,
|
||||||
* For the client, a channel can be created with `new WsChannelClient()` subclass.
|
* a "_bye_" command will be sent when `close()` or `disconnect()` is invoked.
|
||||||
* Do not use the constructor directly.
|
|
||||||
*
|
*
|
||||||
* The channel is fault tolerant to some extend. It has three different types of closing related events:
|
* If the connection is closed before receiving "_bye_" command, a "lost" event will be emitted and:
|
||||||
*
|
*
|
||||||
* 1. "close": The channel is intentionally closed.
|
* * The client will try to reconnect for severaly times in around 15s.
|
||||||
|
* * The server will wait the client to reconnect for around 30s.
|
||||||
*
|
*
|
||||||
* This is caused either by "close()" or "disconnect()" call, or by receiving a "_bye_" command from the peer.
|
* If the reconnecting attempt failed, both side will emit an "error" event.
|
||||||
*
|
|
||||||
* 2. "lost": The channel is temporarily unavailable and is trying to recover.
|
|
||||||
* (The high level class should examine the peer's status out-of-band when receiving this event.)
|
|
||||||
*
|
|
||||||
* When the underlying socket is dead, this event is emitted.
|
|
||||||
* The client will try to reconnect in around 15s. If all attempts fail, an "error" event will be emitted.
|
|
||||||
* The server will wait the client for 30s. If it does not reconnect, an "error" event will be emitted.
|
|
||||||
* Successful recover will not emit command.
|
|
||||||
*
|
|
||||||
* 3. "error": The channel is dead and cannot recover.
|
|
||||||
*
|
|
||||||
* A "close" event may or may not follow this event. Do not rely on that.
|
|
||||||
**/
|
**/
|
||||||
|
|
||||||
import { EventEmitter } from 'node:events';
|
import { EventEmitter, once } from 'node:events';
|
||||||
import util from 'node:util';
|
import util from 'node:util';
|
||||||
|
|
||||||
import type { WebSocket } from 'ws';
|
import type { WebSocket } from 'ws';
|
||||||
|
|
||||||
|
import type { Command, CommandChannel } from 'common/command_channel/interface';
|
||||||
import { Deferred } from 'common/deferred';
|
import { Deferred } from 'common/deferred';
|
||||||
import { Logger, getLogger } from 'common/log';
|
import { Logger, getLogger } from 'common/log';
|
||||||
|
import { WsConnection } from './connection';
|
||||||
|
|
||||||
import type { Command } from '../interface';
|
interface QueuedCommand {
|
||||||
|
command: Command;
|
||||||
interface WsChannelEvents {
|
deferred?: Deferred<void>;
|
||||||
'command': (command: Command) => void;
|
|
||||||
'close': (reason: string) => void;
|
|
||||||
'lost': () => void;
|
|
||||||
'error': (error: Error) => void; // not used in base class
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export declare interface WsChannel {
|
export class WsChannel implements CommandChannel {
|
||||||
on<E extends keyof WsChannelEvents>(event: E, listener: WsChannelEvents[E]): this;
|
|
||||||
}
|
|
||||||
|
|
||||||
export class WsChannel extends EventEmitter {
|
|
||||||
private closing: boolean = false;
|
private closing: boolean = false;
|
||||||
private commandEmitter: EventEmitter = new EventEmitter();
|
private connection: WsConnection | null = null; // NOTE: used in unit test
|
||||||
private connection: WsConnection | null = null;
|
private epoch: number = -1;
|
||||||
private epoch: number = 0;
|
private heartbeatInterval: number | null = null;
|
||||||
private heartbeatInterval: number | null;
|
|
||||||
private log: Logger;
|
private log: Logger;
|
||||||
|
private queue: QueuedCommand[] = [];
|
||||||
|
private terminateTimer: NodeJS.Timer | null = null;
|
||||||
|
|
||||||
|
protected emitter: EventEmitter = new EventEmitter();
|
||||||
|
|
||||||
public readonly name: string;
|
public readonly name: string;
|
||||||
|
|
||||||
// internal, don't use
|
// internal, don't use
|
||||||
constructor(name: string, ws?: WebSocket, heartbeatInterval?: number) {
|
constructor(name: string) {
|
||||||
super()
|
|
||||||
this.log = getLogger(`WsChannel.${name}`);
|
this.log = getLogger(`WsChannel.${name}`);
|
||||||
this.name = name;
|
this.name = name;
|
||||||
this.heartbeatInterval = heartbeatInterval || null;
|
|
||||||
if (ws) {
|
|
||||||
this.setConnection(ws);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public enableHeartbeat(interval: number): void {
|
|
||||||
this.log.debug('## enable heartbeat');
|
|
||||||
this.heartbeatInterval = interval;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// internal, don't use
|
// internal, don't use
|
||||||
public setConnection(ws: WebSocket): void {
|
public async setConnection(ws: WebSocket, waitOpen: boolean): Promise<void> {
|
||||||
if (this.connection) {
|
if (this.terminateTimer) {
|
||||||
this.log.debug('Abandon previous connection');
|
clearTimeout(this.terminateTimer);
|
||||||
this.epoch += 1;
|
this.terminateTimer = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.connection?.terminate('new epoch start');
|
||||||
|
this.newEpoch();
|
||||||
this.log.debug(`Epoch ${this.epoch} start`);
|
this.log.debug(`Epoch ${this.epoch} start`);
|
||||||
|
|
||||||
this.connection = this.configConnection(ws);
|
this.connection = this.configConnection(ws);
|
||||||
|
if (waitOpen) {
|
||||||
|
await once(ws, 'open');
|
||||||
|
}
|
||||||
|
|
||||||
|
while (this.connection && this.queue.length > 0) {
|
||||||
|
const item = this.queue.shift()!;
|
||||||
|
try {
|
||||||
|
await this.connection.sendAsync(item.command);
|
||||||
|
item.deferred?.resolve();
|
||||||
|
} catch (error) {
|
||||||
|
this.log.error('Failed to send command on recovered channel:', error);
|
||||||
|
this.log.error('Dropped command:', item.command);
|
||||||
|
item.deferred?.reject(error as any);
|
||||||
|
// it should trigger connection's error event and this.connection will be set to null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public enableHeartbeat(interval?: number): void {
|
||||||
|
this.heartbeatInterval = interval ?? defaultHeartbeatInterval;
|
||||||
|
this.connection?.setHeartbeatInterval(this.heartbeatInterval);
|
||||||
}
|
}
|
||||||
|
|
||||||
public close(reason: string): void {
|
public close(reason: string): void {
|
||||||
this.log.debug('Close channel:', reason);
|
this.log.debug('Close channel:', reason);
|
||||||
if (this.connection) {
|
this.connection?.close(reason);
|
||||||
this.connection.close(reason);
|
if (this.setClosing()) {
|
||||||
this.endEpoch();
|
this.emitter.emit('__close', reason);
|
||||||
}
|
|
||||||
if (!this.closing) {
|
|
||||||
this.closing = true;
|
|
||||||
this.emit('close', reason);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public terminate(reason: string): void {
|
public terminate(reason: string): void {
|
||||||
this.log.info('Terminate channel:', reason);
|
this.log.info('Terminate channel:', reason);
|
||||||
this.closing = true;
|
this.connection?.terminate(reason);
|
||||||
if (this.connection) {
|
if (this.setClosing()) {
|
||||||
this.connection.terminate(reason);
|
this.emitter.emit('__error', new Error(`WsChannel terminated: ${reason}`));
|
||||||
this.endEpoch();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public send(command: Command): void {
|
public send(command: Command): void {
|
||||||
if (this.connection) {
|
if (this.closing) {
|
||||||
this.connection.send(command);
|
this.log.error('Channel closed. Ignored command', command);
|
||||||
} else {
|
return;
|
||||||
// TODO: add a queue?
|
|
||||||
this.log.error('Connection lost. Dropped command', command);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!this.connection) {
|
||||||
|
this.log.warning('Connection lost. Enqueue command', command);
|
||||||
|
this.queue.push({ command });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.connection.send(command);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Async version of `send()` that (partially) ensures the command is successfully sent to peer.
|
|
||||||
**/
|
|
||||||
public sendAsync(command: Command): Promise<void> {
|
public sendAsync(command: Command): Promise<void> {
|
||||||
if (this.connection) {
|
if (this.closing) {
|
||||||
return this.connection.sendAsync(command);
|
this.log.error('(async) Channel closed. Refused command', command);
|
||||||
} else {
|
return Promise.reject(new Error('WsChannel has been closed'));
|
||||||
this.log.error('Connection lost. Dropped command async', command);
|
|
||||||
return Promise.reject(new Error('Connection is lost and trying to recover, cannot send command now'));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!this.connection) {
|
||||||
|
this.log.warning('(async) Connection lost. Enqueue command', command);
|
||||||
|
const deferred = new Deferred<void>();
|
||||||
|
this.queue.push({ command, deferred });
|
||||||
|
return deferred.promise;
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.connection.sendAsync(command);
|
||||||
}
|
}
|
||||||
|
|
||||||
// the first overload listens to all commands, while the second listens to one command type
|
public onReceive(callback: (command: Command) => void): void {
|
||||||
public onCommand(callback: (command: Command) => void): void;
|
this.emitter.on('__receive', callback);
|
||||||
public onCommand(commandType: string, callback: (command: Command) => void): void;
|
}
|
||||||
|
|
||||||
public onCommand(commandTypeOrCallback: any, callbackOrNone?: any): void {
|
public onCommand(commandType: string, callback: (command: Command) => void): void {
|
||||||
if (callbackOrNone) {
|
this.emitter.on(commandType, callback);
|
||||||
this.commandEmitter.on(commandTypeOrCallback, callbackOrNone);
|
}
|
||||||
} else {
|
|
||||||
this.commandEmitter.on('__any', commandTypeOrCallback);
|
public onClose(callback: (reason?: string) => void): void {
|
||||||
}
|
this.emitter.on('__close', callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
public onError(callback: (error: Error) => void): void {
|
||||||
|
this.emitter.on('__error', callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
public onLost(callback: () => void): void {
|
||||||
|
this.emitter.on('__lost', callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
private newEpoch(): void {
|
||||||
|
this.connection = null;
|
||||||
|
this.epoch += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
private configConnection(ws: WebSocket): WsConnection {
|
private configConnection(ws: WebSocket): WsConnection {
|
||||||
this.log.debug('## config connection');
|
const connName = this.epoch ? `${this.name}.${this.epoch}` : this.name;
|
||||||
const epoch = this.epoch; // copy it to use in closure
|
const conn = new WsConnection(connName, ws, this.emitter);
|
||||||
const conn = new WsConnection(
|
if (this.heartbeatInterval) {
|
||||||
this.epoch ? `${this.name}.${epoch}` : this.name,
|
conn.setHeartbeatInterval(this.heartbeatInterval);
|
||||||
ws,
|
}
|
||||||
this.commandEmitter,
|
|
||||||
this.heartbeatInterval
|
|
||||||
);
|
|
||||||
|
|
||||||
conn.on('bye', reason => {
|
conn.on('bye', reason => {
|
||||||
this.log.debug('Peer intentionally close:', reason);
|
this.log.debug('Peer intentionally closing:', reason);
|
||||||
this.endEpoch();
|
if (this.setClosing()) {
|
||||||
if (!this.closing) {
|
this.emitter.emit('__close', reason);
|
||||||
this.closing = true;
|
|
||||||
this.emit('close', reason);
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
conn.on('close', (code, reason) => {
|
conn.on('close', (code, reason) => {
|
||||||
this.closeConnection(epoch, `Received closing handshake: ${code} ${reason}`);
|
this.log.debug('Peer closed:', reason);
|
||||||
|
this.dropConnection(conn, `Peer closed: ${code} ${reason}`);
|
||||||
});
|
});
|
||||||
|
|
||||||
conn.on('error', error => {
|
conn.on('error', error => {
|
||||||
this.closeConnection(epoch, `Error occurred: ${util.inspect(error)}`);
|
this.dropConnection(conn, `Connection error: ${util.inspect(error)}`);
|
||||||
});
|
});
|
||||||
|
|
||||||
return conn;
|
return conn;
|
||||||
}
|
}
|
||||||
|
|
||||||
private closeConnection(epoch: number, reason: string): void {
|
private setClosing(): boolean {
|
||||||
if (this.closing) {
|
if (this.closing) {
|
||||||
this.log.debug('Connection cleaned up:', reason);
|
return false;
|
||||||
|
}
|
||||||
|
this.closing = true;
|
||||||
|
this.newEpoch();
|
||||||
|
this.queue.forEach(item => {
|
||||||
|
item.deferred?.reject(new Error('WsChannel has been closed.'));
|
||||||
|
});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private dropConnection(conn: WsConnection, reason: string): void {
|
||||||
|
if (this.closing) {
|
||||||
|
this.log.debug('Clean up:', reason);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (this.epoch !== epoch) { // the connection is already abandoned
|
if (this.connection !== conn) { // the connection is already abandoned
|
||||||
this.log.debug(`Previous connection closed ${epoch}: ${reason}`);
|
this.log.debug(`Previous connection closed: ${reason}`);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
this.log.warning('Connection closed unexpectedly:', reason);
|
this.log.warning('Connection closed unexpectedly:', reason);
|
||||||
this.emit('lost');
|
this.newEpoch();
|
||||||
this.endEpoch();
|
this.emitter.emit('__lost');
|
||||||
|
|
||||||
|
if (!this.terminateTimer) {
|
||||||
|
this.terminateTimer = setTimeout(() => {
|
||||||
|
if (!this.closing) {
|
||||||
|
this.terminate('have not reconnected in 30s');
|
||||||
|
}
|
||||||
|
}, terminateTimeout);
|
||||||
|
}
|
||||||
|
|
||||||
// the reconnect logic is in client subclass
|
// the reconnect logic is in client subclass
|
||||||
}
|
}
|
||||||
|
|
||||||
private endEpoch(): void {
|
|
||||||
this.connection = null;
|
|
||||||
this.epoch += 1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
interface WsConnectionEvents {
|
let defaultHeartbeatInterval: number = 5000;
|
||||||
'command': (command: Command) => void;
|
let terminateTimeout: number = 30000;
|
||||||
'bye': (reason: string) => void;
|
|
||||||
'close': (code: number, reason: string) => void;
|
|
||||||
'error': (error: Error) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
declare interface WsConnection {
|
export namespace UnitTestHelper {
|
||||||
on<E extends keyof WsConnectionEvents>(event: E, listener: WsConnectionEvents[E]): this;
|
export function setHeartbeatInterval(ms: number): void {
|
||||||
}
|
defaultHeartbeatInterval = ms;
|
||||||
|
|
||||||
class WsConnection extends EventEmitter {
|
|
||||||
private closing: boolean = false;
|
|
||||||
private commandEmitter: EventEmitter;
|
|
||||||
private heartbeatTimer: NodeJS.Timer | null = null;
|
|
||||||
private log: Logger;
|
|
||||||
private missingPongs: number = 0;
|
|
||||||
private ws: WebSocket;
|
|
||||||
|
|
||||||
constructor(name: string, ws: WebSocket, commandEmitter: EventEmitter, heartbeatInterval: number | null) {
|
|
||||||
super();
|
|
||||||
this.log = getLogger(`WsConnection.${name}`);
|
|
||||||
this.ws = ws;
|
|
||||||
this.commandEmitter = commandEmitter;
|
|
||||||
|
|
||||||
ws.on('close', this.handleClose.bind(this));
|
|
||||||
ws.on('error', this.handleError.bind(this));
|
|
||||||
ws.on('message', this.handleMessage.bind(this));
|
|
||||||
ws.on('pong', this.handlePong.bind(this));
|
|
||||||
|
|
||||||
if (heartbeatInterval) {
|
|
||||||
this.heartbeatTimer = setInterval(this.heartbeat.bind(this), heartbeatInterval);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public async close(reason: string): Promise<void> {
|
export function setTerminateTimeout(ms: number): void {
|
||||||
if (this.closing) {
|
terminateTimeout = ms;
|
||||||
this.log.debug('Close again:', reason);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.log.debug('Close:', reason);
|
|
||||||
this.closing = true;
|
|
||||||
if (this.heartbeatTimer) {
|
|
||||||
clearInterval(this.heartbeatTimer);
|
|
||||||
this.heartbeatTimer = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
await this.sendAsync({ type: '_bye_', reason });
|
|
||||||
} catch (error) {
|
|
||||||
this.log.error('Failed to send bye:', error);
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
this.ws.close(4000, reason);
|
|
||||||
} catch (error) {
|
|
||||||
this.log.error('Failed to close:', error);
|
|
||||||
this.ws.terminate();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public terminate(reason: string): void {
|
export function reset(): void {
|
||||||
this.log.debug('Terminate:', reason);
|
defaultHeartbeatInterval = 5000;
|
||||||
this.closing = true;
|
terminateTimeout = 30000;
|
||||||
if (this.heartbeatTimer) {
|
|
||||||
clearInterval(this.heartbeatTimer);
|
|
||||||
this.heartbeatTimer = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
this.ws.close(4001, reason);
|
|
||||||
} catch (error) {
|
|
||||||
this.log.debug('Failed to close:', error);
|
|
||||||
this.ws.terminate();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public send(command: Command): void {
|
|
||||||
this.log.trace('Sending command', command);
|
|
||||||
this.ws.send(JSON.stringify(command));
|
|
||||||
}
|
|
||||||
|
|
||||||
public sendAsync(command: Command): Promise<void> {
|
|
||||||
this.log.trace('Sending command async', command);
|
|
||||||
const deferred = new Deferred<void>();
|
|
||||||
this.ws.send(JSON.stringify(command), error => {
|
|
||||||
if (error) {
|
|
||||||
deferred.reject(error);
|
|
||||||
} else {
|
|
||||||
deferred.resolve();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return deferred.promise;
|
|
||||||
}
|
|
||||||
|
|
||||||
private handleClose(code: number, reason: Buffer): void {
|
|
||||||
if (this.closing) {
|
|
||||||
this.log.debug('Connection closed');
|
|
||||||
} else {
|
|
||||||
this.log.debug('Connection closed by peer:', code, String(reason));
|
|
||||||
this.emit('close', code, String(reason));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private handleError(error: Error): void {
|
|
||||||
if (this.closing) {
|
|
||||||
this.log.warning('Error after closing:', error);
|
|
||||||
} else {
|
|
||||||
this.emit('error', error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private handleMessage(data: Buffer, _isBinary: boolean): void {
|
|
||||||
const s = String(data);
|
|
||||||
if (this.closing) {
|
|
||||||
this.log.warning('Received message after closing:', s);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.log.trace('Received command', s);
|
|
||||||
const command = JSON.parse(s);
|
|
||||||
|
|
||||||
if (command.type === '_nop_') {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (command.type === '_bye_') {
|
|
||||||
this.closing = true;
|
|
||||||
this.emit('bye', command.reason);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const hasAnyListener = this.commandEmitter.emit('__any', command);
|
|
||||||
const hasTypeListener = this.commandEmitter.emit(command.type, command);
|
|
||||||
if (!hasAnyListener && !hasTypeListener) {
|
|
||||||
this.log.warning('No listener for command', s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private handlePong(): void {
|
|
||||||
this.log.debug('receive pong'); // todo
|
|
||||||
this.missingPongs = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
private heartbeat(): void {
|
|
||||||
if (this.missingPongs > 0) {
|
|
||||||
this.log.warning('Missing pong');
|
|
||||||
}
|
|
||||||
if (this.missingPongs > 3) { // TODO: make it configurable?
|
|
||||||
// no response for ping, try real command
|
|
||||||
this.sendAsync({ type: '_nop_' }).then(() => {
|
|
||||||
this.missingPongs = 0;
|
|
||||||
}).catch(error => {
|
|
||||||
this.log.error('Failed sending command. Drop connection:', error);
|
|
||||||
this.terminate(`peer lost responsive: ${util.inspect(error)}`);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
this.missingPongs += 1;
|
|
||||||
this.log.debug('send ping');
|
|
||||||
this.ws.ping();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,15 +3,6 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* WebSocket command channel client.
|
* WebSocket command channel client.
|
||||||
*
|
|
||||||
* Usage:
|
|
||||||
*
|
|
||||||
* const client = new WsChannelClient('ws://1.2.3.4:8080/server/channel_id');
|
|
||||||
* await client.connect();
|
|
||||||
* client.send(command);
|
|
||||||
*
|
|
||||||
* Most APIs are derived the base class `WsChannel`.
|
|
||||||
* See its doc for more details.
|
|
||||||
**/
|
**/
|
||||||
|
|
||||||
import events from 'node:events';
|
import events from 'node:events';
|
||||||
|
@ -26,7 +17,7 @@ import { WsChannel } from './channel';
|
||||||
const maxPayload: number = 1024 * 1024 * 1024;
|
const maxPayload: number = 1024 * 1024 * 1024;
|
||||||
|
|
||||||
export class WsChannelClient extends WsChannel {
|
export class WsChannelClient extends WsChannel {
|
||||||
private logger: Logger; // avoid name conflict with base class
|
private logger: Logger;
|
||||||
private reconnecting: boolean = false;
|
private reconnecting: boolean = false;
|
||||||
private url: string;
|
private url: string;
|
||||||
|
|
||||||
|
@ -34,19 +25,17 @@ export class WsChannelClient extends WsChannel {
|
||||||
* The url should start with "ws://".
|
* The url should start with "ws://".
|
||||||
* The name is used for better logging.
|
* The name is used for better logging.
|
||||||
**/
|
**/
|
||||||
constructor(url: string, name?: string) {
|
constructor(name: string, url: string) {
|
||||||
const name_ = name ?? generateName(url);
|
super(name);
|
||||||
super(name_);
|
this.logger = getLogger(`WsChannelClient.${name}`);
|
||||||
this.logger = getLogger(`WsChannelClient.${name_}`);
|
|
||||||
this.url = url;
|
this.url = url;
|
||||||
this.on('lost', this.reconnect.bind(this));
|
this.onLost(this.reconnect.bind(this));
|
||||||
}
|
}
|
||||||
|
|
||||||
public async connect(): Promise<void> {
|
public async connect(): Promise<void> {
|
||||||
this.logger.debug('Connecting to', this.url);
|
this.logger.debug('Connecting to', this.url);
|
||||||
const ws = new WebSocket(this.url, { maxPayload });
|
const ws = new WebSocket(this.url, { maxPayload });
|
||||||
this.setConnection(ws);
|
await this.setConnection(ws, true),
|
||||||
await events.once(ws, 'open');
|
|
||||||
this.logger.debug('Connected');
|
this.logger.debug('Connected');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,7 +43,7 @@ export class WsChannelClient extends WsChannel {
|
||||||
* Alias of `close()`.
|
* Alias of `close()`.
|
||||||
**/
|
**/
|
||||||
public async disconnect(reason?: string): Promise<void> {
|
public async disconnect(reason?: string): Promise<void> {
|
||||||
this.close(reason ?? 'client disconnecting');
|
this.close(reason ?? 'client intentionally disconnect');
|
||||||
}
|
}
|
||||||
|
|
||||||
private async reconnect(): Promise<void> {
|
private async reconnect(): Promise<void> {
|
||||||
|
@ -82,16 +71,6 @@ export class WsChannelClient extends WsChannel {
|
||||||
}
|
}
|
||||||
|
|
||||||
this.logger.error('Conenction lost. Cannot reconnect');
|
this.logger.error('Conenction lost. Cannot reconnect');
|
||||||
this.emit('error', new Error('Connection lost'));
|
this.emitter.emit('__error', new Error('Connection lost'));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function generateName(url: string): string {
|
|
||||||
const parts = url.split('/');
|
|
||||||
for (let i = parts.length - 1; i > 1; i--) {
|
|
||||||
if (parts[i]) {
|
|
||||||
return parts[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 'anonymous';
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,189 @@
|
||||||
|
// Copyright (c) Microsoft Corporation.
|
||||||
|
// Licensed under the MIT license.
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Internal helper class which handles one WebSocket connection.
|
||||||
|
**/
|
||||||
|
|
||||||
|
import { EventEmitter } from 'node:events';
|
||||||
|
import util from 'node:util';
|
||||||
|
|
||||||
|
import type { WebSocket } from 'ws';
|
||||||
|
|
||||||
|
import type { Command } from 'common/command_channel/interface';
|
||||||
|
import { Logger, getLogger } from 'common/log';
|
||||||
|
|
||||||
|
interface ConnectionEvents {
|
||||||
|
'bye': (reason: string) => void;
|
||||||
|
'close': (code: number, reason: string) => void;
|
||||||
|
'error': (error: Error) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export declare interface WsConnection {
|
||||||
|
on<E extends keyof ConnectionEvents>(event: E, listener: ConnectionEvents[E]): this;
|
||||||
|
}
|
||||||
|
|
||||||
|
export class WsConnection extends EventEmitter {
|
||||||
|
private closing: boolean = false;
|
||||||
|
private commandEmitter: EventEmitter;
|
||||||
|
private heartbeatTimer: NodeJS.Timer | null = null;
|
||||||
|
private log: Logger;
|
||||||
|
private missingPongs: number = 0;
|
||||||
|
private ws: WebSocket; // NOTE: used in unit test
|
||||||
|
|
||||||
|
constructor(name: string, ws: WebSocket, commandEmitter: EventEmitter) {
|
||||||
|
super();
|
||||||
|
this.log = getLogger(`WsConnection.${name}`);
|
||||||
|
this.ws = ws;
|
||||||
|
this.commandEmitter = commandEmitter;
|
||||||
|
|
||||||
|
ws.on('close', this.handleClose.bind(this));
|
||||||
|
ws.on('error', this.handleError.bind(this));
|
||||||
|
ws.on('message', this.handleMessage.bind(this));
|
||||||
|
ws.on('pong', this.handlePong.bind(this));
|
||||||
|
}
|
||||||
|
|
||||||
|
public setHeartbeatInterval(interval: number): void {
|
||||||
|
if (this.heartbeatTimer) {
|
||||||
|
clearTimeout(this.heartbeatTimer);
|
||||||
|
}
|
||||||
|
this.heartbeatTimer = setInterval(this.heartbeat.bind(this), interval);
|
||||||
|
}
|
||||||
|
|
||||||
|
public async close(reason: string): Promise<void> {
|
||||||
|
if (this.closing) {
|
||||||
|
this.log.debug('Close again:', reason);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.log.debug('Close connection:', reason);
|
||||||
|
this.closing = true;
|
||||||
|
if (this.heartbeatTimer) {
|
||||||
|
clearInterval(this.heartbeatTimer);
|
||||||
|
this.heartbeatTimer = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
await this.sendAsync({ type: '_bye_', reason });
|
||||||
|
} catch (error) {
|
||||||
|
this.log.error('Failed to send bye:', error);
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
this.ws.close(4000, reason);
|
||||||
|
return;
|
||||||
|
} catch (error) {
|
||||||
|
this.log.error('Failed to close socket:', error);
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
this.ws.terminate();
|
||||||
|
} catch (error) {
|
||||||
|
this.log.debug('Failed to terminate socket:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public terminate(reason: string): void {
|
||||||
|
this.log.debug('Terminate connection:', reason);
|
||||||
|
this.closing = true;
|
||||||
|
if (this.heartbeatTimer) {
|
||||||
|
clearInterval(this.heartbeatTimer);
|
||||||
|
this.heartbeatTimer = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
this.ws.close(4001, reason);
|
||||||
|
return;
|
||||||
|
} catch (error) {
|
||||||
|
this.log.debug('Failed to close socket:', error);
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
this.ws.terminate();
|
||||||
|
} catch (error) {
|
||||||
|
this.log.debug('Failed to terminate socket:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public send(command: Command): void {
|
||||||
|
this.log.trace('Send command', command);
|
||||||
|
this.ws.send(JSON.stringify(command));
|
||||||
|
}
|
||||||
|
|
||||||
|
public sendAsync(command: Command): Promise<void> {
|
||||||
|
this.log.trace('(async) Send command', command);
|
||||||
|
const send: any = util.promisify(this.ws.send.bind(this.ws));
|
||||||
|
return send(JSON.stringify(command));
|
||||||
|
}
|
||||||
|
|
||||||
|
private handleClose(code: number, reason: Buffer): void {
|
||||||
|
if (this.closing) {
|
||||||
|
this.log.debug('Connection closed');
|
||||||
|
} else {
|
||||||
|
this.log.debug('Connection closed by peer:', code, String(reason));
|
||||||
|
this.emit('close', code, String(reason));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private handleError(error: Error): void {
|
||||||
|
if (this.closing) {
|
||||||
|
this.log.warning('Error after closing:', error);
|
||||||
|
} else {
|
||||||
|
this.log.error('Connection error:', error);
|
||||||
|
this.emit('error', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private handleMessage(data: Buffer, _isBinary: boolean): void {
|
||||||
|
const s = String(data);
|
||||||
|
if (this.closing) {
|
||||||
|
this.log.warning('Received message after closing:', s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.log.trace('Receive command', s);
|
||||||
|
const command = JSON.parse(s);
|
||||||
|
|
||||||
|
if (command.type === '_nop_') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (command.type === '_bye_') {
|
||||||
|
this.log.debug('Intentionally close connection:', s);
|
||||||
|
this.closing = true;
|
||||||
|
this.emit('bye', command.reason);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const hasReceiveListener = this.commandEmitter.emit('__receive', command);
|
||||||
|
const hasCommandListener = this.commandEmitter.emit(command.type, command);
|
||||||
|
if (!hasReceiveListener && !hasCommandListener) {
|
||||||
|
this.log.warning('No listener for command', s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private handlePong(): void {
|
||||||
|
this.log.trace('Receive pong');
|
||||||
|
this.missingPongs = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
private heartbeat(): void {
|
||||||
|
if (this.missingPongs > 0) {
|
||||||
|
this.log.warning('Missing pong');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.missingPongs > 3) { // TODO: make it configurable?
|
||||||
|
// no response for ping, try real command
|
||||||
|
this.sendAsync({ type: '_nop_' }).then(() => {
|
||||||
|
this.missingPongs = 0;
|
||||||
|
}).catch(error => {
|
||||||
|
this.log.error('Failed sending command. Drop connection:', error);
|
||||||
|
this.terminate(`peer lost responsive: ${util.inspect(error)}`);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
this.missingPongs += 1;
|
||||||
|
this.log.trace('Send ping');
|
||||||
|
this.ws.ping();
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,29 +3,6 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* WebSocket command channel server.
|
* WebSocket command channel server.
|
||||||
*
|
|
||||||
* The server will specify a URL prefix like `ws://1.2.3.4:8080/SERVER_PREFIX`,
|
|
||||||
* and each client will append a channel ID, like `ws://1.2.3.4:8080/SERVER_PREFIX/CHANNEL_ID`.
|
|
||||||
*
|
|
||||||
* const server = new WsChannelServer('example', 'SERVER_PREFIX');
|
|
||||||
* const url = server.getChannelUrl('CHANNEL_ID');
|
|
||||||
* const client = new WsChannelClient(url);
|
|
||||||
* await server.start();
|
|
||||||
* await client.connect();
|
|
||||||
*
|
|
||||||
* There two styles to use the server:
|
|
||||||
*
|
|
||||||
* 1. Handle all clients' commands in one space:
|
|
||||||
*
|
|
||||||
* server.onReceive((channelId, command) => { ... });
|
|
||||||
* server.send(channelId, command);
|
|
||||||
*
|
|
||||||
* 2. Maintain a `WsChannel` instance for each client:
|
|
||||||
*
|
|
||||||
* server.onConnection((channelId, channel) => {
|
|
||||||
* channel.onCommand(command => { ... });
|
|
||||||
* channel.send(command);
|
|
||||||
* });
|
|
||||||
**/
|
**/
|
||||||
|
|
||||||
import { EventEmitter } from 'events';
|
import { EventEmitter } from 'events';
|
||||||
|
@ -39,22 +16,18 @@ import globals from 'common/globals';
|
||||||
import { Logger, getLogger } from 'common/log';
|
import { Logger, getLogger } from 'common/log';
|
||||||
import { WsChannel } from './channel';
|
import { WsChannel } from './channel';
|
||||||
|
|
||||||
let heartbeatInterval: number = 5000;
|
|
||||||
|
|
||||||
type ReceiveCallback = (channelId: string, command: Command) => void;
|
type ReceiveCallback = (channelId: string, command: Command) => void;
|
||||||
|
|
||||||
export class WsChannelServer extends EventEmitter {
|
export class WsChannelServer extends EventEmitter {
|
||||||
private channels: Map<string, WsChannel> = new Map();
|
private channels: Map<string, WsChannel> = new Map();
|
||||||
private ip: string;
|
|
||||||
private log: Logger;
|
private log: Logger;
|
||||||
private path: string;
|
private path: string;
|
||||||
private receiveCallbacks: ReceiveCallback[] = [];
|
private receiveCallbacks: ReceiveCallback[] = [];
|
||||||
|
|
||||||
constructor(name: string, urlPath: string, ip?: string) {
|
constructor(name: string, urlPath: string) {
|
||||||
super();
|
super();
|
||||||
this.log = getLogger(`WsChannelServer.${name}`);
|
this.log = getLogger(`WsChannelServer.${name}`);
|
||||||
this.path = urlPath;
|
this.path = urlPath;
|
||||||
this.ip = ip ?? 'localhost';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public async start(): Promise<void> {
|
public async start(): Promise<void> {
|
||||||
|
@ -67,7 +40,7 @@ export class WsChannelServer extends EventEmitter {
|
||||||
const deferred = new Deferred<void>();
|
const deferred = new Deferred<void>();
|
||||||
|
|
||||||
this.channels.forEach((channel, channelId) => {
|
this.channels.forEach((channel, channelId) => {
|
||||||
channel.on('close', (_reason) => {
|
channel.onClose(_reason => {
|
||||||
this.channels.delete(channelId);
|
this.channels.delete(channelId);
|
||||||
if (this.channels.size === 0) {
|
if (this.channels.size === 0) {
|
||||||
deferred.resolve();
|
deferred.resolve();
|
||||||
|
@ -77,17 +50,16 @@ export class WsChannelServer extends EventEmitter {
|
||||||
});
|
});
|
||||||
|
|
||||||
// wait for at most 5 seconds
|
// wait for at most 5 seconds
|
||||||
// use heartbeatInterval here for easier unit test
|
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
this.log.debug('Shutdown timeout. Stop waiting following channels:', Array.from(this.channels.keys()));
|
this.log.debug('Shutdown timeout. Stop waiting following channels:', Array.from(this.channels.keys()));
|
||||||
deferred.resolve();
|
deferred.resolve();
|
||||||
}, heartbeatInterval);
|
}, 5000);
|
||||||
|
|
||||||
return deferred.promise;
|
return deferred.promise;
|
||||||
}
|
}
|
||||||
|
|
||||||
public getChannelUrl(channelId: string): string {
|
public getChannelUrl(channelId: string, ip?: string): string {
|
||||||
return globals.rest.getFullUrl('ws', this.ip, this.path, channelId);
|
return globals.rest.getFullUrl('ws', ip ?? 'localhost', this.path, channelId);
|
||||||
}
|
}
|
||||||
|
|
||||||
public send(channelId: string, command: Command): void {
|
public send(channelId: string, command: Command): void {
|
||||||
|
@ -104,7 +76,7 @@ export class WsChannelServer extends EventEmitter {
|
||||||
// because by this way it can detect and warning if a command is never listened
|
// because by this way it can detect and warning if a command is never listened
|
||||||
this.receiveCallbacks.push(callback);
|
this.receiveCallbacks.push(callback);
|
||||||
for (const [channelId, channel] of this.channels) {
|
for (const [channelId, channel] of this.channels) {
|
||||||
channel.onCommand(command => { callback(channelId, command); });
|
channel.onReceive(command => { callback(channelId, command); });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,33 +90,30 @@ export class WsChannelServer extends EventEmitter {
|
||||||
|
|
||||||
if (this.channels.has(channelId)) {
|
if (this.channels.has(channelId)) {
|
||||||
this.log.warning(`Channel ${channelId} reconnecting, drop previous connection`);
|
this.log.warning(`Channel ${channelId} reconnecting, drop previous connection`);
|
||||||
this.channels.get(channelId)!.setConnection(ws);
|
this.channels.get(channelId)!.setConnection(ws, false);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const channel = new WsChannel(channelId, ws, heartbeatInterval);
|
const channel = new WsChannel(channelId);
|
||||||
this.channels.set(channelId, channel);
|
this.channels.set(channelId, channel);
|
||||||
|
|
||||||
channel.on('close', reason => {
|
channel.onClose(reason => {
|
||||||
this.log.debug(`Connection ${channelId} closed:`, reason);
|
this.log.debug(`Connection ${channelId} closed:`, reason);
|
||||||
this.channels.delete(channelId);
|
this.channels.delete(channelId);
|
||||||
});
|
});
|
||||||
|
|
||||||
channel.on('error', error => {
|
channel.onError(error => {
|
||||||
this.log.error(`Connection ${channelId} error:`, error);
|
this.log.error(`Connection ${channelId} error:`, error);
|
||||||
this.channels.delete(channelId);
|
this.channels.delete(channelId);
|
||||||
});
|
});
|
||||||
|
|
||||||
for (const cb of this.receiveCallbacks) {
|
for (const cb of this.receiveCallbacks) {
|
||||||
channel.on('command', command => { cb(channelId, command); });
|
channel.onReceive(command => { cb(channelId, command); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
channel.enableHeartbeat();
|
||||||
|
channel.setConnection(ws, false);
|
||||||
|
|
||||||
this.emit('connection', channelId, channel);
|
this.emit('connection', channelId, channel);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export namespace UnitTestHelpers {
|
|
||||||
export function setHeartbeatInterval(ms: number): void {
|
|
||||||
heartbeatInterval = ms;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -98,5 +98,5 @@ if (isUnitTest()) {
|
||||||
resetGlobals();
|
resetGlobals();
|
||||||
}
|
}
|
||||||
|
|
||||||
const globals: MutableGlobals = (global as any).nni;
|
export const globals: MutableGlobals = (global as any).nni;
|
||||||
export default globals;
|
export default globals;
|
||||||
|
|
|
@ -72,13 +72,13 @@ async function main(): Promise<void> {
|
||||||
logger.debug('command:', process.argv);
|
logger.debug('command:', process.argv);
|
||||||
logger.debug('config:', config);
|
logger.debug('config:', config);
|
||||||
|
|
||||||
const client = new WsChannelClient(args.managerCommandChannel, args.environmentId);
|
const client = new WsChannelClient(args.environmentId, args.managerCommandChannel);
|
||||||
client.enableHeartbeat(5000);
|
client.enableHeartbeat();
|
||||||
client.on('close', reason => {
|
client.onClose(reason => {
|
||||||
logger.info('Manager closed connection:', reason);
|
logger.info('Manager closed connection:', reason);
|
||||||
globals.shutdown.initiate('Connection end');
|
globals.shutdown.initiate('Connection end');
|
||||||
});
|
});
|
||||||
client.on('error', error => {
|
client.onError(error => {
|
||||||
logger.info('Connection error:', error);
|
logger.info('Connection error:', error);
|
||||||
globals.shutdown.initiate('Connection error');
|
globals.shutdown.initiate('Connection error');
|
||||||
});
|
});
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"build": "tsc",
|
"build": "tsc",
|
||||||
"test": "nyc --reporter=cobertura --reporter=text mocha test/**/*.test.ts",
|
"test": "nyc --reporter=cobertura --reporter=text mocha \"test/**/*.test.ts\"",
|
||||||
"test_nnimanager": "nyc --reporter=cobertura --reporter=text mocha test/core/nnimanager.test.ts",
|
"test_nnimanager": "nyc --reporter=cobertura --reporter=text mocha test/core/nnimanager.test.ts",
|
||||||
"mocha": "mocha",
|
"mocha": "mocha",
|
||||||
"eslint": "eslint . --ext .ts"
|
"eslint": "eslint . --ext .ts"
|
||||||
|
|
|
@ -61,7 +61,7 @@ export class RestServerCore {
|
||||||
return deferred.promise;
|
return deferred.promise;
|
||||||
}
|
}
|
||||||
|
|
||||||
public shutdown(): Promise<void> {
|
public shutdown(timeoutMilliseconds?: number): Promise<void> {
|
||||||
logger.info('Stopping REST server.');
|
logger.info('Stopping REST server.');
|
||||||
if (this.server === null) {
|
if (this.server === null) {
|
||||||
logger.warning('REST server is not running.');
|
logger.warning('REST server is not running.');
|
||||||
|
@ -77,7 +77,7 @@ export class RestServerCore {
|
||||||
logger.debug('Killing connections');
|
logger.debug('Killing connections');
|
||||||
this.server?.closeAllConnections();
|
this.server?.closeAllConnections();
|
||||||
}
|
}
|
||||||
}, 5000);
|
}, timeoutMilliseconds ?? 5000);
|
||||||
return deferred.promise;
|
return deferred.promise;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,209 @@
|
||||||
|
// Copyright (c) Microsoft Corporation.
|
||||||
|
// Licensed under the MIT license.
|
||||||
|
|
||||||
|
import assert from 'node:assert/strict';
|
||||||
|
import { setTimeout } from 'node:timers/promises';
|
||||||
|
|
||||||
|
import type {
|
||||||
|
Command, CommandChannel, CommandChannelClient, CommandChannelServer
|
||||||
|
} from 'common/command_channel/interface';
|
||||||
|
import { UnitTestHelper as Helper } from 'common/command_channel/websocket/channel';
|
||||||
|
import { WsChannelClient, WsChannelServer } from 'common/command_channel/websocket/index';
|
||||||
|
import { globals } from 'common/globals/unittest';
|
||||||
|
import { RestServerCore } from 'rest_server/core';
|
||||||
|
|
||||||
|
describe('## websocket command channel ##', () => {
|
||||||
|
before(beforeHook);
|
||||||
|
|
||||||
|
it('start', testServerStart);
|
||||||
|
|
||||||
|
it('connect', testClientStart);
|
||||||
|
it('message', testMessage);
|
||||||
|
|
||||||
|
it('reconnect', testReconnect);
|
||||||
|
it('message', testMessage);
|
||||||
|
|
||||||
|
it('handle error', testError);
|
||||||
|
it('shutdown', testShutdown);
|
||||||
|
|
||||||
|
after(afterHook);
|
||||||
|
});
|
||||||
|
|
||||||
|
/* test cases */
|
||||||
|
|
||||||
|
async function testServerStart(): Promise<void> {
|
||||||
|
ut.server = new WsChannelServer('ut_server', 'ut');
|
||||||
|
|
||||||
|
ut.server.onReceive((channelId, command) => {
|
||||||
|
if (channelId === '1') {
|
||||||
|
ut.events.push({ event: 'server_receive_1', command });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
ut.server.onConnection((channelId, channel) => {
|
||||||
|
ut.events.push({ event: 'connect', channelId, channel });
|
||||||
|
|
||||||
|
channel.onClose(reason => {
|
||||||
|
ut.events.push({ event: `client_close_${channelId}`, reason });
|
||||||
|
});
|
||||||
|
channel.onError(error => {
|
||||||
|
ut.events.push({ event: `client_error_${channelId}`, error });
|
||||||
|
});
|
||||||
|
channel.onLost(() => {
|
||||||
|
ut.events.push({ event: `client_lost_${channelId}` });
|
||||||
|
});
|
||||||
|
|
||||||
|
if (channelId === '1') {
|
||||||
|
ut.serverChannel1 = channel;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (channelId === '2') {
|
||||||
|
ut.serverChannel2 = channel;
|
||||||
|
channel.onReceive(command => {
|
||||||
|
ut.events.push({ event: 'server_receive_2', command });
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
await ut.server.start();
|
||||||
|
}
|
||||||
|
|
||||||
|
async function testClientStart(): Promise<void> {
|
||||||
|
const url1 = ut.server.getChannelUrl('1');
|
||||||
|
const url2 = ut.server.getChannelUrl('2', '127.0.0.1');
|
||||||
|
assert.equal(url1, `ws://localhost:${globals.args.port}/ut/1`);
|
||||||
|
assert.equal(url2, `ws://127.0.0.1:${globals.args.port}/ut/2`);
|
||||||
|
|
||||||
|
ut.client1 = new WsChannelClient('ut_client_1', url1);
|
||||||
|
ut.client2 = new WsChannelClient('ut_client_2', url2);
|
||||||
|
|
||||||
|
ut.client1.onReceive(command => {
|
||||||
|
ut.events.push({ event: 'client_receive_1', command });
|
||||||
|
});
|
||||||
|
ut.client2.onCommand('ut_command', command => {
|
||||||
|
ut.events.push({ event: 'client_receive_2', command });
|
||||||
|
});
|
||||||
|
|
||||||
|
ut.client2.onClose(reason => {
|
||||||
|
ut.events.push({ event: 'server_close_2', reason });
|
||||||
|
});
|
||||||
|
|
||||||
|
await Promise.all([
|
||||||
|
ut.client1.connect(),
|
||||||
|
ut.client2.connect(),
|
||||||
|
]);
|
||||||
|
|
||||||
|
assert.equal(ut.events[0].event, 'connect');
|
||||||
|
assert.equal(ut.events[1].event, 'connect');
|
||||||
|
assert.equal(Number(ut.events[0].channelId) + Number(ut.events[1].channelId), 3);
|
||||||
|
assert.equal(ut.events.length, 2);
|
||||||
|
|
||||||
|
ut.events.length = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function testReconnect(): Promise<void> {
|
||||||
|
const ws = (ut.client1 as any).connection.ws; // NOTE: private api
|
||||||
|
ws.pause();
|
||||||
|
await setTimeout(heartbeatTimeout);
|
||||||
|
ws.terminate();
|
||||||
|
ws.resume();
|
||||||
|
|
||||||
|
// mac pipeline can be slow
|
||||||
|
for (let i = 0; i < 10; i++) {
|
||||||
|
await setTimeout(heartbeat);
|
||||||
|
if (ut.events.length > 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.ok(ut.countEvents('client_lost_1') >= 1);
|
||||||
|
assert.ok(ut.countEvents('client_close_1') == 0);
|
||||||
|
assert.ok(ut.countEvents('client_error_1') == 0);
|
||||||
|
assert.ok(ut.countEvents('connect') == 0); // reconnect is not connect
|
||||||
|
|
||||||
|
ut.events.length = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function testMessage(): Promise<void> {
|
||||||
|
ut.server.send('1', ut.packCommand(1));
|
||||||
|
await ut.client2.sendAsync(ut.packCommand(2));
|
||||||
|
ut.client2.send(ut.packCommand('三'));
|
||||||
|
ut.server.send('2', ut.packCommand('4'));
|
||||||
|
ut.client1.send(ut.packCommand(5));
|
||||||
|
ut.server.send('1', ut.packCommand(6));
|
||||||
|
|
||||||
|
await setTimeout(heartbeat);
|
||||||
|
|
||||||
|
assert.deepEqual(ut.filterCommands('client_receive_1'), [ 1, 6 ]);
|
||||||
|
assert.deepEqual(ut.filterCommands('client_receive_2'), [ '4' ]);
|
||||||
|
assert.deepEqual(ut.filterCommands('server_receive_1'), [ 5 ]);
|
||||||
|
assert.deepEqual(ut.filterCommands('server_receive_2'), [ 2, '三' ]);
|
||||||
|
|
||||||
|
ut.events.length = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function testError(): Promise<void> {
|
||||||
|
ut.client2.terminate('client 2 terminate');
|
||||||
|
await setTimeout(terminateTimeout * 1.1);
|
||||||
|
|
||||||
|
assert.ok(ut.countEvents('client_close_2') == 0);
|
||||||
|
assert.ok(ut.countEvents('client_error_2') == 1);
|
||||||
|
|
||||||
|
ut.events.length = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function testShutdown(): Promise<void> {
|
||||||
|
await ut.server.shutdown();
|
||||||
|
|
||||||
|
assert.equal(ut.countEvents('client_close_1'), 1);
|
||||||
|
|
||||||
|
ut.events.length = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* helpers and states */
|
||||||
|
|
||||||
|
// NOTE: Increase these numbers if it fails randomly
|
||||||
|
const heartbeat = 10;
|
||||||
|
const heartbeatTimeout = 50;
|
||||||
|
const terminateTimeout = 100;
|
||||||
|
|
||||||
|
async function beforeHook(): Promise<void> {
|
||||||
|
globals.reset();
|
||||||
|
|
||||||
|
ut.rest = new RestServerCore();
|
||||||
|
await ut.rest.start();
|
||||||
|
|
||||||
|
Helper.setHeartbeatInterval(heartbeat);
|
||||||
|
Helper.setTerminateTimeout(terminateTimeout);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function afterHook(): Promise<void> {
|
||||||
|
Helper.reset();
|
||||||
|
await ut.rest?.shutdown();
|
||||||
|
globals.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
class UnitTestStates {
|
||||||
|
server!: CommandChannelServer;
|
||||||
|
client1!: CommandChannelClient;
|
||||||
|
client2!: CommandChannelClient;
|
||||||
|
serverChannel1!: CommandChannel;
|
||||||
|
serverChannel2!: CommandChannel;
|
||||||
|
events: any[] = [];
|
||||||
|
|
||||||
|
rest!: RestServerCore;
|
||||||
|
|
||||||
|
countEvents(event: string): number {
|
||||||
|
return this.events.filter(e => (e.event === event)).length;
|
||||||
|
}
|
||||||
|
|
||||||
|
filterCommands(event: string): any[] {
|
||||||
|
return this.events.filter(e => (e.event === event)).map(e => e.command.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
packCommand(value: any): Command {
|
||||||
|
return { type: 'ut_command', value };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const ut = new UnitTestStates();
|
|
@ -39,16 +39,16 @@ export class RemoteTrainingServiceV3 implements TrainingServiceV3 {
|
||||||
this.log = getLogger(`RemoteV3.${this.id}`);
|
this.log = getLogger(`RemoteV3.${this.id}`);
|
||||||
this.log.debug('Training sevice config:', config);
|
this.log.debug('Training sevice config:', config);
|
||||||
|
|
||||||
this.server = new WsChannelServer('RemoteTrialKeeper', `platform/${this.id}`, config.nniManagerIp);
|
this.server = new WsChannelServer(this.id, `/platform/${this.id}`);
|
||||||
|
|
||||||
this.server.on('connection', (channelId: string, channel: WsChannel) => {
|
this.server.on('connection', (channelId: string, channel: WsChannel) => {
|
||||||
const worker = this.workersByChannel.get(channelId);
|
const worker = this.workersByChannel.get(channelId);
|
||||||
if (worker) {
|
if (worker) {
|
||||||
worker.setChannel(channel);
|
worker.setChannel(channel);
|
||||||
channel.on('close', reason => {
|
channel.onClose(reason => {
|
||||||
this.log.error('Worker channel closed unexpectedly:', reason);
|
this.log.error('Worker channel closed unexpectedly:', reason);
|
||||||
});
|
});
|
||||||
channel.on('error', error => {
|
channel.onError(error => {
|
||||||
this.log.error('Worker channel error:', error);
|
this.log.error('Worker channel error:', error);
|
||||||
this.restartWorker(worker);
|
this.restartWorker(worker);
|
||||||
});
|
});
|
||||||
|
@ -190,7 +190,7 @@ export class RemoteTrainingServiceV3 implements TrainingServiceV3 {
|
||||||
this.id,
|
this.id,
|
||||||
channelId,
|
channelId,
|
||||||
config,
|
config,
|
||||||
this.server.getChannelUrl(channelId),
|
this.server.getChannelUrl(channelId, this.config.nniManagerIp),
|
||||||
Boolean(this.config.trialGpuNumber)
|
Boolean(this.config.trialGpuNumber)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,7 @@ export class Worker {
|
||||||
public setChannel(channel: WsChannel): void {
|
public setChannel(channel: WsChannel): void {
|
||||||
this.channel = channel;
|
this.channel = channel;
|
||||||
this.trialKeeper.setChannel(channel);
|
this.trialKeeper.setChannel(channel);
|
||||||
channel.on('lost', async () => {
|
channel.onLost(async () => {
|
||||||
if (!await this.checkAlive()) {
|
if (!await this.checkAlive()) {
|
||||||
this.log.error('Trial keeper failed');
|
this.log.error('Trial keeper failed');
|
||||||
channel.terminate('Trial keeper failed'); // MARK
|
channel.terminate('Trial keeper failed'); // MARK
|
||||||
|
|
Загрузка…
Ссылка в новой задаче