зеркало из https://github.com/microsoft/nni.git
Merge command channel APIs and add unit tests (#5450)
This commit is contained in:
Родитель
1f6aedc48f
Коммит
cf5fabd968
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .channel import WsChannelClient
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import nni
|
||||
from ..base import Command, CommandChannel
|
||||
from .connection import WsConnection
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
class WsChannelClient(CommandChannel):
|
||||
def __init__(self, url: str):
|
||||
self._url: str = url
|
||||
self._closing: bool = False
|
||||
self._conn: WsConnection | None = None
|
||||
|
||||
def connect(self) -> None:
|
||||
_logger.debug(f'Connect to {self._url}')
|
||||
assert not self._closing
|
||||
self._ensure_conn()
|
||||
|
||||
def disconnect(self) -> None:
|
||||
_logger.debug(f'Disconnect from {self._url}')
|
||||
self.send({'type': '_bye_'})
|
||||
self._closing = True
|
||||
self._close_conn('client intentionally close')
|
||||
|
||||
def send(self, command: Command) -> None:
|
||||
if self._closing:
|
||||
return
|
||||
_logger.debug(f'Send {command}')
|
||||
msg = nni.dump(command)
|
||||
for i in range(5):
|
||||
try:
|
||||
conn = self._ensure_conn()
|
||||
conn.send(msg)
|
||||
return
|
||||
except Exception:
|
||||
_logger.exception(f'Failed to send command. Retry in {i}s')
|
||||
self._terminate_conn('send fail')
|
||||
time.sleep(i)
|
||||
_logger.warning(f'Failed to send command {command}. Last retry')
|
||||
conn = self._ensure_conn()
|
||||
conn.send(msg)
|
||||
|
||||
def receive(self) -> Command | None:
|
||||
while True:
|
||||
if self._closing:
|
||||
return None
|
||||
msg = self._receive_msg()
|
||||
if msg is None:
|
||||
return None
|
||||
command = nni.load(msg)
|
||||
if command['type'] == '_nop_':
|
||||
continue
|
||||
if command['type'] == '_bye_':
|
||||
reason = command.get('reason')
|
||||
_logger.debug(f'Server close connection: {reason}')
|
||||
self._closing = True
|
||||
self._close_conn('server intentionally close')
|
||||
return None
|
||||
return command
|
||||
|
||||
def _ensure_conn(self) -> WsConnection:
|
||||
if self._conn is None and not self._closing:
|
||||
self._conn = WsConnection(self._url)
|
||||
self._conn.connect()
|
||||
_logger.debug('Connected')
|
||||
return self._conn # type: ignore
|
||||
|
||||
def _close_conn(self, reason: str) -> None:
|
||||
if self._conn is not None:
|
||||
try:
|
||||
self._conn.disconnect(reason)
|
||||
except Exception:
|
||||
pass
|
||||
self._conn = None
|
||||
|
||||
def _terminate_conn(self, reason: str) -> None:
|
||||
if self._conn is not None:
|
||||
try:
|
||||
self._conn.terminate(reason)
|
||||
except Exception:
|
||||
pass
|
||||
self._conn = None
|
||||
|
||||
def _receive_msg(self) -> str | None:
|
||||
for i in range(5):
|
||||
try:
|
||||
conn = self._ensure_conn()
|
||||
msg = conn.receive()
|
||||
_logger.debug(f'Receive {msg}')
|
||||
if not self._closing:
|
||||
assert msg is not None
|
||||
return msg
|
||||
except Exception:
|
||||
_logger.exception(f'Failed to receive command. Retry in {i}s')
|
||||
self._terminate_conn('receive fail')
|
||||
time.sleep(i)
|
||||
_logger.warning(f'Failed to receive command. Last retry')
|
||||
conn = self._ensure_conn()
|
||||
conn.receive()
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
Synchronized and object-oriented WebSocket class.
|
||||
|
||||
WebSocket guarantees that messages will not be divided at API level.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ['WsConnection']
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from threading import Lock, Thread
|
||||
from typing import Any, Type
|
||||
|
||||
import websockets
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# the singleton event loop
|
||||
_event_loop: asyncio.AbstractEventLoop = None # type: ignore
|
||||
_event_loop_lock: Lock = Lock()
|
||||
_event_loop_refcnt: int = 0 # number of connected websockets
|
||||
|
||||
class WsConnection:
|
||||
"""
|
||||
A WebSocket connection.
|
||||
|
||||
Call :meth:`connect` before :meth:`send` and :meth:`receive`.
|
||||
|
||||
All methods are thread safe.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
url
|
||||
The WebSocket URL.
|
||||
For tuner command channel it should be something like ``ws://localhost:8080/tuner``.
|
||||
"""
|
||||
|
||||
ConnectionClosed: Type[Exception] = websockets.ConnectionClosed # type: ignore
|
||||
|
||||
def __init__(self, url: str):
|
||||
self._url: str = url
|
||||
self._ws: Any = None # the library does not provide type hints
|
||||
|
||||
def connect(self) -> None:
|
||||
global _event_loop, _event_loop_refcnt
|
||||
with _event_loop_lock:
|
||||
_event_loop_refcnt += 1
|
||||
if _event_loop is None:
|
||||
_logger.debug('Starting event loop.')
|
||||
# following line must be outside _run_event_loop
|
||||
# because _wait() might be executed before first line of the child thread
|
||||
_event_loop = asyncio.new_event_loop()
|
||||
thread = Thread(target=_run_event_loop, name='NNI-WebSocketEventLoop', daemon=True)
|
||||
thread.start()
|
||||
|
||||
_logger.debug(f'Connecting to {self._url}')
|
||||
self._ws = _wait(_connect_async(self._url))
|
||||
_logger.debug(f'Connected.')
|
||||
|
||||
def disconnect(self, reason: str | None = None, code: int | None = None) -> None:
|
||||
if self._ws is None:
|
||||
_logger.debug('disconnect: No connection.')
|
||||
return
|
||||
|
||||
try:
|
||||
_wait(self._ws.close(code or 4000, reason))
|
||||
_logger.debug('Connection closed by client.')
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to close connection: {repr(e)}')
|
||||
self._ws = None
|
||||
_decrease_refcnt()
|
||||
|
||||
def terminate(self, reason: str | None = None) -> None:
|
||||
if self._ws is None:
|
||||
_logger.debug('terminate: No connection.')
|
||||
return
|
||||
self.disconnect(reason, 4001)
|
||||
|
||||
def send(self, message: str) -> None:
|
||||
_logger.debug(f'Sending {message}')
|
||||
try:
|
||||
_wait(self._ws.send(message))
|
||||
except websockets.ConnectionClosed: # type: ignore
|
||||
_logger.debug('Connection closed by server.')
|
||||
self._ws = None
|
||||
_decrease_refcnt()
|
||||
raise
|
||||
|
||||
def receive(self) -> str | None:
|
||||
"""
|
||||
Return received message;
|
||||
or return ``None`` if the connection has been closed by peer.
|
||||
"""
|
||||
try:
|
||||
msg = _wait(self._ws.recv())
|
||||
_logger.debug(f'Received {msg}')
|
||||
except websockets.ConnectionClosed: # type: ignore
|
||||
_logger.debug('Connection closed by server.')
|
||||
self._ws = None
|
||||
_decrease_refcnt()
|
||||
raise
|
||||
|
||||
# seems the library will inference whether it's text or binary, so we don't have guarantee
|
||||
if isinstance(msg, bytes):
|
||||
return msg.decode()
|
||||
else:
|
||||
return msg
|
||||
|
||||
def _wait(coro):
|
||||
# Synchronized version of "await".
|
||||
future = asyncio.run_coroutine_threadsafe(coro, _event_loop)
|
||||
return future.result()
|
||||
|
||||
def _run_event_loop() -> None:
|
||||
# A separate thread to run the event loop.
|
||||
# The event loop itself is blocking, and send/receive are also blocking,
|
||||
# so they must run in different threads.
|
||||
asyncio.set_event_loop(_event_loop)
|
||||
_event_loop.run_forever()
|
||||
_logger.debug('Event loop stopped.')
|
||||
|
||||
async def _connect_async(url):
|
||||
# Theoretically this function is meaningless and one can directly use `websockets.connect(url)`,
|
||||
# but it will not work, raising "TypeError: A coroutine object is required".
|
||||
# Seems a design flaw in websockets library.
|
||||
return await websockets.connect(url, max_size=None) # type: ignore
|
||||
|
||||
def _decrease_refcnt() -> None:
|
||||
global _event_loop, _event_loop_refcnt
|
||||
with _event_loop_lock:
|
||||
_event_loop_refcnt -= 1
|
||||
if _event_loop_refcnt == 0:
|
||||
_event_loop.call_soon_threadsafe(_event_loop.stop)
|
||||
_event_loop = None # type: ignore
|
|
@ -1,133 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
Synchronized and object-oriented WebSocket class.
|
||||
|
||||
WebSocket guarantees that messages will not be divided at API level.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ['WebSocket']
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from threading import Lock, Thread
|
||||
from typing import Any, Type
|
||||
|
||||
import websockets
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# the singleton event loop
|
||||
_event_loop: asyncio.AbstractEventLoop = None # type: ignore
|
||||
_event_loop_lock: Lock = Lock()
|
||||
_event_loop_refcnt: int = 0 # number of connected websockets
|
||||
|
||||
class WebSocket:
|
||||
"""
|
||||
A WebSocket connection.
|
||||
|
||||
Call :meth:`connect` before :meth:`send` and :meth:`receive`.
|
||||
|
||||
All methods are thread safe.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
url
|
||||
The WebSocket URL.
|
||||
For tuner command channel it should be something like ``ws://localhost:8080/tuner``.
|
||||
"""
|
||||
|
||||
ConnectionClosed: Type[Exception] = websockets.ConnectionClosed # type: ignore
|
||||
|
||||
def __init__(self, url: str):
|
||||
self._url: str = url
|
||||
self._ws: Any = None # the library does not provide type hints
|
||||
|
||||
def connect(self) -> None:
|
||||
global _event_loop, _event_loop_refcnt
|
||||
with _event_loop_lock:
|
||||
_event_loop_refcnt += 1
|
||||
if _event_loop is None:
|
||||
_logger.debug('Starting event loop.')
|
||||
# following line must be outside _run_event_loop
|
||||
# because _wait() might be executed before first line of the child thread
|
||||
_event_loop = asyncio.new_event_loop()
|
||||
thread = Thread(target=_run_event_loop, name='NNI-WebSocketEventLoop', daemon=True)
|
||||
thread.start()
|
||||
|
||||
_logger.debug(f'Connecting to {self._url}')
|
||||
self._ws = _wait(_connect_async(self._url))
|
||||
_logger.debug(f'Connected.')
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if self._ws is None:
|
||||
_logger.debug('disconnect: No connection.')
|
||||
return
|
||||
|
||||
try:
|
||||
_wait(self._ws.close())
|
||||
_logger.debug('Connection closed by client.')
|
||||
except Exception as e:
|
||||
_logger.warning(f'Failed to close connection: {repr(e)}')
|
||||
self._ws = None
|
||||
_decrease_refcnt()
|
||||
|
||||
def send(self, message: str) -> None:
|
||||
_logger.debug(f'Sending {message}')
|
||||
try:
|
||||
_wait(self._ws.send(message))
|
||||
except websockets.ConnectionClosed: # type: ignore
|
||||
_logger.debug('Connection closed by server.')
|
||||
self._ws = None
|
||||
_decrease_refcnt()
|
||||
raise
|
||||
|
||||
def receive(self) -> str | None:
|
||||
"""
|
||||
Return received message;
|
||||
or return ``None`` if the connection has been closed by peer.
|
||||
"""
|
||||
try:
|
||||
msg = _wait(self._ws.recv())
|
||||
_logger.debug(f'Received {msg}')
|
||||
except websockets.ConnectionClosed: # type: ignore
|
||||
_logger.debug('Connection closed by server.')
|
||||
self._ws = None
|
||||
_decrease_refcnt()
|
||||
raise
|
||||
|
||||
# seems the library will inference whether it's text or binary, so we don't have guarantee
|
||||
if isinstance(msg, bytes):
|
||||
return msg.decode()
|
||||
else:
|
||||
return msg
|
||||
|
||||
def _wait(coro):
|
||||
# Synchronized version of "await".
|
||||
future = asyncio.run_coroutine_threadsafe(coro, _event_loop)
|
||||
return future.result()
|
||||
|
||||
def _run_event_loop() -> None:
|
||||
# A separate thread to run the event loop.
|
||||
# The event loop itself is blocking, and send/receive are also blocking,
|
||||
# so they must run in different threads.
|
||||
asyncio.set_event_loop(_event_loop)
|
||||
_event_loop.run_forever()
|
||||
_logger.debug('Event loop stopped.')
|
||||
|
||||
async def _connect_async(url):
|
||||
# Theoretically this function is meaningless and one can directly use `websockets.connect(url)`,
|
||||
# but it will not work, raising "TypeError: A coroutine object is required".
|
||||
# Seems a design flaw in websockets library.
|
||||
return await websockets.connect(url, max_size=None) # type: ignore
|
||||
|
||||
def _decrease_refcnt() -> None:
|
||||
global _event_loop, _event_loop_refcnt
|
||||
with _event_loop_lock:
|
||||
_event_loop_refcnt -= 1
|
||||
if _event_loop_refcnt == 0:
|
||||
_event_loop.call_soon_threadsafe(_event_loop.stop)
|
||||
_event_loop = None # type: ignore
|
||||
from ..command_channel.websocket.connection import WsConnection as WebSocket # pylint: disable=unused-import
|
||||
|
|
|
@ -32,7 +32,7 @@ async def read_stdin():
|
|||
line = line.decode().strip()
|
||||
_debug(f'read from stdin: {line}')
|
||||
if line == '_close_':
|
||||
exit()
|
||||
break
|
||||
await _ws.send(line)
|
||||
|
||||
async def ws_server():
|
||||
|
@ -46,9 +46,12 @@ async def on_connect(ws):
|
|||
global _ws
|
||||
_debug('connected')
|
||||
_ws = ws
|
||||
async for msg in ws:
|
||||
_debug(f'received from websocket: {msg}')
|
||||
print(msg, flush=True)
|
||||
try:
|
||||
async for msg in ws:
|
||||
_debug(f'received from websocket: {msg}')
|
||||
print(msg, flush=True)
|
||||
except websockets.exceptions.ConnectionClosedError:
|
||||
pass
|
||||
|
||||
def _debug(msg):
|
||||
#sys.stderr.write(f'[server-debug] {msg}\n')
|
||||
|
|
|
@ -11,21 +11,21 @@ from subprocess import Popen, PIPE
|
|||
import sys
|
||||
import time
|
||||
|
||||
from nni.runtime.tuner_command_channel.websocket import WebSocket
|
||||
from nni.runtime.command_channel.websocket import WsChannelClient
|
||||
|
||||
# A helper server that connects its stdio to incoming WebSocket.
|
||||
_server = None
|
||||
_client = None
|
||||
|
||||
_command1 = 'T_hello world'
|
||||
_command2 = 'T_你好'
|
||||
_command1 = {'type': 'ut_command', 'value': 123}
|
||||
_command2 = {'type': 'ut_command', 'value': '你好'}
|
||||
|
||||
## test cases ##
|
||||
|
||||
def test_connect():
|
||||
global _client
|
||||
port = _init()
|
||||
_client = WebSocket(f'ws://localhost:{port}')
|
||||
_client = WsChannelClient(f'ws://localhost:{port}')
|
||||
_client.connect()
|
||||
|
||||
def test_send():
|
||||
|
@ -34,16 +34,16 @@ def test_send():
|
|||
_client.send(_command2)
|
||||
time.sleep(0.01)
|
||||
|
||||
sent1 = _server.stdout.readline().strip()
|
||||
sent1 = json.loads(_server.stdout.readline())
|
||||
assert sent1 == _command1, sent1
|
||||
|
||||
sent2 = _server.stdout.readline().strip()
|
||||
sent2 = json.loads(_server.stdout.readline().strip())
|
||||
assert sent2 == _command2, sent2
|
||||
|
||||
def test_receive():
|
||||
# Send commands to server via stdin, and get them back via channel.
|
||||
_server.stdin.write(_command1 + '\n')
|
||||
_server.stdin.write(_command2 + '\n')
|
||||
_server.stdin.write(json.dumps(_command1) + '\n')
|
||||
_server.stdin.write(json.dumps(_command2) + '\n')
|
||||
_server.stdin.flush()
|
||||
|
||||
received1 = _client.receive()
|
|
@ -51,8 +51,8 @@ export class HttpChannelServer implements CommandChannelServer {
|
|||
this.outgoingQueues.forEach(queue => { queue.clear(); });
|
||||
}
|
||||
|
||||
public getChannelUrl(channelId: string): string {
|
||||
return globals.rest.getFullUrl('http', 'localhost', this.path, channelId);
|
||||
public getChannelUrl(channelId: string, ip?: string): string {
|
||||
return globals.rest.getFullUrl('http', ip ?? 'localhost', this.path, channelId);
|
||||
}
|
||||
|
||||
public send(channelId: string, command: Command): void {
|
||||
|
@ -63,6 +63,10 @@ export class HttpChannelServer implements CommandChannelServer {
|
|||
this.emitter.on('receive', callback);
|
||||
}
|
||||
|
||||
public onConnection(_callback: (channelId: string, channel: any) => void): void {
|
||||
throw new Error('Not implemented');
|
||||
}
|
||||
|
||||
private handleGet(request: Request, response: Response): void {
|
||||
const channelId = request.params['channel'];
|
||||
const promise = this.getOutgoingQueue(channelId).asyncPop(timeoutMilliseconds);
|
||||
|
|
|
@ -1,26 +1,158 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
//export interface Command {
|
||||
// type: string;
|
||||
// [key: string]: any;
|
||||
//}
|
||||
/**
|
||||
* Common interface of command channels.
|
||||
*
|
||||
* A command channel is a duplex connection which supports sending and receiving JSON commands.
|
||||
*
|
||||
* Typically a command channel implementation consists of a server and a client.
|
||||
*
|
||||
* The server should listen to a URL prefix like `http://localhost:8080/example/`;
|
||||
* and each client should connect to a unique URL containing the prefix, e.g. `http://localhost:8080/example/channel1`.
|
||||
* The client's URL should be created with `server.getChannelUrl(channelId, serverIp)`.
|
||||
*
|
||||
* We currently have implemented one full feature command channel, the WebSocket channel,
|
||||
* and a simplified one, the HTTP channel.
|
||||
* In v3.1 release we might implement file command channel and AzureML command channel.
|
||||
*
|
||||
* The clients might have a Python version locates in `nni/runtime/command_channel`.
|
||||
* The TypeScript and Python version should be interchangable.
|
||||
**/
|
||||
|
||||
/**
|
||||
* A command is a JSON object.
|
||||
*
|
||||
* The object has only one mandatory entry, `type`.
|
||||
*
|
||||
* The type string should not be surrounded by underscore (e.g. `_nop_`),
|
||||
* unless you are dealing with the underlying implementation of a specific command channel;
|
||||
* it should never starts with two underscores (e.g. `__command`) in any circumstance.
|
||||
**/
|
||||
export type Command = any;
|
||||
|
||||
// Maybe it's better to disable `noPropertyAccessFromIndexSignature` in tscofnig?
|
||||
// export interface Command {
|
||||
// type: string;
|
||||
// [key: string]: any;
|
||||
// }
|
||||
|
||||
/**
|
||||
* A command channel server serves one or more command channels.
|
||||
* Each channel is connected to a client.
|
||||
* `CommandChannel` is the base interface used by both the servers and the clients.
|
||||
*
|
||||
* Normally each client has a unique channel URL,
|
||||
* which can be got with `server.getChannelUrl(id)`.
|
||||
* For servers, channels can be got with `onConnection()` event listener.
|
||||
* For clients, a channel can be created with the client subclass' constructor.
|
||||
*
|
||||
* The APIs might be changed to return `Promise<void>` in future.
|
||||
* The channel should be fault tolerant to some extend. It has three different types of closing related events:
|
||||
*
|
||||
* 1. Close: The channel is intentionally closed.
|
||||
*
|
||||
* 2. Lost: The channel is temporarily unavailable and is trying to recover.
|
||||
* The user of this class should examine the peer's status out-of-band when receiving "lost" event.
|
||||
*
|
||||
* 3. Error: The channel is dead and cannot recover.
|
||||
* A "close" event may or may not occur following this event. Do not rely on that.
|
||||
**/
|
||||
export interface CommandChannel {
|
||||
readonly name: string; // for better logging
|
||||
|
||||
enableHeartbeat(intervalMilliseconds?: number): void;
|
||||
|
||||
/**
|
||||
* Graceful (intentional) close.
|
||||
* A "close" event will be emitted by `this` and the peer.
|
||||
**/
|
||||
close(reason: string): void;
|
||||
|
||||
/**
|
||||
* Force close. Should only be used when the channel is not working.
|
||||
* An "error" event may be emitted by `this`.
|
||||
* A "lost" and/or "error" event will be emitted by the peer, if its process is still alive.
|
||||
**/
|
||||
terminate(reason: string): void;
|
||||
|
||||
send(command: Command): void;
|
||||
|
||||
/**
|
||||
* The async version should try to ensures the command is successfully sent to the peer.
|
||||
* But this is not guaranteed.
|
||||
**/
|
||||
sendAsync(command: Command): Promise<void>;
|
||||
|
||||
onReceive(callback: (command: Command) => void): void;
|
||||
onCommand(commandType: string, callback: (command: Command) => void): void;
|
||||
|
||||
onClose(callback: (reason?: string) => void): void;
|
||||
onError(callback: (error: Error) => void): void;
|
||||
onLost(callback: () => void): void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Client side of a command channel.
|
||||
*
|
||||
* The constructor should have no side effects.
|
||||
*
|
||||
* The listeners should be registered before calling `connect()`,
|
||||
* or the first few commands might be missed.
|
||||
*
|
||||
* Example usage:
|
||||
*
|
||||
* const client = new WsChannelClient('example', 'ws://1.2.3.4:8080/server/channel_id');
|
||||
* await client.connect();
|
||||
* client.send(command);
|
||||
**/
|
||||
export interface CommandChannelClient extends CommandChannel {
|
||||
// constructor(name: string, url: string);
|
||||
|
||||
connect(): Promise<void>;
|
||||
|
||||
/**
|
||||
* Typically an alias of `close()`.
|
||||
**/
|
||||
disconnect(reason?: string): Promise<void>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Server side of a command channel.
|
||||
*
|
||||
* The consructor should have no side effects.
|
||||
*
|
||||
* The listeners should be registered before calling `start()`.
|
||||
*
|
||||
* Example usage:
|
||||
*
|
||||
* const server = new WsChannelServer('example_server', '/server_prefix');
|
||||
* const url = server.getChannelUrl('channel_id');
|
||||
* const client = new WsChannelClient('example_client', url);
|
||||
* await server.start();
|
||||
* await client.connect();
|
||||
*
|
||||
* There two ways to listen to command:
|
||||
*
|
||||
* 1. Handle all clients' commands in one space:
|
||||
*
|
||||
* server.onReceive((channelId, command) => { ... });
|
||||
* server.send(channelId, command);
|
||||
*
|
||||
* 2. Maintain a `WsChannel` instance for each client:
|
||||
*
|
||||
* server.onConnection((channelId, channel) => {
|
||||
* channel.onCommand(command => { ... });
|
||||
* channel.send(command);
|
||||
* });
|
||||
**/
|
||||
export interface CommandChannelServer {
|
||||
// constructor(name: string, urlPath: string)
|
||||
// constructor(name: string, urlPath: string);
|
||||
|
||||
start(): Promise<void>;
|
||||
shutdown(): Promise<void>;
|
||||
getChannelUrl(channelId: string): string;
|
||||
|
||||
/**
|
||||
* When `ip` is missing, it should default to localhost.
|
||||
**/
|
||||
getChannelUrl(channelId: string, ip?: string): string;
|
||||
|
||||
send(channelId: string, command: Command): void;
|
||||
onReceive(callback: (channelId: string, command: Command) => void): void;
|
||||
onConnection(callback: (channelId: string, channel: CommandChannel) => void): void;
|
||||
}
|
||||
|
|
|
@ -50,7 +50,7 @@ import { DefaultMap } from 'common/default_map';
|
|||
import { Deferred } from 'common/deferred';
|
||||
import { Logger, getLogger } from 'common/log';
|
||||
import type { TrialKeeper } from 'common/trial_keeper/keeper';
|
||||
import { WsChannel } from './websocket/channel';
|
||||
import type { CommandChannel } from './interface';
|
||||
|
||||
interface RpcResponseCommand {
|
||||
type: 'rpc_response';
|
||||
|
@ -61,14 +61,14 @@ interface RpcResponseCommand {
|
|||
|
||||
type Class = { new(...args: any[]): any; };
|
||||
|
||||
const rpcHelpers: Map<WsChannel, RpcHelper> = new Map();
|
||||
const rpcHelpers: Map<CommandChannel, RpcHelper> = new Map();
|
||||
|
||||
/**
|
||||
* Enable RPC on a channel.
|
||||
*
|
||||
* The channel does not need to be connected for calling this function.
|
||||
**/
|
||||
export function getRpcHelper(channel: WsChannel): RpcHelper {
|
||||
export function getRpcHelper(channel: CommandChannel): RpcHelper {
|
||||
if (!rpcHelpers.has(channel)) {
|
||||
rpcHelpers.set(channel, new RpcHelper(channel));
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ export function getRpcHelper(channel: WsChannel): RpcHelper {
|
|||
}
|
||||
|
||||
export class RpcHelper {
|
||||
private channel: WsChannel;
|
||||
private channel: CommandChannel;
|
||||
private lastId: number = 0;
|
||||
private localCtors: Map<string, Class> = new Map();
|
||||
private localObjs: Map<number, any> = new Map();
|
||||
|
@ -87,7 +87,7 @@ export class RpcHelper {
|
|||
/**
|
||||
* NOTE: Don't use this constructor directly. Use `getRpcHelper()`.
|
||||
**/
|
||||
constructor(channel: WsChannel) {
|
||||
constructor(channel: CommandChannel) {
|
||||
this.log = getLogger(`RpcHelper.${channel.name}`);
|
||||
this.channel = channel;
|
||||
this.channel.onCommand('rpc_constructor', command => {
|
||||
|
|
|
@ -4,353 +4,240 @@
|
|||
/**
|
||||
* WebSocket command channel.
|
||||
*
|
||||
* This is the base class that used by both server and client.
|
||||
* A WsChannel operates on one WebSocket connection at a time.
|
||||
* But when the network is unstable, it may close the underlying connection and create a new one.
|
||||
* This is generally transparent to the user of this class, except that a "lost" event will be emitted.
|
||||
*
|
||||
* For the server, channels can be got with `onConnection()` event listener.
|
||||
* For the client, a channel can be created with `new WsChannelClient()` subclass.
|
||||
* Do not use the constructor directly.
|
||||
* To distinguish intentional close from connection lost,
|
||||
* a "_bye_" command will be sent when `close()` or `disconnect()` is invoked.
|
||||
*
|
||||
* The channel is fault tolerant to some extend. It has three different types of closing related events:
|
||||
* If the connection is closed before receiving "_bye_" command, a "lost" event will be emitted and:
|
||||
*
|
||||
* 1. "close": The channel is intentionally closed.
|
||||
* * The client will try to reconnect for severaly times in around 15s.
|
||||
* * The server will wait the client to reconnect for around 30s.
|
||||
*
|
||||
* This is caused either by "close()" or "disconnect()" call, or by receiving a "_bye_" command from the peer.
|
||||
*
|
||||
* 2. "lost": The channel is temporarily unavailable and is trying to recover.
|
||||
* (The high level class should examine the peer's status out-of-band when receiving this event.)
|
||||
*
|
||||
* When the underlying socket is dead, this event is emitted.
|
||||
* The client will try to reconnect in around 15s. If all attempts fail, an "error" event will be emitted.
|
||||
* The server will wait the client for 30s. If it does not reconnect, an "error" event will be emitted.
|
||||
* Successful recover will not emit command.
|
||||
*
|
||||
* 3. "error": The channel is dead and cannot recover.
|
||||
*
|
||||
* A "close" event may or may not follow this event. Do not rely on that.
|
||||
* If the reconnecting attempt failed, both side will emit an "error" event.
|
||||
**/
|
||||
|
||||
import { EventEmitter } from 'node:events';
|
||||
import { EventEmitter, once } from 'node:events';
|
||||
import util from 'node:util';
|
||||
|
||||
import type { WebSocket } from 'ws';
|
||||
|
||||
import type { Command, CommandChannel } from 'common/command_channel/interface';
|
||||
import { Deferred } from 'common/deferred';
|
||||
import { Logger, getLogger } from 'common/log';
|
||||
import { WsConnection } from './connection';
|
||||
|
||||
import type { Command } from '../interface';
|
||||
|
||||
interface WsChannelEvents {
|
||||
'command': (command: Command) => void;
|
||||
'close': (reason: string) => void;
|
||||
'lost': () => void;
|
||||
'error': (error: Error) => void; // not used in base class
|
||||
interface QueuedCommand {
|
||||
command: Command;
|
||||
deferred?: Deferred<void>;
|
||||
}
|
||||
|
||||
export declare interface WsChannel {
|
||||
on<E extends keyof WsChannelEvents>(event: E, listener: WsChannelEvents[E]): this;
|
||||
}
|
||||
|
||||
export class WsChannel extends EventEmitter {
|
||||
export class WsChannel implements CommandChannel {
|
||||
private closing: boolean = false;
|
||||
private commandEmitter: EventEmitter = new EventEmitter();
|
||||
private connection: WsConnection | null = null;
|
||||
private epoch: number = 0;
|
||||
private heartbeatInterval: number | null;
|
||||
private connection: WsConnection | null = null; // NOTE: used in unit test
|
||||
private epoch: number = -1;
|
||||
private heartbeatInterval: number | null = null;
|
||||
private log: Logger;
|
||||
private queue: QueuedCommand[] = [];
|
||||
private terminateTimer: NodeJS.Timer | null = null;
|
||||
|
||||
protected emitter: EventEmitter = new EventEmitter();
|
||||
|
||||
public readonly name: string;
|
||||
|
||||
// internal, don't use
|
||||
constructor(name: string, ws?: WebSocket, heartbeatInterval?: number) {
|
||||
super()
|
||||
constructor(name: string) {
|
||||
this.log = getLogger(`WsChannel.${name}`);
|
||||
this.name = name;
|
||||
this.heartbeatInterval = heartbeatInterval || null;
|
||||
if (ws) {
|
||||
this.setConnection(ws);
|
||||
}
|
||||
}
|
||||
|
||||
public enableHeartbeat(interval: number): void {
|
||||
this.log.debug('## enable heartbeat');
|
||||
this.heartbeatInterval = interval;
|
||||
}
|
||||
|
||||
// internal, don't use
|
||||
public setConnection(ws: WebSocket): void {
|
||||
if (this.connection) {
|
||||
this.log.debug('Abandon previous connection');
|
||||
this.epoch += 1;
|
||||
public async setConnection(ws: WebSocket, waitOpen: boolean): Promise<void> {
|
||||
if (this.terminateTimer) {
|
||||
clearTimeout(this.terminateTimer);
|
||||
this.terminateTimer = null;
|
||||
}
|
||||
|
||||
this.connection?.terminate('new epoch start');
|
||||
this.newEpoch();
|
||||
this.log.debug(`Epoch ${this.epoch} start`);
|
||||
|
||||
this.connection = this.configConnection(ws);
|
||||
if (waitOpen) {
|
||||
await once(ws, 'open');
|
||||
}
|
||||
|
||||
while (this.connection && this.queue.length > 0) {
|
||||
const item = this.queue.shift()!;
|
||||
try {
|
||||
await this.connection.sendAsync(item.command);
|
||||
item.deferred?.resolve();
|
||||
} catch (error) {
|
||||
this.log.error('Failed to send command on recovered channel:', error);
|
||||
this.log.error('Dropped command:', item.command);
|
||||
item.deferred?.reject(error as any);
|
||||
// it should trigger connection's error event and this.connection will be set to null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public enableHeartbeat(interval?: number): void {
|
||||
this.heartbeatInterval = interval ?? defaultHeartbeatInterval;
|
||||
this.connection?.setHeartbeatInterval(this.heartbeatInterval);
|
||||
}
|
||||
|
||||
public close(reason: string): void {
|
||||
this.log.debug('Close channel:', reason);
|
||||
if (this.connection) {
|
||||
this.connection.close(reason);
|
||||
this.endEpoch();
|
||||
}
|
||||
if (!this.closing) {
|
||||
this.closing = true;
|
||||
this.emit('close', reason);
|
||||
this.connection?.close(reason);
|
||||
if (this.setClosing()) {
|
||||
this.emitter.emit('__close', reason);
|
||||
}
|
||||
}
|
||||
|
||||
public terminate(reason: string): void {
|
||||
this.log.info('Terminate channel:', reason);
|
||||
this.closing = true;
|
||||
if (this.connection) {
|
||||
this.connection.terminate(reason);
|
||||
this.endEpoch();
|
||||
this.connection?.terminate(reason);
|
||||
if (this.setClosing()) {
|
||||
this.emitter.emit('__error', new Error(`WsChannel terminated: ${reason}`));
|
||||
}
|
||||
}
|
||||
|
||||
public send(command: Command): void {
|
||||
if (this.connection) {
|
||||
this.connection.send(command);
|
||||
} else {
|
||||
// TODO: add a queue?
|
||||
this.log.error('Connection lost. Dropped command', command);
|
||||
if (this.closing) {
|
||||
this.log.error('Channel closed. Ignored command', command);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!this.connection) {
|
||||
this.log.warning('Connection lost. Enqueue command', command);
|
||||
this.queue.push({ command });
|
||||
return;
|
||||
}
|
||||
|
||||
this.connection.send(command);
|
||||
}
|
||||
|
||||
/**
|
||||
* Async version of `send()` that (partially) ensures the command is successfully sent to peer.
|
||||
**/
|
||||
public sendAsync(command: Command): Promise<void> {
|
||||
if (this.connection) {
|
||||
return this.connection.sendAsync(command);
|
||||
} else {
|
||||
this.log.error('Connection lost. Dropped command async', command);
|
||||
return Promise.reject(new Error('Connection is lost and trying to recover, cannot send command now'));
|
||||
if (this.closing) {
|
||||
this.log.error('(async) Channel closed. Refused command', command);
|
||||
return Promise.reject(new Error('WsChannel has been closed'));
|
||||
}
|
||||
|
||||
if (!this.connection) {
|
||||
this.log.warning('(async) Connection lost. Enqueue command', command);
|
||||
const deferred = new Deferred<void>();
|
||||
this.queue.push({ command, deferred });
|
||||
return deferred.promise;
|
||||
}
|
||||
|
||||
return this.connection.sendAsync(command);
|
||||
}
|
||||
|
||||
// the first overload listens to all commands, while the second listens to one command type
|
||||
public onCommand(callback: (command: Command) => void): void;
|
||||
public onCommand(commandType: string, callback: (command: Command) => void): void;
|
||||
public onReceive(callback: (command: Command) => void): void {
|
||||
this.emitter.on('__receive', callback);
|
||||
}
|
||||
|
||||
public onCommand(commandTypeOrCallback: any, callbackOrNone?: any): void {
|
||||
if (callbackOrNone) {
|
||||
this.commandEmitter.on(commandTypeOrCallback, callbackOrNone);
|
||||
} else {
|
||||
this.commandEmitter.on('__any', commandTypeOrCallback);
|
||||
}
|
||||
public onCommand(commandType: string, callback: (command: Command) => void): void {
|
||||
this.emitter.on(commandType, callback);
|
||||
}
|
||||
|
||||
public onClose(callback: (reason?: string) => void): void {
|
||||
this.emitter.on('__close', callback);
|
||||
}
|
||||
|
||||
public onError(callback: (error: Error) => void): void {
|
||||
this.emitter.on('__error', callback);
|
||||
}
|
||||
|
||||
public onLost(callback: () => void): void {
|
||||
this.emitter.on('__lost', callback);
|
||||
}
|
||||
|
||||
private newEpoch(): void {
|
||||
this.connection = null;
|
||||
this.epoch += 1;
|
||||
}
|
||||
|
||||
private configConnection(ws: WebSocket): WsConnection {
|
||||
this.log.debug('## config connection');
|
||||
const epoch = this.epoch; // copy it to use in closure
|
||||
const conn = new WsConnection(
|
||||
this.epoch ? `${this.name}.${epoch}` : this.name,
|
||||
ws,
|
||||
this.commandEmitter,
|
||||
this.heartbeatInterval
|
||||
);
|
||||
const connName = this.epoch ? `${this.name}.${this.epoch}` : this.name;
|
||||
const conn = new WsConnection(connName, ws, this.emitter);
|
||||
if (this.heartbeatInterval) {
|
||||
conn.setHeartbeatInterval(this.heartbeatInterval);
|
||||
}
|
||||
|
||||
conn.on('bye', reason => {
|
||||
this.log.debug('Peer intentionally close:', reason);
|
||||
this.endEpoch();
|
||||
if (!this.closing) {
|
||||
this.closing = true;
|
||||
this.emit('close', reason);
|
||||
this.log.debug('Peer intentionally closing:', reason);
|
||||
if (this.setClosing()) {
|
||||
this.emitter.emit('__close', reason);
|
||||
}
|
||||
});
|
||||
|
||||
conn.on('close', (code, reason) => {
|
||||
this.closeConnection(epoch, `Received closing handshake: ${code} ${reason}`);
|
||||
this.log.debug('Peer closed:', reason);
|
||||
this.dropConnection(conn, `Peer closed: ${code} ${reason}`);
|
||||
});
|
||||
|
||||
conn.on('error', error => {
|
||||
this.closeConnection(epoch, `Error occurred: ${util.inspect(error)}`);
|
||||
this.dropConnection(conn, `Connection error: ${util.inspect(error)}`);
|
||||
});
|
||||
|
||||
return conn;
|
||||
}
|
||||
|
||||
private closeConnection(epoch: number, reason: string): void {
|
||||
private setClosing(): boolean {
|
||||
if (this.closing) {
|
||||
this.log.debug('Connection cleaned up:', reason);
|
||||
return false;
|
||||
}
|
||||
this.closing = true;
|
||||
this.newEpoch();
|
||||
this.queue.forEach(item => {
|
||||
item.deferred?.reject(new Error('WsChannel has been closed.'));
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
private dropConnection(conn: WsConnection, reason: string): void {
|
||||
if (this.closing) {
|
||||
this.log.debug('Clean up:', reason);
|
||||
return;
|
||||
}
|
||||
if (this.epoch !== epoch) { // the connection is already abandoned
|
||||
this.log.debug(`Previous connection closed ${epoch}: ${reason}`);
|
||||
if (this.connection !== conn) { // the connection is already abandoned
|
||||
this.log.debug(`Previous connection closed: ${reason}`);
|
||||
return;
|
||||
}
|
||||
|
||||
this.log.warning('Connection closed unexpectedly:', reason);
|
||||
this.emit('lost');
|
||||
this.endEpoch();
|
||||
this.newEpoch();
|
||||
this.emitter.emit('__lost');
|
||||
|
||||
if (!this.terminateTimer) {
|
||||
this.terminateTimer = setTimeout(() => {
|
||||
if (!this.closing) {
|
||||
this.terminate('have not reconnected in 30s');
|
||||
}
|
||||
}, terminateTimeout);
|
||||
}
|
||||
|
||||
// the reconnect logic is in client subclass
|
||||
}
|
||||
|
||||
private endEpoch(): void {
|
||||
this.connection = null;
|
||||
this.epoch += 1;
|
||||
}
|
||||
}
|
||||
|
||||
interface WsConnectionEvents {
|
||||
'command': (command: Command) => void;
|
||||
'bye': (reason: string) => void;
|
||||
'close': (code: number, reason: string) => void;
|
||||
'error': (error: Error) => void;
|
||||
}
|
||||
let defaultHeartbeatInterval: number = 5000;
|
||||
let terminateTimeout: number = 30000;
|
||||
|
||||
declare interface WsConnection {
|
||||
on<E extends keyof WsConnectionEvents>(event: E, listener: WsConnectionEvents[E]): this;
|
||||
}
|
||||
|
||||
class WsConnection extends EventEmitter {
|
||||
private closing: boolean = false;
|
||||
private commandEmitter: EventEmitter;
|
||||
private heartbeatTimer: NodeJS.Timer | null = null;
|
||||
private log: Logger;
|
||||
private missingPongs: number = 0;
|
||||
private ws: WebSocket;
|
||||
|
||||
constructor(name: string, ws: WebSocket, commandEmitter: EventEmitter, heartbeatInterval: number | null) {
|
||||
super();
|
||||
this.log = getLogger(`WsConnection.${name}`);
|
||||
this.ws = ws;
|
||||
this.commandEmitter = commandEmitter;
|
||||
|
||||
ws.on('close', this.handleClose.bind(this));
|
||||
ws.on('error', this.handleError.bind(this));
|
||||
ws.on('message', this.handleMessage.bind(this));
|
||||
ws.on('pong', this.handlePong.bind(this));
|
||||
|
||||
if (heartbeatInterval) {
|
||||
this.heartbeatTimer = setInterval(this.heartbeat.bind(this), heartbeatInterval);
|
||||
}
|
||||
export namespace UnitTestHelper {
|
||||
export function setHeartbeatInterval(ms: number): void {
|
||||
defaultHeartbeatInterval = ms;
|
||||
}
|
||||
|
||||
public async close(reason: string): Promise<void> {
|
||||
if (this.closing) {
|
||||
this.log.debug('Close again:', reason);
|
||||
return;
|
||||
}
|
||||
|
||||
this.log.debug('Close:', reason);
|
||||
this.closing = true;
|
||||
if (this.heartbeatTimer) {
|
||||
clearInterval(this.heartbeatTimer);
|
||||
this.heartbeatTimer = null;
|
||||
}
|
||||
|
||||
try {
|
||||
await this.sendAsync({ type: '_bye_', reason });
|
||||
} catch (error) {
|
||||
this.log.error('Failed to send bye:', error);
|
||||
}
|
||||
|
||||
try {
|
||||
this.ws.close(4000, reason);
|
||||
} catch (error) {
|
||||
this.log.error('Failed to close:', error);
|
||||
this.ws.terminate();
|
||||
}
|
||||
export function setTerminateTimeout(ms: number): void {
|
||||
terminateTimeout = ms;
|
||||
}
|
||||
|
||||
public terminate(reason: string): void {
|
||||
this.log.debug('Terminate:', reason);
|
||||
this.closing = true;
|
||||
if (this.heartbeatTimer) {
|
||||
clearInterval(this.heartbeatTimer);
|
||||
this.heartbeatTimer = null;
|
||||
}
|
||||
|
||||
try {
|
||||
this.ws.close(4001, reason);
|
||||
} catch (error) {
|
||||
this.log.debug('Failed to close:', error);
|
||||
this.ws.terminate();
|
||||
}
|
||||
}
|
||||
|
||||
public send(command: Command): void {
|
||||
this.log.trace('Sending command', command);
|
||||
this.ws.send(JSON.stringify(command));
|
||||
}
|
||||
|
||||
public sendAsync(command: Command): Promise<void> {
|
||||
this.log.trace('Sending command async', command);
|
||||
const deferred = new Deferred<void>();
|
||||
this.ws.send(JSON.stringify(command), error => {
|
||||
if (error) {
|
||||
deferred.reject(error);
|
||||
} else {
|
||||
deferred.resolve();
|
||||
}
|
||||
});
|
||||
return deferred.promise;
|
||||
}
|
||||
|
||||
private handleClose(code: number, reason: Buffer): void {
|
||||
if (this.closing) {
|
||||
this.log.debug('Connection closed');
|
||||
} else {
|
||||
this.log.debug('Connection closed by peer:', code, String(reason));
|
||||
this.emit('close', code, String(reason));
|
||||
}
|
||||
}
|
||||
|
||||
private handleError(error: Error): void {
|
||||
if (this.closing) {
|
||||
this.log.warning('Error after closing:', error);
|
||||
} else {
|
||||
this.emit('error', error);
|
||||
}
|
||||
}
|
||||
|
||||
private handleMessage(data: Buffer, _isBinary: boolean): void {
|
||||
const s = String(data);
|
||||
if (this.closing) {
|
||||
this.log.warning('Received message after closing:', s);
|
||||
return;
|
||||
}
|
||||
|
||||
this.log.trace('Received command', s);
|
||||
const command = JSON.parse(s);
|
||||
|
||||
if (command.type === '_nop_') {
|
||||
return;
|
||||
}
|
||||
if (command.type === '_bye_') {
|
||||
this.closing = true;
|
||||
this.emit('bye', command.reason);
|
||||
return;
|
||||
}
|
||||
|
||||
const hasAnyListener = this.commandEmitter.emit('__any', command);
|
||||
const hasTypeListener = this.commandEmitter.emit(command.type, command);
|
||||
if (!hasAnyListener && !hasTypeListener) {
|
||||
this.log.warning('No listener for command', s);
|
||||
}
|
||||
}
|
||||
|
||||
private handlePong(): void {
|
||||
this.log.debug('receive pong'); // todo
|
||||
this.missingPongs = 0;
|
||||
}
|
||||
|
||||
private heartbeat(): void {
|
||||
if (this.missingPongs > 0) {
|
||||
this.log.warning('Missing pong');
|
||||
}
|
||||
if (this.missingPongs > 3) { // TODO: make it configurable?
|
||||
// no response for ping, try real command
|
||||
this.sendAsync({ type: '_nop_' }).then(() => {
|
||||
this.missingPongs = 0;
|
||||
}).catch(error => {
|
||||
this.log.error('Failed sending command. Drop connection:', error);
|
||||
this.terminate(`peer lost responsive: ${util.inspect(error)}`);
|
||||
});
|
||||
}
|
||||
this.missingPongs += 1;
|
||||
this.log.debug('send ping');
|
||||
this.ws.ping();
|
||||
export function reset(): void {
|
||||
defaultHeartbeatInterval = 5000;
|
||||
terminateTimeout = 30000;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,15 +3,6 @@
|
|||
|
||||
/**
|
||||
* WebSocket command channel client.
|
||||
*
|
||||
* Usage:
|
||||
*
|
||||
* const client = new WsChannelClient('ws://1.2.3.4:8080/server/channel_id');
|
||||
* await client.connect();
|
||||
* client.send(command);
|
||||
*
|
||||
* Most APIs are derived the base class `WsChannel`.
|
||||
* See its doc for more details.
|
||||
**/
|
||||
|
||||
import events from 'node:events';
|
||||
|
@ -26,7 +17,7 @@ import { WsChannel } from './channel';
|
|||
const maxPayload: number = 1024 * 1024 * 1024;
|
||||
|
||||
export class WsChannelClient extends WsChannel {
|
||||
private logger: Logger; // avoid name conflict with base class
|
||||
private logger: Logger;
|
||||
private reconnecting: boolean = false;
|
||||
private url: string;
|
||||
|
||||
|
@ -34,19 +25,17 @@ export class WsChannelClient extends WsChannel {
|
|||
* The url should start with "ws://".
|
||||
* The name is used for better logging.
|
||||
**/
|
||||
constructor(url: string, name?: string) {
|
||||
const name_ = name ?? generateName(url);
|
||||
super(name_);
|
||||
this.logger = getLogger(`WsChannelClient.${name_}`);
|
||||
constructor(name: string, url: string) {
|
||||
super(name);
|
||||
this.logger = getLogger(`WsChannelClient.${name}`);
|
||||
this.url = url;
|
||||
this.on('lost', this.reconnect.bind(this));
|
||||
this.onLost(this.reconnect.bind(this));
|
||||
}
|
||||
|
||||
public async connect(): Promise<void> {
|
||||
this.logger.debug('Connecting to', this.url);
|
||||
const ws = new WebSocket(this.url, { maxPayload });
|
||||
this.setConnection(ws);
|
||||
await events.once(ws, 'open');
|
||||
await this.setConnection(ws, true),
|
||||
this.logger.debug('Connected');
|
||||
}
|
||||
|
||||
|
@ -54,7 +43,7 @@ export class WsChannelClient extends WsChannel {
|
|||
* Alias of `close()`.
|
||||
**/
|
||||
public async disconnect(reason?: string): Promise<void> {
|
||||
this.close(reason ?? 'client disconnecting');
|
||||
this.close(reason ?? 'client intentionally disconnect');
|
||||
}
|
||||
|
||||
private async reconnect(): Promise<void> {
|
||||
|
@ -82,16 +71,6 @@ export class WsChannelClient extends WsChannel {
|
|||
}
|
||||
|
||||
this.logger.error('Conenction lost. Cannot reconnect');
|
||||
this.emit('error', new Error('Connection lost'));
|
||||
this.emitter.emit('__error', new Error('Connection lost'));
|
||||
}
|
||||
}
|
||||
|
||||
function generateName(url: string): string {
|
||||
const parts = url.split('/');
|
||||
for (let i = parts.length - 1; i > 1; i--) {
|
||||
if (parts[i]) {
|
||||
return parts[i];
|
||||
}
|
||||
}
|
||||
return 'anonymous';
|
||||
}
|
||||
|
|
|
@ -0,0 +1,189 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
/**
|
||||
* Internal helper class which handles one WebSocket connection.
|
||||
**/
|
||||
|
||||
import { EventEmitter } from 'node:events';
|
||||
import util from 'node:util';
|
||||
|
||||
import type { WebSocket } from 'ws';
|
||||
|
||||
import type { Command } from 'common/command_channel/interface';
|
||||
import { Logger, getLogger } from 'common/log';
|
||||
|
||||
interface ConnectionEvents {
|
||||
'bye': (reason: string) => void;
|
||||
'close': (code: number, reason: string) => void;
|
||||
'error': (error: Error) => void;
|
||||
}
|
||||
|
||||
export declare interface WsConnection {
|
||||
on<E extends keyof ConnectionEvents>(event: E, listener: ConnectionEvents[E]): this;
|
||||
}
|
||||
|
||||
export class WsConnection extends EventEmitter {
|
||||
private closing: boolean = false;
|
||||
private commandEmitter: EventEmitter;
|
||||
private heartbeatTimer: NodeJS.Timer | null = null;
|
||||
private log: Logger;
|
||||
private missingPongs: number = 0;
|
||||
private ws: WebSocket; // NOTE: used in unit test
|
||||
|
||||
constructor(name: string, ws: WebSocket, commandEmitter: EventEmitter) {
|
||||
super();
|
||||
this.log = getLogger(`WsConnection.${name}`);
|
||||
this.ws = ws;
|
||||
this.commandEmitter = commandEmitter;
|
||||
|
||||
ws.on('close', this.handleClose.bind(this));
|
||||
ws.on('error', this.handleError.bind(this));
|
||||
ws.on('message', this.handleMessage.bind(this));
|
||||
ws.on('pong', this.handlePong.bind(this));
|
||||
}
|
||||
|
||||
public setHeartbeatInterval(interval: number): void {
|
||||
if (this.heartbeatTimer) {
|
||||
clearTimeout(this.heartbeatTimer);
|
||||
}
|
||||
this.heartbeatTimer = setInterval(this.heartbeat.bind(this), interval);
|
||||
}
|
||||
|
||||
public async close(reason: string): Promise<void> {
|
||||
if (this.closing) {
|
||||
this.log.debug('Close again:', reason);
|
||||
return;
|
||||
}
|
||||
|
||||
this.log.debug('Close connection:', reason);
|
||||
this.closing = true;
|
||||
if (this.heartbeatTimer) {
|
||||
clearInterval(this.heartbeatTimer);
|
||||
this.heartbeatTimer = null;
|
||||
}
|
||||
|
||||
try {
|
||||
await this.sendAsync({ type: '_bye_', reason });
|
||||
} catch (error) {
|
||||
this.log.error('Failed to send bye:', error);
|
||||
}
|
||||
|
||||
try {
|
||||
this.ws.close(4000, reason);
|
||||
return;
|
||||
} catch (error) {
|
||||
this.log.error('Failed to close socket:', error);
|
||||
}
|
||||
|
||||
try {
|
||||
this.ws.terminate();
|
||||
} catch (error) {
|
||||
this.log.debug('Failed to terminate socket:', error);
|
||||
}
|
||||
}
|
||||
|
||||
public terminate(reason: string): void {
|
||||
this.log.debug('Terminate connection:', reason);
|
||||
this.closing = true;
|
||||
if (this.heartbeatTimer) {
|
||||
clearInterval(this.heartbeatTimer);
|
||||
this.heartbeatTimer = null;
|
||||
}
|
||||
|
||||
try {
|
||||
this.ws.close(4001, reason);
|
||||
return;
|
||||
} catch (error) {
|
||||
this.log.debug('Failed to close socket:', error);
|
||||
}
|
||||
|
||||
try {
|
||||
this.ws.terminate();
|
||||
} catch (error) {
|
||||
this.log.debug('Failed to terminate socket:', error);
|
||||
}
|
||||
}
|
||||
|
||||
public send(command: Command): void {
|
||||
this.log.trace('Send command', command);
|
||||
this.ws.send(JSON.stringify(command));
|
||||
}
|
||||
|
||||
public sendAsync(command: Command): Promise<void> {
|
||||
this.log.trace('(async) Send command', command);
|
||||
const send: any = util.promisify(this.ws.send.bind(this.ws));
|
||||
return send(JSON.stringify(command));
|
||||
}
|
||||
|
||||
private handleClose(code: number, reason: Buffer): void {
|
||||
if (this.closing) {
|
||||
this.log.debug('Connection closed');
|
||||
} else {
|
||||
this.log.debug('Connection closed by peer:', code, String(reason));
|
||||
this.emit('close', code, String(reason));
|
||||
}
|
||||
}
|
||||
|
||||
private handleError(error: Error): void {
|
||||
if (this.closing) {
|
||||
this.log.warning('Error after closing:', error);
|
||||
} else {
|
||||
this.log.error('Connection error:', error);
|
||||
this.emit('error', error);
|
||||
}
|
||||
}
|
||||
|
||||
private handleMessage(data: Buffer, _isBinary: boolean): void {
|
||||
const s = String(data);
|
||||
if (this.closing) {
|
||||
this.log.warning('Received message after closing:', s);
|
||||
return;
|
||||
}
|
||||
|
||||
this.log.trace('Receive command', s);
|
||||
const command = JSON.parse(s);
|
||||
|
||||
if (command.type === '_nop_') {
|
||||
return;
|
||||
}
|
||||
|
||||
if (command.type === '_bye_') {
|
||||
this.log.debug('Intentionally close connection:', s);
|
||||
this.closing = true;
|
||||
this.emit('bye', command.reason);
|
||||
return;
|
||||
}
|
||||
|
||||
const hasReceiveListener = this.commandEmitter.emit('__receive', command);
|
||||
const hasCommandListener = this.commandEmitter.emit(command.type, command);
|
||||
if (!hasReceiveListener && !hasCommandListener) {
|
||||
this.log.warning('No listener for command', s);
|
||||
}
|
||||
}
|
||||
|
||||
private handlePong(): void {
|
||||
this.log.trace('Receive pong');
|
||||
this.missingPongs = 0;
|
||||
}
|
||||
|
||||
private heartbeat(): void {
|
||||
if (this.missingPongs > 0) {
|
||||
this.log.warning('Missing pong');
|
||||
}
|
||||
|
||||
if (this.missingPongs > 3) { // TODO: make it configurable?
|
||||
// no response for ping, try real command
|
||||
this.sendAsync({ type: '_nop_' }).then(() => {
|
||||
this.missingPongs = 0;
|
||||
}).catch(error => {
|
||||
this.log.error('Failed sending command. Drop connection:', error);
|
||||
this.terminate(`peer lost responsive: ${util.inspect(error)}`);
|
||||
});
|
||||
}
|
||||
|
||||
this.missingPongs += 1;
|
||||
this.log.trace('Send ping');
|
||||
this.ws.ping();
|
||||
}
|
||||
}
|
|
@ -3,29 +3,6 @@
|
|||
|
||||
/**
|
||||
* WebSocket command channel server.
|
||||
*
|
||||
* The server will specify a URL prefix like `ws://1.2.3.4:8080/SERVER_PREFIX`,
|
||||
* and each client will append a channel ID, like `ws://1.2.3.4:8080/SERVER_PREFIX/CHANNEL_ID`.
|
||||
*
|
||||
* const server = new WsChannelServer('example', 'SERVER_PREFIX');
|
||||
* const url = server.getChannelUrl('CHANNEL_ID');
|
||||
* const client = new WsChannelClient(url);
|
||||
* await server.start();
|
||||
* await client.connect();
|
||||
*
|
||||
* There two styles to use the server:
|
||||
*
|
||||
* 1. Handle all clients' commands in one space:
|
||||
*
|
||||
* server.onReceive((channelId, command) => { ... });
|
||||
* server.send(channelId, command);
|
||||
*
|
||||
* 2. Maintain a `WsChannel` instance for each client:
|
||||
*
|
||||
* server.onConnection((channelId, channel) => {
|
||||
* channel.onCommand(command => { ... });
|
||||
* channel.send(command);
|
||||
* });
|
||||
**/
|
||||
|
||||
import { EventEmitter } from 'events';
|
||||
|
@ -39,22 +16,18 @@ import globals from 'common/globals';
|
|||
import { Logger, getLogger } from 'common/log';
|
||||
import { WsChannel } from './channel';
|
||||
|
||||
let heartbeatInterval: number = 5000;
|
||||
|
||||
type ReceiveCallback = (channelId: string, command: Command) => void;
|
||||
|
||||
export class WsChannelServer extends EventEmitter {
|
||||
private channels: Map<string, WsChannel> = new Map();
|
||||
private ip: string;
|
||||
private log: Logger;
|
||||
private path: string;
|
||||
private receiveCallbacks: ReceiveCallback[] = [];
|
||||
|
||||
constructor(name: string, urlPath: string, ip?: string) {
|
||||
constructor(name: string, urlPath: string) {
|
||||
super();
|
||||
this.log = getLogger(`WsChannelServer.${name}`);
|
||||
this.path = urlPath;
|
||||
this.ip = ip ?? 'localhost';
|
||||
}
|
||||
|
||||
public async start(): Promise<void> {
|
||||
|
@ -67,7 +40,7 @@ export class WsChannelServer extends EventEmitter {
|
|||
const deferred = new Deferred<void>();
|
||||
|
||||
this.channels.forEach((channel, channelId) => {
|
||||
channel.on('close', (_reason) => {
|
||||
channel.onClose(_reason => {
|
||||
this.channels.delete(channelId);
|
||||
if (this.channels.size === 0) {
|
||||
deferred.resolve();
|
||||
|
@ -77,17 +50,16 @@ export class WsChannelServer extends EventEmitter {
|
|||
});
|
||||
|
||||
// wait for at most 5 seconds
|
||||
// use heartbeatInterval here for easier unit test
|
||||
setTimeout(() => {
|
||||
this.log.debug('Shutdown timeout. Stop waiting following channels:', Array.from(this.channels.keys()));
|
||||
deferred.resolve();
|
||||
}, heartbeatInterval);
|
||||
}, 5000);
|
||||
|
||||
return deferred.promise;
|
||||
}
|
||||
|
||||
public getChannelUrl(channelId: string): string {
|
||||
return globals.rest.getFullUrl('ws', this.ip, this.path, channelId);
|
||||
public getChannelUrl(channelId: string, ip?: string): string {
|
||||
return globals.rest.getFullUrl('ws', ip ?? 'localhost', this.path, channelId);
|
||||
}
|
||||
|
||||
public send(channelId: string, command: Command): void {
|
||||
|
@ -104,7 +76,7 @@ export class WsChannelServer extends EventEmitter {
|
|||
// because by this way it can detect and warning if a command is never listened
|
||||
this.receiveCallbacks.push(callback);
|
||||
for (const [channelId, channel] of this.channels) {
|
||||
channel.onCommand(command => { callback(channelId, command); });
|
||||
channel.onReceive(command => { callback(channelId, command); });
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -118,33 +90,30 @@ export class WsChannelServer extends EventEmitter {
|
|||
|
||||
if (this.channels.has(channelId)) {
|
||||
this.log.warning(`Channel ${channelId} reconnecting, drop previous connection`);
|
||||
this.channels.get(channelId)!.setConnection(ws);
|
||||
this.channels.get(channelId)!.setConnection(ws, false);
|
||||
return;
|
||||
}
|
||||
|
||||
const channel = new WsChannel(channelId, ws, heartbeatInterval);
|
||||
const channel = new WsChannel(channelId);
|
||||
this.channels.set(channelId, channel);
|
||||
|
||||
channel.on('close', reason => {
|
||||
channel.onClose(reason => {
|
||||
this.log.debug(`Connection ${channelId} closed:`, reason);
|
||||
this.channels.delete(channelId);
|
||||
});
|
||||
|
||||
channel.on('error', error => {
|
||||
channel.onError(error => {
|
||||
this.log.error(`Connection ${channelId} error:`, error);
|
||||
this.channels.delete(channelId);
|
||||
});
|
||||
|
||||
for (const cb of this.receiveCallbacks) {
|
||||
channel.on('command', command => { cb(channelId, command); });
|
||||
channel.onReceive(command => { cb(channelId, command); });
|
||||
}
|
||||
|
||||
channel.enableHeartbeat();
|
||||
channel.setConnection(ws, false);
|
||||
|
||||
this.emit('connection', channelId, channel);
|
||||
}
|
||||
}
|
||||
|
||||
export namespace UnitTestHelpers {
|
||||
export function setHeartbeatInterval(ms: number): void {
|
||||
heartbeatInterval = ms;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -98,5 +98,5 @@ if (isUnitTest()) {
|
|||
resetGlobals();
|
||||
}
|
||||
|
||||
const globals: MutableGlobals = (global as any).nni;
|
||||
export const globals: MutableGlobals = (global as any).nni;
|
||||
export default globals;
|
||||
|
|
|
@ -72,13 +72,13 @@ async function main(): Promise<void> {
|
|||
logger.debug('command:', process.argv);
|
||||
logger.debug('config:', config);
|
||||
|
||||
const client = new WsChannelClient(args.managerCommandChannel, args.environmentId);
|
||||
client.enableHeartbeat(5000);
|
||||
client.on('close', reason => {
|
||||
const client = new WsChannelClient(args.environmentId, args.managerCommandChannel);
|
||||
client.enableHeartbeat();
|
||||
client.onClose(reason => {
|
||||
logger.info('Manager closed connection:', reason);
|
||||
globals.shutdown.initiate('Connection end');
|
||||
});
|
||||
client.on('error', error => {
|
||||
client.onError(error => {
|
||||
logger.info('Connection error:', error);
|
||||
globals.shutdown.initiate('Connection error');
|
||||
});
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
"license": "MIT",
|
||||
"scripts": {
|
||||
"build": "tsc",
|
||||
"test": "nyc --reporter=cobertura --reporter=text mocha test/**/*.test.ts",
|
||||
"test": "nyc --reporter=cobertura --reporter=text mocha \"test/**/*.test.ts\"",
|
||||
"test_nnimanager": "nyc --reporter=cobertura --reporter=text mocha test/core/nnimanager.test.ts",
|
||||
"mocha": "mocha",
|
||||
"eslint": "eslint . --ext .ts"
|
||||
|
|
|
@ -61,7 +61,7 @@ export class RestServerCore {
|
|||
return deferred.promise;
|
||||
}
|
||||
|
||||
public shutdown(): Promise<void> {
|
||||
public shutdown(timeoutMilliseconds?: number): Promise<void> {
|
||||
logger.info('Stopping REST server.');
|
||||
if (this.server === null) {
|
||||
logger.warning('REST server is not running.');
|
||||
|
@ -77,7 +77,7 @@ export class RestServerCore {
|
|||
logger.debug('Killing connections');
|
||||
this.server?.closeAllConnections();
|
||||
}
|
||||
}, 5000);
|
||||
}, timeoutMilliseconds ?? 5000);
|
||||
return deferred.promise;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,209 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
import assert from 'node:assert/strict';
|
||||
import { setTimeout } from 'node:timers/promises';
|
||||
|
||||
import type {
|
||||
Command, CommandChannel, CommandChannelClient, CommandChannelServer
|
||||
} from 'common/command_channel/interface';
|
||||
import { UnitTestHelper as Helper } from 'common/command_channel/websocket/channel';
|
||||
import { WsChannelClient, WsChannelServer } from 'common/command_channel/websocket/index';
|
||||
import { globals } from 'common/globals/unittest';
|
||||
import { RestServerCore } from 'rest_server/core';
|
||||
|
||||
describe('## websocket command channel ##', () => {
|
||||
before(beforeHook);
|
||||
|
||||
it('start', testServerStart);
|
||||
|
||||
it('connect', testClientStart);
|
||||
it('message', testMessage);
|
||||
|
||||
it('reconnect', testReconnect);
|
||||
it('message', testMessage);
|
||||
|
||||
it('handle error', testError);
|
||||
it('shutdown', testShutdown);
|
||||
|
||||
after(afterHook);
|
||||
});
|
||||
|
||||
/* test cases */
|
||||
|
||||
async function testServerStart(): Promise<void> {
|
||||
ut.server = new WsChannelServer('ut_server', 'ut');
|
||||
|
||||
ut.server.onReceive((channelId, command) => {
|
||||
if (channelId === '1') {
|
||||
ut.events.push({ event: 'server_receive_1', command });
|
||||
}
|
||||
});
|
||||
|
||||
ut.server.onConnection((channelId, channel) => {
|
||||
ut.events.push({ event: 'connect', channelId, channel });
|
||||
|
||||
channel.onClose(reason => {
|
||||
ut.events.push({ event: `client_close_${channelId}`, reason });
|
||||
});
|
||||
channel.onError(error => {
|
||||
ut.events.push({ event: `client_error_${channelId}`, error });
|
||||
});
|
||||
channel.onLost(() => {
|
||||
ut.events.push({ event: `client_lost_${channelId}` });
|
||||
});
|
||||
|
||||
if (channelId === '1') {
|
||||
ut.serverChannel1 = channel;
|
||||
}
|
||||
|
||||
if (channelId === '2') {
|
||||
ut.serverChannel2 = channel;
|
||||
channel.onReceive(command => {
|
||||
ut.events.push({ event: 'server_receive_2', command });
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
await ut.server.start();
|
||||
}
|
||||
|
||||
async function testClientStart(): Promise<void> {
|
||||
const url1 = ut.server.getChannelUrl('1');
|
||||
const url2 = ut.server.getChannelUrl('2', '127.0.0.1');
|
||||
assert.equal(url1, `ws://localhost:${globals.args.port}/ut/1`);
|
||||
assert.equal(url2, `ws://127.0.0.1:${globals.args.port}/ut/2`);
|
||||
|
||||
ut.client1 = new WsChannelClient('ut_client_1', url1);
|
||||
ut.client2 = new WsChannelClient('ut_client_2', url2);
|
||||
|
||||
ut.client1.onReceive(command => {
|
||||
ut.events.push({ event: 'client_receive_1', command });
|
||||
});
|
||||
ut.client2.onCommand('ut_command', command => {
|
||||
ut.events.push({ event: 'client_receive_2', command });
|
||||
});
|
||||
|
||||
ut.client2.onClose(reason => {
|
||||
ut.events.push({ event: 'server_close_2', reason });
|
||||
});
|
||||
|
||||
await Promise.all([
|
||||
ut.client1.connect(),
|
||||
ut.client2.connect(),
|
||||
]);
|
||||
|
||||
assert.equal(ut.events[0].event, 'connect');
|
||||
assert.equal(ut.events[1].event, 'connect');
|
||||
assert.equal(Number(ut.events[0].channelId) + Number(ut.events[1].channelId), 3);
|
||||
assert.equal(ut.events.length, 2);
|
||||
|
||||
ut.events.length = 0;
|
||||
}
|
||||
|
||||
async function testReconnect(): Promise<void> {
|
||||
const ws = (ut.client1 as any).connection.ws; // NOTE: private api
|
||||
ws.pause();
|
||||
await setTimeout(heartbeatTimeout);
|
||||
ws.terminate();
|
||||
ws.resume();
|
||||
|
||||
// mac pipeline can be slow
|
||||
for (let i = 0; i < 10; i++) {
|
||||
await setTimeout(heartbeat);
|
||||
if (ut.events.length > 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert.ok(ut.countEvents('client_lost_1') >= 1);
|
||||
assert.ok(ut.countEvents('client_close_1') == 0);
|
||||
assert.ok(ut.countEvents('client_error_1') == 0);
|
||||
assert.ok(ut.countEvents('connect') == 0); // reconnect is not connect
|
||||
|
||||
ut.events.length = 0;
|
||||
}
|
||||
|
||||
async function testMessage(): Promise<void> {
|
||||
ut.server.send('1', ut.packCommand(1));
|
||||
await ut.client2.sendAsync(ut.packCommand(2));
|
||||
ut.client2.send(ut.packCommand('三'));
|
||||
ut.server.send('2', ut.packCommand('4'));
|
||||
ut.client1.send(ut.packCommand(5));
|
||||
ut.server.send('1', ut.packCommand(6));
|
||||
|
||||
await setTimeout(heartbeat);
|
||||
|
||||
assert.deepEqual(ut.filterCommands('client_receive_1'), [ 1, 6 ]);
|
||||
assert.deepEqual(ut.filterCommands('client_receive_2'), [ '4' ]);
|
||||
assert.deepEqual(ut.filterCommands('server_receive_1'), [ 5 ]);
|
||||
assert.deepEqual(ut.filterCommands('server_receive_2'), [ 2, '三' ]);
|
||||
|
||||
ut.events.length = 0;
|
||||
}
|
||||
|
||||
async function testError(): Promise<void> {
|
||||
ut.client2.terminate('client 2 terminate');
|
||||
await setTimeout(terminateTimeout * 1.1);
|
||||
|
||||
assert.ok(ut.countEvents('client_close_2') == 0);
|
||||
assert.ok(ut.countEvents('client_error_2') == 1);
|
||||
|
||||
ut.events.length = 0;
|
||||
}
|
||||
|
||||
async function testShutdown(): Promise<void> {
|
||||
await ut.server.shutdown();
|
||||
|
||||
assert.equal(ut.countEvents('client_close_1'), 1);
|
||||
|
||||
ut.events.length = 0;
|
||||
}
|
||||
|
||||
/* helpers and states */
|
||||
|
||||
// NOTE: Increase these numbers if it fails randomly
|
||||
const heartbeat = 10;
|
||||
const heartbeatTimeout = 50;
|
||||
const terminateTimeout = 100;
|
||||
|
||||
async function beforeHook(): Promise<void> {
|
||||
globals.reset();
|
||||
|
||||
ut.rest = new RestServerCore();
|
||||
await ut.rest.start();
|
||||
|
||||
Helper.setHeartbeatInterval(heartbeat);
|
||||
Helper.setTerminateTimeout(terminateTimeout);
|
||||
}
|
||||
|
||||
async function afterHook(): Promise<void> {
|
||||
Helper.reset();
|
||||
await ut.rest?.shutdown();
|
||||
globals.reset();
|
||||
}
|
||||
|
||||
class UnitTestStates {
|
||||
server!: CommandChannelServer;
|
||||
client1!: CommandChannelClient;
|
||||
client2!: CommandChannelClient;
|
||||
serverChannel1!: CommandChannel;
|
||||
serverChannel2!: CommandChannel;
|
||||
events: any[] = [];
|
||||
|
||||
rest!: RestServerCore;
|
||||
|
||||
countEvents(event: string): number {
|
||||
return this.events.filter(e => (e.event === event)).length;
|
||||
}
|
||||
|
||||
filterCommands(event: string): any[] {
|
||||
return this.events.filter(e => (e.event === event)).map(e => e.command.value);
|
||||
}
|
||||
|
||||
packCommand(value: any): Command {
|
||||
return { type: 'ut_command', value };
|
||||
}
|
||||
}
|
||||
|
||||
const ut = new UnitTestStates();
|
|
@ -39,16 +39,16 @@ export class RemoteTrainingServiceV3 implements TrainingServiceV3 {
|
|||
this.log = getLogger(`RemoteV3.${this.id}`);
|
||||
this.log.debug('Training sevice config:', config);
|
||||
|
||||
this.server = new WsChannelServer('RemoteTrialKeeper', `platform/${this.id}`, config.nniManagerIp);
|
||||
this.server = new WsChannelServer(this.id, `/platform/${this.id}`);
|
||||
|
||||
this.server.on('connection', (channelId: string, channel: WsChannel) => {
|
||||
const worker = this.workersByChannel.get(channelId);
|
||||
if (worker) {
|
||||
worker.setChannel(channel);
|
||||
channel.on('close', reason => {
|
||||
channel.onClose(reason => {
|
||||
this.log.error('Worker channel closed unexpectedly:', reason);
|
||||
});
|
||||
channel.on('error', error => {
|
||||
channel.onError(error => {
|
||||
this.log.error('Worker channel error:', error);
|
||||
this.restartWorker(worker);
|
||||
});
|
||||
|
@ -190,7 +190,7 @@ export class RemoteTrainingServiceV3 implements TrainingServiceV3 {
|
|||
this.id,
|
||||
channelId,
|
||||
config,
|
||||
this.server.getChannelUrl(channelId),
|
||||
this.server.getChannelUrl(channelId, this.config.nniManagerIp),
|
||||
Boolean(this.config.trialGpuNumber)
|
||||
);
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ export class Worker {
|
|||
public setChannel(channel: WsChannel): void {
|
||||
this.channel = channel;
|
||||
this.trialKeeper.setChannel(channel);
|
||||
channel.on('lost', async () => {
|
||||
channel.onLost(async () => {
|
||||
if (!await this.checkAlive()) {
|
||||
this.log.error('Trial keeper failed');
|
||||
channel.terminate('Trial keeper failed'); // MARK
|
||||
|
|
Загрузка…
Ссылка в новой задаче