Merge command channel APIs and add unit tests (#5450)

This commit is contained in:
liuzhe-lz 2023-03-20 16:04:00 +08:00 коммит произвёл GitHub
Родитель 1f6aedc48f
Коммит cf5fabd968
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
20 изменённых файлов: 1004 добавлений и 512 удалений

Просмотреть файл

@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .channel import WsChannelClient

Просмотреть файл

@ -0,0 +1,106 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
import time
import nni
from ..base import Command, CommandChannel
from .connection import WsConnection
_logger = logging.getLogger(__name__)
class WsChannelClient(CommandChannel):
def __init__(self, url: str):
self._url: str = url
self._closing: bool = False
self._conn: WsConnection | None = None
def connect(self) -> None:
_logger.debug(f'Connect to {self._url}')
assert not self._closing
self._ensure_conn()
def disconnect(self) -> None:
_logger.debug(f'Disconnect from {self._url}')
self.send({'type': '_bye_'})
self._closing = True
self._close_conn('client intentionally close')
def send(self, command: Command) -> None:
if self._closing:
return
_logger.debug(f'Send {command}')
msg = nni.dump(command)
for i in range(5):
try:
conn = self._ensure_conn()
conn.send(msg)
return
except Exception:
_logger.exception(f'Failed to send command. Retry in {i}s')
self._terminate_conn('send fail')
time.sleep(i)
_logger.warning(f'Failed to send command {command}. Last retry')
conn = self._ensure_conn()
conn.send(msg)
def receive(self) -> Command | None:
while True:
if self._closing:
return None
msg = self._receive_msg()
if msg is None:
return None
command = nni.load(msg)
if command['type'] == '_nop_':
continue
if command['type'] == '_bye_':
reason = command.get('reason')
_logger.debug(f'Server close connection: {reason}')
self._closing = True
self._close_conn('server intentionally close')
return None
return command
def _ensure_conn(self) -> WsConnection:
if self._conn is None and not self._closing:
self._conn = WsConnection(self._url)
self._conn.connect()
_logger.debug('Connected')
return self._conn # type: ignore
def _close_conn(self, reason: str) -> None:
if self._conn is not None:
try:
self._conn.disconnect(reason)
except Exception:
pass
self._conn = None
def _terminate_conn(self, reason: str) -> None:
if self._conn is not None:
try:
self._conn.terminate(reason)
except Exception:
pass
self._conn = None
def _receive_msg(self) -> str | None:
for i in range(5):
try:
conn = self._ensure_conn()
msg = conn.receive()
_logger.debug(f'Receive {msg}')
if not self._closing:
assert msg is not None
return msg
except Exception:
_logger.exception(f'Failed to receive command. Retry in {i}s')
self._terminate_conn('receive fail')
time.sleep(i)
_logger.warning(f'Failed to receive command. Last retry')
conn = self._ensure_conn()
conn.receive()

Просмотреть файл

@ -0,0 +1,139 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Synchronized and object-oriented WebSocket class.
WebSocket guarantees that messages will not be divided at API level.
"""
from __future__ import annotations
__all__ = ['WsConnection']
import asyncio
import logging
from threading import Lock, Thread
from typing import Any, Type
import websockets
_logger = logging.getLogger(__name__)
# the singleton event loop
_event_loop: asyncio.AbstractEventLoop = None # type: ignore
_event_loop_lock: Lock = Lock()
_event_loop_refcnt: int = 0 # number of connected websockets
class WsConnection:
"""
A WebSocket connection.
Call :meth:`connect` before :meth:`send` and :meth:`receive`.
All methods are thread safe.
Parameters
----------
url
The WebSocket URL.
For tuner command channel it should be something like ``ws://localhost:8080/tuner``.
"""
ConnectionClosed: Type[Exception] = websockets.ConnectionClosed # type: ignore
def __init__(self, url: str):
self._url: str = url
self._ws: Any = None # the library does not provide type hints
def connect(self) -> None:
global _event_loop, _event_loop_refcnt
with _event_loop_lock:
_event_loop_refcnt += 1
if _event_loop is None:
_logger.debug('Starting event loop.')
# following line must be outside _run_event_loop
# because _wait() might be executed before first line of the child thread
_event_loop = asyncio.new_event_loop()
thread = Thread(target=_run_event_loop, name='NNI-WebSocketEventLoop', daemon=True)
thread.start()
_logger.debug(f'Connecting to {self._url}')
self._ws = _wait(_connect_async(self._url))
_logger.debug(f'Connected.')
def disconnect(self, reason: str | None = None, code: int | None = None) -> None:
if self._ws is None:
_logger.debug('disconnect: No connection.')
return
try:
_wait(self._ws.close(code or 4000, reason))
_logger.debug('Connection closed by client.')
except Exception as e:
_logger.warning(f'Failed to close connection: {repr(e)}')
self._ws = None
_decrease_refcnt()
def terminate(self, reason: str | None = None) -> None:
if self._ws is None:
_logger.debug('terminate: No connection.')
return
self.disconnect(reason, 4001)
def send(self, message: str) -> None:
_logger.debug(f'Sending {message}')
try:
_wait(self._ws.send(message))
except websockets.ConnectionClosed: # type: ignore
_logger.debug('Connection closed by server.')
self._ws = None
_decrease_refcnt()
raise
def receive(self) -> str | None:
"""
Return received message;
or return ``None`` if the connection has been closed by peer.
"""
try:
msg = _wait(self._ws.recv())
_logger.debug(f'Received {msg}')
except websockets.ConnectionClosed: # type: ignore
_logger.debug('Connection closed by server.')
self._ws = None
_decrease_refcnt()
raise
# seems the library will inference whether it's text or binary, so we don't have guarantee
if isinstance(msg, bytes):
return msg.decode()
else:
return msg
def _wait(coro):
# Synchronized version of "await".
future = asyncio.run_coroutine_threadsafe(coro, _event_loop)
return future.result()
def _run_event_loop() -> None:
# A separate thread to run the event loop.
# The event loop itself is blocking, and send/receive are also blocking,
# so they must run in different threads.
asyncio.set_event_loop(_event_loop)
_event_loop.run_forever()
_logger.debug('Event loop stopped.')
async def _connect_async(url):
# Theoretically this function is meaningless and one can directly use `websockets.connect(url)`,
# but it will not work, raising "TypeError: A coroutine object is required".
# Seems a design flaw in websockets library.
return await websockets.connect(url, max_size=None) # type: ignore
def _decrease_refcnt() -> None:
global _event_loop, _event_loop_refcnt
with _event_loop_lock:
_event_loop_refcnt -= 1
if _event_loop_refcnt == 0:
_event_loop.call_soon_threadsafe(_event_loop.stop)
_event_loop = None # type: ignore

Просмотреть файл

@ -1,133 +1,4 @@
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
""" from ..command_channel.websocket.connection import WsConnection as WebSocket # pylint: disable=unused-import
Synchronized and object-oriented WebSocket class.
WebSocket guarantees that messages will not be divided at API level.
"""
from __future__ import annotations
__all__ = ['WebSocket']
import asyncio
import logging
from threading import Lock, Thread
from typing import Any, Type
import websockets
_logger = logging.getLogger(__name__)
# the singleton event loop
_event_loop: asyncio.AbstractEventLoop = None # type: ignore
_event_loop_lock: Lock = Lock()
_event_loop_refcnt: int = 0 # number of connected websockets
class WebSocket:
"""
A WebSocket connection.
Call :meth:`connect` before :meth:`send` and :meth:`receive`.
All methods are thread safe.
Parameters
----------
url
The WebSocket URL.
For tuner command channel it should be something like ``ws://localhost:8080/tuner``.
"""
ConnectionClosed: Type[Exception] = websockets.ConnectionClosed # type: ignore
def __init__(self, url: str):
self._url: str = url
self._ws: Any = None # the library does not provide type hints
def connect(self) -> None:
global _event_loop, _event_loop_refcnt
with _event_loop_lock:
_event_loop_refcnt += 1
if _event_loop is None:
_logger.debug('Starting event loop.')
# following line must be outside _run_event_loop
# because _wait() might be executed before first line of the child thread
_event_loop = asyncio.new_event_loop()
thread = Thread(target=_run_event_loop, name='NNI-WebSocketEventLoop', daemon=True)
thread.start()
_logger.debug(f'Connecting to {self._url}')
self._ws = _wait(_connect_async(self._url))
_logger.debug(f'Connected.')
def disconnect(self) -> None:
if self._ws is None:
_logger.debug('disconnect: No connection.')
return
try:
_wait(self._ws.close())
_logger.debug('Connection closed by client.')
except Exception as e:
_logger.warning(f'Failed to close connection: {repr(e)}')
self._ws = None
_decrease_refcnt()
def send(self, message: str) -> None:
_logger.debug(f'Sending {message}')
try:
_wait(self._ws.send(message))
except websockets.ConnectionClosed: # type: ignore
_logger.debug('Connection closed by server.')
self._ws = None
_decrease_refcnt()
raise
def receive(self) -> str | None:
"""
Return received message;
or return ``None`` if the connection has been closed by peer.
"""
try:
msg = _wait(self._ws.recv())
_logger.debug(f'Received {msg}')
except websockets.ConnectionClosed: # type: ignore
_logger.debug('Connection closed by server.')
self._ws = None
_decrease_refcnt()
raise
# seems the library will inference whether it's text or binary, so we don't have guarantee
if isinstance(msg, bytes):
return msg.decode()
else:
return msg
def _wait(coro):
# Synchronized version of "await".
future = asyncio.run_coroutine_threadsafe(coro, _event_loop)
return future.result()
def _run_event_loop() -> None:
# A separate thread to run the event loop.
# The event loop itself is blocking, and send/receive are also blocking,
# so they must run in different threads.
asyncio.set_event_loop(_event_loop)
_event_loop.run_forever()
_logger.debug('Event loop stopped.')
async def _connect_async(url):
# Theoretically this function is meaningless and one can directly use `websockets.connect(url)`,
# but it will not work, raising "TypeError: A coroutine object is required".
# Seems a design flaw in websockets library.
return await websockets.connect(url, max_size=None) # type: ignore
def _decrease_refcnt() -> None:
global _event_loop, _event_loop_refcnt
with _event_loop_lock:
_event_loop_refcnt -= 1
if _event_loop_refcnt == 0:
_event_loop.call_soon_threadsafe(_event_loop.stop)
_event_loop = None # type: ignore

Просмотреть файл

@ -32,7 +32,7 @@ async def read_stdin():
line = line.decode().strip() line = line.decode().strip()
_debug(f'read from stdin: {line}') _debug(f'read from stdin: {line}')
if line == '_close_': if line == '_close_':
exit() break
await _ws.send(line) await _ws.send(line)
async def ws_server(): async def ws_server():
@ -46,9 +46,12 @@ async def on_connect(ws):
global _ws global _ws
_debug('connected') _debug('connected')
_ws = ws _ws = ws
try:
async for msg in ws: async for msg in ws:
_debug(f'received from websocket: {msg}') _debug(f'received from websocket: {msg}')
print(msg, flush=True) print(msg, flush=True)
except websockets.exceptions.ConnectionClosedError:
pass
def _debug(msg): def _debug(msg):
#sys.stderr.write(f'[server-debug] {msg}\n') #sys.stderr.write(f'[server-debug] {msg}\n')

Просмотреть файл

@ -11,21 +11,21 @@ from subprocess import Popen, PIPE
import sys import sys
import time import time
from nni.runtime.tuner_command_channel.websocket import WebSocket from nni.runtime.command_channel.websocket import WsChannelClient
# A helper server that connects its stdio to incoming WebSocket. # A helper server that connects its stdio to incoming WebSocket.
_server = None _server = None
_client = None _client = None
_command1 = 'T_hello world' _command1 = {'type': 'ut_command', 'value': 123}
_command2 = 'T_你好' _command2 = {'type': 'ut_command', 'value': '你好'}
## test cases ## ## test cases ##
def test_connect(): def test_connect():
global _client global _client
port = _init() port = _init()
_client = WebSocket(f'ws://localhost:{port}') _client = WsChannelClient(f'ws://localhost:{port}')
_client.connect() _client.connect()
def test_send(): def test_send():
@ -34,16 +34,16 @@ def test_send():
_client.send(_command2) _client.send(_command2)
time.sleep(0.01) time.sleep(0.01)
sent1 = _server.stdout.readline().strip() sent1 = json.loads(_server.stdout.readline())
assert sent1 == _command1, sent1 assert sent1 == _command1, sent1
sent2 = _server.stdout.readline().strip() sent2 = json.loads(_server.stdout.readline().strip())
assert sent2 == _command2, sent2 assert sent2 == _command2, sent2
def test_receive(): def test_receive():
# Send commands to server via stdin, and get them back via channel. # Send commands to server via stdin, and get them back via channel.
_server.stdin.write(_command1 + '\n') _server.stdin.write(json.dumps(_command1) + '\n')
_server.stdin.write(_command2 + '\n') _server.stdin.write(json.dumps(_command2) + '\n')
_server.stdin.flush() _server.stdin.flush()
received1 = _client.receive() received1 = _client.receive()

Просмотреть файл

@ -51,8 +51,8 @@ export class HttpChannelServer implements CommandChannelServer {
this.outgoingQueues.forEach(queue => { queue.clear(); }); this.outgoingQueues.forEach(queue => { queue.clear(); });
} }
public getChannelUrl(channelId: string): string { public getChannelUrl(channelId: string, ip?: string): string {
return globals.rest.getFullUrl('http', 'localhost', this.path, channelId); return globals.rest.getFullUrl('http', ip ?? 'localhost', this.path, channelId);
} }
public send(channelId: string, command: Command): void { public send(channelId: string, command: Command): void {
@ -63,6 +63,10 @@ export class HttpChannelServer implements CommandChannelServer {
this.emitter.on('receive', callback); this.emitter.on('receive', callback);
} }
public onConnection(_callback: (channelId: string, channel: any) => void): void {
throw new Error('Not implemented');
}
private handleGet(request: Request, response: Response): void { private handleGet(request: Request, response: Response): void {
const channelId = request.params['channel']; const channelId = request.params['channel'];
const promise = this.getOutgoingQueue(channelId).asyncPop(timeoutMilliseconds); const promise = this.getOutgoingQueue(channelId).asyncPop(timeoutMilliseconds);

Просмотреть файл

@ -1,26 +1,158 @@
// Copyright (c) Microsoft Corporation. // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license. // Licensed under the MIT license.
//export interface Command { /**
// type: string; * Common interface of command channels.
// [key: string]: any; *
//} * A command channel is a duplex connection which supports sending and receiving JSON commands.
export type Command = any; *
* Typically a command channel implementation consists of a server and a client.
*
* The server should listen to a URL prefix like `http://localhost:8080/example/`;
* and each client should connect to a unique URL containing the prefix, e.g. `http://localhost:8080/example/channel1`.
* The client's URL should be created with `server.getChannelUrl(channelId, serverIp)`.
*
* We currently have implemented one full feature command channel, the WebSocket channel,
* and a simplified one, the HTTP channel.
* In v3.1 release we might implement file command channel and AzureML command channel.
*
* The clients might have a Python version locates in `nni/runtime/command_channel`.
* The TypeScript and Python version should be interchangable.
**/
/** /**
* A command channel server serves one or more command channels. * A command is a JSON object.
* Each channel is connected to a client.
* *
* Normally each client has a unique channel URL, * The object has only one mandatory entry, `type`.
* which can be got with `server.getChannelUrl(id)`.
* *
* The APIs might be changed to return `Promise<void>` in future. * The type string should not be surrounded by underscore (e.g. `_nop_`),
* unless you are dealing with the underlying implementation of a specific command channel;
* it should never starts with two underscores (e.g. `__command`) in any circumstance.
**/
export type Command = any;
// Maybe it's better to disable `noPropertyAccessFromIndexSignature` in tscofnig?
// export interface Command {
// type: string;
// [key: string]: any;
// }
/**
* `CommandChannel` is the base interface used by both the servers and the clients.
*
* For servers, channels can be got with `onConnection()` event listener.
* For clients, a channel can be created with the client subclass' constructor.
*
* The channel should be fault tolerant to some extend. It has three different types of closing related events:
*
* 1. Close: The channel is intentionally closed.
*
* 2. Lost: The channel is temporarily unavailable and is trying to recover.
* The user of this class should examine the peer's status out-of-band when receiving "lost" event.
*
* 3. Error: The channel is dead and cannot recover.
* A "close" event may or may not occur following this event. Do not rely on that.
**/
export interface CommandChannel {
readonly name: string; // for better logging
enableHeartbeat(intervalMilliseconds?: number): void;
/**
* Graceful (intentional) close.
* A "close" event will be emitted by `this` and the peer.
**/
close(reason: string): void;
/**
* Force close. Should only be used when the channel is not working.
* An "error" event may be emitted by `this`.
* A "lost" and/or "error" event will be emitted by the peer, if its process is still alive.
**/
terminate(reason: string): void;
send(command: Command): void;
/**
* The async version should try to ensures the command is successfully sent to the peer.
* But this is not guaranteed.
**/
sendAsync(command: Command): Promise<void>;
onReceive(callback: (command: Command) => void): void;
onCommand(commandType: string, callback: (command: Command) => void): void;
onClose(callback: (reason?: string) => void): void;
onError(callback: (error: Error) => void): void;
onLost(callback: () => void): void;
}
/**
* Client side of a command channel.
*
* The constructor should have no side effects.
*
* The listeners should be registered before calling `connect()`,
* or the first few commands might be missed.
*
* Example usage:
*
* const client = new WsChannelClient('example', 'ws://1.2.3.4:8080/server/channel_id');
* await client.connect();
* client.send(command);
**/
export interface CommandChannelClient extends CommandChannel {
// constructor(name: string, url: string);
connect(): Promise<void>;
/**
* Typically an alias of `close()`.
**/
disconnect(reason?: string): Promise<void>;
}
/**
* Server side of a command channel.
*
* The consructor should have no side effects.
*
* The listeners should be registered before calling `start()`.
*
* Example usage:
*
* const server = new WsChannelServer('example_server', '/server_prefix');
* const url = server.getChannelUrl('channel_id');
* const client = new WsChannelClient('example_client', url);
* await server.start();
* await client.connect();
*
* There two ways to listen to command:
*
* 1. Handle all clients' commands in one space:
*
* server.onReceive((channelId, command) => { ... });
* server.send(channelId, command);
*
* 2. Maintain a `WsChannel` instance for each client:
*
* server.onConnection((channelId, channel) => {
* channel.onCommand(command => { ... });
* channel.send(command);
* });
**/ **/
export interface CommandChannelServer { export interface CommandChannelServer {
// constructor(name: string, urlPath: string) // constructor(name: string, urlPath: string);
start(): Promise<void>; start(): Promise<void>;
shutdown(): Promise<void>; shutdown(): Promise<void>;
getChannelUrl(channelId: string): string;
/**
* When `ip` is missing, it should default to localhost.
**/
getChannelUrl(channelId: string, ip?: string): string;
send(channelId: string, command: Command): void; send(channelId: string, command: Command): void;
onReceive(callback: (channelId: string, command: Command) => void): void; onReceive(callback: (channelId: string, command: Command) => void): void;
onConnection(callback: (channelId: string, channel: CommandChannel) => void): void;
} }

Просмотреть файл

@ -50,7 +50,7 @@ import { DefaultMap } from 'common/default_map';
import { Deferred } from 'common/deferred'; import { Deferred } from 'common/deferred';
import { Logger, getLogger } from 'common/log'; import { Logger, getLogger } from 'common/log';
import type { TrialKeeper } from 'common/trial_keeper/keeper'; import type { TrialKeeper } from 'common/trial_keeper/keeper';
import { WsChannel } from './websocket/channel'; import type { CommandChannel } from './interface';
interface RpcResponseCommand { interface RpcResponseCommand {
type: 'rpc_response'; type: 'rpc_response';
@ -61,14 +61,14 @@ interface RpcResponseCommand {
type Class = { new(...args: any[]): any; }; type Class = { new(...args: any[]): any; };
const rpcHelpers: Map<WsChannel, RpcHelper> = new Map(); const rpcHelpers: Map<CommandChannel, RpcHelper> = new Map();
/** /**
* Enable RPC on a channel. * Enable RPC on a channel.
* *
* The channel does not need to be connected for calling this function. * The channel does not need to be connected for calling this function.
**/ **/
export function getRpcHelper(channel: WsChannel): RpcHelper { export function getRpcHelper(channel: CommandChannel): RpcHelper {
if (!rpcHelpers.has(channel)) { if (!rpcHelpers.has(channel)) {
rpcHelpers.set(channel, new RpcHelper(channel)); rpcHelpers.set(channel, new RpcHelper(channel));
} }
@ -76,7 +76,7 @@ export function getRpcHelper(channel: WsChannel): RpcHelper {
} }
export class RpcHelper { export class RpcHelper {
private channel: WsChannel; private channel: CommandChannel;
private lastId: number = 0; private lastId: number = 0;
private localCtors: Map<string, Class> = new Map(); private localCtors: Map<string, Class> = new Map();
private localObjs: Map<number, any> = new Map(); private localObjs: Map<number, any> = new Map();
@ -87,7 +87,7 @@ export class RpcHelper {
/** /**
* NOTE: Don't use this constructor directly. Use `getRpcHelper()`. * NOTE: Don't use this constructor directly. Use `getRpcHelper()`.
**/ **/
constructor(channel: WsChannel) { constructor(channel: CommandChannel) {
this.log = getLogger(`RpcHelper.${channel.name}`); this.log = getLogger(`RpcHelper.${channel.name}`);
this.channel = channel; this.channel = channel;
this.channel.onCommand('rpc_constructor', command => { this.channel.onCommand('rpc_constructor', command => {

Просмотреть файл

@ -4,353 +4,240 @@
/** /**
* WebSocket command channel. * WebSocket command channel.
* *
* This is the base class that used by both server and client. * A WsChannel operates on one WebSocket connection at a time.
* But when the network is unstable, it may close the underlying connection and create a new one.
* This is generally transparent to the user of this class, except that a "lost" event will be emitted.
* *
* For the server, channels can be got with `onConnection()` event listener. * To distinguish intentional close from connection lost,
* For the client, a channel can be created with `new WsChannelClient()` subclass. * a "_bye_" command will be sent when `close()` or `disconnect()` is invoked.
* Do not use the constructor directly.
* *
* The channel is fault tolerant to some extend. It has three different types of closing related events: * If the connection is closed before receiving "_bye_" command, a "lost" event will be emitted and:
* *
* 1. "close": The channel is intentionally closed. * * The client will try to reconnect for severaly times in around 15s.
* * The server will wait the client to reconnect for around 30s.
* *
* This is caused either by "close()" or "disconnect()" call, or by receiving a "_bye_" command from the peer. * If the reconnecting attempt failed, both side will emit an "error" event.
*
* 2. "lost": The channel is temporarily unavailable and is trying to recover.
* (The high level class should examine the peer's status out-of-band when receiving this event.)
*
* When the underlying socket is dead, this event is emitted.
* The client will try to reconnect in around 15s. If all attempts fail, an "error" event will be emitted.
* The server will wait the client for 30s. If it does not reconnect, an "error" event will be emitted.
* Successful recover will not emit command.
*
* 3. "error": The channel is dead and cannot recover.
*
* A "close" event may or may not follow this event. Do not rely on that.
**/ **/
import { EventEmitter } from 'node:events'; import { EventEmitter, once } from 'node:events';
import util from 'node:util'; import util from 'node:util';
import type { WebSocket } from 'ws'; import type { WebSocket } from 'ws';
import type { Command, CommandChannel } from 'common/command_channel/interface';
import { Deferred } from 'common/deferred'; import { Deferred } from 'common/deferred';
import { Logger, getLogger } from 'common/log'; import { Logger, getLogger } from 'common/log';
import { WsConnection } from './connection';
import type { Command } from '../interface'; interface QueuedCommand {
command: Command;
interface WsChannelEvents { deferred?: Deferred<void>;
'command': (command: Command) => void;
'close': (reason: string) => void;
'lost': () => void;
'error': (error: Error) => void; // not used in base class
} }
export declare interface WsChannel { export class WsChannel implements CommandChannel {
on<E extends keyof WsChannelEvents>(event: E, listener: WsChannelEvents[E]): this;
}
export class WsChannel extends EventEmitter {
private closing: boolean = false; private closing: boolean = false;
private commandEmitter: EventEmitter = new EventEmitter(); private connection: WsConnection | null = null; // NOTE: used in unit test
private connection: WsConnection | null = null; private epoch: number = -1;
private epoch: number = 0; private heartbeatInterval: number | null = null;
private heartbeatInterval: number | null;
private log: Logger; private log: Logger;
private queue: QueuedCommand[] = [];
private terminateTimer: NodeJS.Timer | null = null;
protected emitter: EventEmitter = new EventEmitter();
public readonly name: string; public readonly name: string;
// internal, don't use // internal, don't use
constructor(name: string, ws?: WebSocket, heartbeatInterval?: number) { constructor(name: string) {
super()
this.log = getLogger(`WsChannel.${name}`); this.log = getLogger(`WsChannel.${name}`);
this.name = name; this.name = name;
this.heartbeatInterval = heartbeatInterval || null;
if (ws) {
this.setConnection(ws);
}
}
public enableHeartbeat(interval: number): void {
this.log.debug('## enable heartbeat');
this.heartbeatInterval = interval;
} }
// internal, don't use // internal, don't use
public setConnection(ws: WebSocket): void { public async setConnection(ws: WebSocket, waitOpen: boolean): Promise<void> {
if (this.connection) { if (this.terminateTimer) {
this.log.debug('Abandon previous connection'); clearTimeout(this.terminateTimer);
this.epoch += 1; this.terminateTimer = null;
} }
this.connection?.terminate('new epoch start');
this.newEpoch();
this.log.debug(`Epoch ${this.epoch} start`); this.log.debug(`Epoch ${this.epoch} start`);
this.connection = this.configConnection(ws); this.connection = this.configConnection(ws);
if (waitOpen) {
await once(ws, 'open');
}
while (this.connection && this.queue.length > 0) {
const item = this.queue.shift()!;
try {
await this.connection.sendAsync(item.command);
item.deferred?.resolve();
} catch (error) {
this.log.error('Failed to send command on recovered channel:', error);
this.log.error('Dropped command:', item.command);
item.deferred?.reject(error as any);
// it should trigger connection's error event and this.connection will be set to null
}
}
}
public enableHeartbeat(interval?: number): void {
this.heartbeatInterval = interval ?? defaultHeartbeatInterval;
this.connection?.setHeartbeatInterval(this.heartbeatInterval);
} }
public close(reason: string): void { public close(reason: string): void {
this.log.debug('Close channel:', reason); this.log.debug('Close channel:', reason);
if (this.connection) { this.connection?.close(reason);
this.connection.close(reason); if (this.setClosing()) {
this.endEpoch(); this.emitter.emit('__close', reason);
}
if (!this.closing) {
this.closing = true;
this.emit('close', reason);
} }
} }
public terminate(reason: string): void { public terminate(reason: string): void {
this.log.info('Terminate channel:', reason); this.log.info('Terminate channel:', reason);
this.closing = true; this.connection?.terminate(reason);
if (this.connection) { if (this.setClosing()) {
this.connection.terminate(reason); this.emitter.emit('__error', new Error(`WsChannel terminated: ${reason}`));
this.endEpoch();
} }
} }
public send(command: Command): void { public send(command: Command): void {
if (this.connection) { if (this.closing) {
this.log.error('Channel closed. Ignored command', command);
return;
}
if (!this.connection) {
this.log.warning('Connection lost. Enqueue command', command);
this.queue.push({ command });
return;
}
this.connection.send(command); this.connection.send(command);
} else {
// TODO: add a queue?
this.log.error('Connection lost. Dropped command', command);
}
} }
/**
* Async version of `send()` that (partially) ensures the command is successfully sent to peer.
**/
public sendAsync(command: Command): Promise<void> { public sendAsync(command: Command): Promise<void> {
if (this.connection) { if (this.closing) {
this.log.error('(async) Channel closed. Refused command', command);
return Promise.reject(new Error('WsChannel has been closed'));
}
if (!this.connection) {
this.log.warning('(async) Connection lost. Enqueue command', command);
const deferred = new Deferred<void>();
this.queue.push({ command, deferred });
return deferred.promise;
}
return this.connection.sendAsync(command); return this.connection.sendAsync(command);
} else {
this.log.error('Connection lost. Dropped command async', command);
return Promise.reject(new Error('Connection is lost and trying to recover, cannot send command now'));
}
} }
// the first overload listens to all commands, while the second listens to one command type public onReceive(callback: (command: Command) => void): void {
public onCommand(callback: (command: Command) => void): void; this.emitter.on('__receive', callback);
public onCommand(commandType: string, callback: (command: Command) => void): void;
public onCommand(commandTypeOrCallback: any, callbackOrNone?: any): void {
if (callbackOrNone) {
this.commandEmitter.on(commandTypeOrCallback, callbackOrNone);
} else {
this.commandEmitter.on('__any', commandTypeOrCallback);
} }
public onCommand(commandType: string, callback: (command: Command) => void): void {
this.emitter.on(commandType, callback);
}
public onClose(callback: (reason?: string) => void): void {
this.emitter.on('__close', callback);
}
public onError(callback: (error: Error) => void): void {
this.emitter.on('__error', callback);
}
public onLost(callback: () => void): void {
this.emitter.on('__lost', callback);
}
private newEpoch(): void {
this.connection = null;
this.epoch += 1;
} }
private configConnection(ws: WebSocket): WsConnection { private configConnection(ws: WebSocket): WsConnection {
this.log.debug('## config connection'); const connName = this.epoch ? `${this.name}.${this.epoch}` : this.name;
const epoch = this.epoch; // copy it to use in closure const conn = new WsConnection(connName, ws, this.emitter);
const conn = new WsConnection( if (this.heartbeatInterval) {
this.epoch ? `${this.name}.${epoch}` : this.name, conn.setHeartbeatInterval(this.heartbeatInterval);
ws, }
this.commandEmitter,
this.heartbeatInterval
);
conn.on('bye', reason => { conn.on('bye', reason => {
this.log.debug('Peer intentionally close:', reason); this.log.debug('Peer intentionally closing:', reason);
this.endEpoch(); if (this.setClosing()) {
if (!this.closing) { this.emitter.emit('__close', reason);
this.closing = true;
this.emit('close', reason);
} }
}); });
conn.on('close', (code, reason) => { conn.on('close', (code, reason) => {
this.closeConnection(epoch, `Received closing handshake: ${code} ${reason}`); this.log.debug('Peer closed:', reason);
this.dropConnection(conn, `Peer closed: ${code} ${reason}`);
}); });
conn.on('error', error => { conn.on('error', error => {
this.closeConnection(epoch, `Error occurred: ${util.inspect(error)}`); this.dropConnection(conn, `Connection error: ${util.inspect(error)}`);
}); });
return conn; return conn;
} }
private closeConnection(epoch: number, reason: string): void { private setClosing(): boolean {
if (this.closing) { if (this.closing) {
this.log.debug('Connection cleaned up:', reason); return false;
}
this.closing = true;
this.newEpoch();
this.queue.forEach(item => {
item.deferred?.reject(new Error('WsChannel has been closed.'));
});
return true;
}
private dropConnection(conn: WsConnection, reason: string): void {
if (this.closing) {
this.log.debug('Clean up:', reason);
return; return;
} }
if (this.epoch !== epoch) { // the connection is already abandoned if (this.connection !== conn) { // the connection is already abandoned
this.log.debug(`Previous connection closed ${epoch}: ${reason}`); this.log.debug(`Previous connection closed: ${reason}`);
return; return;
} }
this.log.warning('Connection closed unexpectedly:', reason); this.log.warning('Connection closed unexpectedly:', reason);
this.emit('lost'); this.newEpoch();
this.endEpoch(); this.emitter.emit('__lost');
if (!this.terminateTimer) {
this.terminateTimer = setTimeout(() => {
if (!this.closing) {
this.terminate('have not reconnected in 30s');
}
}, terminateTimeout);
}
// the reconnect logic is in client subclass // the reconnect logic is in client subclass
} }
private endEpoch(): void {
this.connection = null;
this.epoch += 1;
}
} }
interface WsConnectionEvents { let defaultHeartbeatInterval: number = 5000;
'command': (command: Command) => void; let terminateTimeout: number = 30000;
'bye': (reason: string) => void;
'close': (code: number, reason: string) => void;
'error': (error: Error) => void;
}
declare interface WsConnection { export namespace UnitTestHelper {
on<E extends keyof WsConnectionEvents>(event: E, listener: WsConnectionEvents[E]): this; export function setHeartbeatInterval(ms: number): void {
} defaultHeartbeatInterval = ms;
class WsConnection extends EventEmitter {
private closing: boolean = false;
private commandEmitter: EventEmitter;
private heartbeatTimer: NodeJS.Timer | null = null;
private log: Logger;
private missingPongs: number = 0;
private ws: WebSocket;
constructor(name: string, ws: WebSocket, commandEmitter: EventEmitter, heartbeatInterval: number | null) {
super();
this.log = getLogger(`WsConnection.${name}`);
this.ws = ws;
this.commandEmitter = commandEmitter;
ws.on('close', this.handleClose.bind(this));
ws.on('error', this.handleError.bind(this));
ws.on('message', this.handleMessage.bind(this));
ws.on('pong', this.handlePong.bind(this));
if (heartbeatInterval) {
this.heartbeatTimer = setInterval(this.heartbeat.bind(this), heartbeatInterval);
}
} }
public async close(reason: string): Promise<void> { export function setTerminateTimeout(ms: number): void {
if (this.closing) { terminateTimeout = ms;
this.log.debug('Close again:', reason);
return;
} }
this.log.debug('Close:', reason); export function reset(): void {
this.closing = true; defaultHeartbeatInterval = 5000;
if (this.heartbeatTimer) { terminateTimeout = 30000;
clearInterval(this.heartbeatTimer);
this.heartbeatTimer = null;
}
try {
await this.sendAsync({ type: '_bye_', reason });
} catch (error) {
this.log.error('Failed to send bye:', error);
}
try {
this.ws.close(4000, reason);
} catch (error) {
this.log.error('Failed to close:', error);
this.ws.terminate();
}
}
public terminate(reason: string): void {
this.log.debug('Terminate:', reason);
this.closing = true;
if (this.heartbeatTimer) {
clearInterval(this.heartbeatTimer);
this.heartbeatTimer = null;
}
try {
this.ws.close(4001, reason);
} catch (error) {
this.log.debug('Failed to close:', error);
this.ws.terminate();
}
}
public send(command: Command): void {
this.log.trace('Sending command', command);
this.ws.send(JSON.stringify(command));
}
public sendAsync(command: Command): Promise<void> {
this.log.trace('Sending command async', command);
const deferred = new Deferred<void>();
this.ws.send(JSON.stringify(command), error => {
if (error) {
deferred.reject(error);
} else {
deferred.resolve();
}
});
return deferred.promise;
}
private handleClose(code: number, reason: Buffer): void {
if (this.closing) {
this.log.debug('Connection closed');
} else {
this.log.debug('Connection closed by peer:', code, String(reason));
this.emit('close', code, String(reason));
}
}
private handleError(error: Error): void {
if (this.closing) {
this.log.warning('Error after closing:', error);
} else {
this.emit('error', error);
}
}
private handleMessage(data: Buffer, _isBinary: boolean): void {
const s = String(data);
if (this.closing) {
this.log.warning('Received message after closing:', s);
return;
}
this.log.trace('Received command', s);
const command = JSON.parse(s);
if (command.type === '_nop_') {
return;
}
if (command.type === '_bye_') {
this.closing = true;
this.emit('bye', command.reason);
return;
}
const hasAnyListener = this.commandEmitter.emit('__any', command);
const hasTypeListener = this.commandEmitter.emit(command.type, command);
if (!hasAnyListener && !hasTypeListener) {
this.log.warning('No listener for command', s);
}
}
private handlePong(): void {
this.log.debug('receive pong'); // todo
this.missingPongs = 0;
}
private heartbeat(): void {
if (this.missingPongs > 0) {
this.log.warning('Missing pong');
}
if (this.missingPongs > 3) { // TODO: make it configurable?
// no response for ping, try real command
this.sendAsync({ type: '_nop_' }).then(() => {
this.missingPongs = 0;
}).catch(error => {
this.log.error('Failed sending command. Drop connection:', error);
this.terminate(`peer lost responsive: ${util.inspect(error)}`);
});
}
this.missingPongs += 1;
this.log.debug('send ping');
this.ws.ping();
} }
} }

Просмотреть файл

@ -3,15 +3,6 @@
/** /**
* WebSocket command channel client. * WebSocket command channel client.
*
* Usage:
*
* const client = new WsChannelClient('ws://1.2.3.4:8080/server/channel_id');
* await client.connect();
* client.send(command);
*
* Most APIs are derived the base class `WsChannel`.
* See its doc for more details.
**/ **/
import events from 'node:events'; import events from 'node:events';
@ -26,7 +17,7 @@ import { WsChannel } from './channel';
const maxPayload: number = 1024 * 1024 * 1024; const maxPayload: number = 1024 * 1024 * 1024;
export class WsChannelClient extends WsChannel { export class WsChannelClient extends WsChannel {
private logger: Logger; // avoid name conflict with base class private logger: Logger;
private reconnecting: boolean = false; private reconnecting: boolean = false;
private url: string; private url: string;
@ -34,19 +25,17 @@ export class WsChannelClient extends WsChannel {
* The url should start with "ws://". * The url should start with "ws://".
* The name is used for better logging. * The name is used for better logging.
**/ **/
constructor(url: string, name?: string) { constructor(name: string, url: string) {
const name_ = name ?? generateName(url); super(name);
super(name_); this.logger = getLogger(`WsChannelClient.${name}`);
this.logger = getLogger(`WsChannelClient.${name_}`);
this.url = url; this.url = url;
this.on('lost', this.reconnect.bind(this)); this.onLost(this.reconnect.bind(this));
} }
public async connect(): Promise<void> { public async connect(): Promise<void> {
this.logger.debug('Connecting to', this.url); this.logger.debug('Connecting to', this.url);
const ws = new WebSocket(this.url, { maxPayload }); const ws = new WebSocket(this.url, { maxPayload });
this.setConnection(ws); await this.setConnection(ws, true),
await events.once(ws, 'open');
this.logger.debug('Connected'); this.logger.debug('Connected');
} }
@ -54,7 +43,7 @@ export class WsChannelClient extends WsChannel {
* Alias of `close()`. * Alias of `close()`.
**/ **/
public async disconnect(reason?: string): Promise<void> { public async disconnect(reason?: string): Promise<void> {
this.close(reason ?? 'client disconnecting'); this.close(reason ?? 'client intentionally disconnect');
} }
private async reconnect(): Promise<void> { private async reconnect(): Promise<void> {
@ -82,16 +71,6 @@ export class WsChannelClient extends WsChannel {
} }
this.logger.error('Conenction lost. Cannot reconnect'); this.logger.error('Conenction lost. Cannot reconnect');
this.emit('error', new Error('Connection lost')); this.emitter.emit('__error', new Error('Connection lost'));
} }
} }
function generateName(url: string): string {
const parts = url.split('/');
for (let i = parts.length - 1; i > 1; i--) {
if (parts[i]) {
return parts[i];
}
}
return 'anonymous';
}

Просмотреть файл

@ -0,0 +1,189 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
/**
* Internal helper class which handles one WebSocket connection.
**/
import { EventEmitter } from 'node:events';
import util from 'node:util';
import type { WebSocket } from 'ws';
import type { Command } from 'common/command_channel/interface';
import { Logger, getLogger } from 'common/log';
interface ConnectionEvents {
'bye': (reason: string) => void;
'close': (code: number, reason: string) => void;
'error': (error: Error) => void;
}
export declare interface WsConnection {
on<E extends keyof ConnectionEvents>(event: E, listener: ConnectionEvents[E]): this;
}
export class WsConnection extends EventEmitter {
private closing: boolean = false;
private commandEmitter: EventEmitter;
private heartbeatTimer: NodeJS.Timer | null = null;
private log: Logger;
private missingPongs: number = 0;
private ws: WebSocket; // NOTE: used in unit test
constructor(name: string, ws: WebSocket, commandEmitter: EventEmitter) {
super();
this.log = getLogger(`WsConnection.${name}`);
this.ws = ws;
this.commandEmitter = commandEmitter;
ws.on('close', this.handleClose.bind(this));
ws.on('error', this.handleError.bind(this));
ws.on('message', this.handleMessage.bind(this));
ws.on('pong', this.handlePong.bind(this));
}
public setHeartbeatInterval(interval: number): void {
if (this.heartbeatTimer) {
clearTimeout(this.heartbeatTimer);
}
this.heartbeatTimer = setInterval(this.heartbeat.bind(this), interval);
}
public async close(reason: string): Promise<void> {
if (this.closing) {
this.log.debug('Close again:', reason);
return;
}
this.log.debug('Close connection:', reason);
this.closing = true;
if (this.heartbeatTimer) {
clearInterval(this.heartbeatTimer);
this.heartbeatTimer = null;
}
try {
await this.sendAsync({ type: '_bye_', reason });
} catch (error) {
this.log.error('Failed to send bye:', error);
}
try {
this.ws.close(4000, reason);
return;
} catch (error) {
this.log.error('Failed to close socket:', error);
}
try {
this.ws.terminate();
} catch (error) {
this.log.debug('Failed to terminate socket:', error);
}
}
public terminate(reason: string): void {
this.log.debug('Terminate connection:', reason);
this.closing = true;
if (this.heartbeatTimer) {
clearInterval(this.heartbeatTimer);
this.heartbeatTimer = null;
}
try {
this.ws.close(4001, reason);
return;
} catch (error) {
this.log.debug('Failed to close socket:', error);
}
try {
this.ws.terminate();
} catch (error) {
this.log.debug('Failed to terminate socket:', error);
}
}
public send(command: Command): void {
this.log.trace('Send command', command);
this.ws.send(JSON.stringify(command));
}
public sendAsync(command: Command): Promise<void> {
this.log.trace('(async) Send command', command);
const send: any = util.promisify(this.ws.send.bind(this.ws));
return send(JSON.stringify(command));
}
private handleClose(code: number, reason: Buffer): void {
if (this.closing) {
this.log.debug('Connection closed');
} else {
this.log.debug('Connection closed by peer:', code, String(reason));
this.emit('close', code, String(reason));
}
}
private handleError(error: Error): void {
if (this.closing) {
this.log.warning('Error after closing:', error);
} else {
this.log.error('Connection error:', error);
this.emit('error', error);
}
}
private handleMessage(data: Buffer, _isBinary: boolean): void {
const s = String(data);
if (this.closing) {
this.log.warning('Received message after closing:', s);
return;
}
this.log.trace('Receive command', s);
const command = JSON.parse(s);
if (command.type === '_nop_') {
return;
}
if (command.type === '_bye_') {
this.log.debug('Intentionally close connection:', s);
this.closing = true;
this.emit('bye', command.reason);
return;
}
const hasReceiveListener = this.commandEmitter.emit('__receive', command);
const hasCommandListener = this.commandEmitter.emit(command.type, command);
if (!hasReceiveListener && !hasCommandListener) {
this.log.warning('No listener for command', s);
}
}
private handlePong(): void {
this.log.trace('Receive pong');
this.missingPongs = 0;
}
private heartbeat(): void {
if (this.missingPongs > 0) {
this.log.warning('Missing pong');
}
if (this.missingPongs > 3) { // TODO: make it configurable?
// no response for ping, try real command
this.sendAsync({ type: '_nop_' }).then(() => {
this.missingPongs = 0;
}).catch(error => {
this.log.error('Failed sending command. Drop connection:', error);
this.terminate(`peer lost responsive: ${util.inspect(error)}`);
});
}
this.missingPongs += 1;
this.log.trace('Send ping');
this.ws.ping();
}
}

Просмотреть файл

@ -3,29 +3,6 @@
/** /**
* WebSocket command channel server. * WebSocket command channel server.
*
* The server will specify a URL prefix like `ws://1.2.3.4:8080/SERVER_PREFIX`,
* and each client will append a channel ID, like `ws://1.2.3.4:8080/SERVER_PREFIX/CHANNEL_ID`.
*
* const server = new WsChannelServer('example', 'SERVER_PREFIX');
* const url = server.getChannelUrl('CHANNEL_ID');
* const client = new WsChannelClient(url);
* await server.start();
* await client.connect();
*
* There two styles to use the server:
*
* 1. Handle all clients' commands in one space:
*
* server.onReceive((channelId, command) => { ... });
* server.send(channelId, command);
*
* 2. Maintain a `WsChannel` instance for each client:
*
* server.onConnection((channelId, channel) => {
* channel.onCommand(command => { ... });
* channel.send(command);
* });
**/ **/
import { EventEmitter } from 'events'; import { EventEmitter } from 'events';
@ -39,22 +16,18 @@ import globals from 'common/globals';
import { Logger, getLogger } from 'common/log'; import { Logger, getLogger } from 'common/log';
import { WsChannel } from './channel'; import { WsChannel } from './channel';
let heartbeatInterval: number = 5000;
type ReceiveCallback = (channelId: string, command: Command) => void; type ReceiveCallback = (channelId: string, command: Command) => void;
export class WsChannelServer extends EventEmitter { export class WsChannelServer extends EventEmitter {
private channels: Map<string, WsChannel> = new Map(); private channels: Map<string, WsChannel> = new Map();
private ip: string;
private log: Logger; private log: Logger;
private path: string; private path: string;
private receiveCallbacks: ReceiveCallback[] = []; private receiveCallbacks: ReceiveCallback[] = [];
constructor(name: string, urlPath: string, ip?: string) { constructor(name: string, urlPath: string) {
super(); super();
this.log = getLogger(`WsChannelServer.${name}`); this.log = getLogger(`WsChannelServer.${name}`);
this.path = urlPath; this.path = urlPath;
this.ip = ip ?? 'localhost';
} }
public async start(): Promise<void> { public async start(): Promise<void> {
@ -67,7 +40,7 @@ export class WsChannelServer extends EventEmitter {
const deferred = new Deferred<void>(); const deferred = new Deferred<void>();
this.channels.forEach((channel, channelId) => { this.channels.forEach((channel, channelId) => {
channel.on('close', (_reason) => { channel.onClose(_reason => {
this.channels.delete(channelId); this.channels.delete(channelId);
if (this.channels.size === 0) { if (this.channels.size === 0) {
deferred.resolve(); deferred.resolve();
@ -77,17 +50,16 @@ export class WsChannelServer extends EventEmitter {
}); });
// wait for at most 5 seconds // wait for at most 5 seconds
// use heartbeatInterval here for easier unit test
setTimeout(() => { setTimeout(() => {
this.log.debug('Shutdown timeout. Stop waiting following channels:', Array.from(this.channels.keys())); this.log.debug('Shutdown timeout. Stop waiting following channels:', Array.from(this.channels.keys()));
deferred.resolve(); deferred.resolve();
}, heartbeatInterval); }, 5000);
return deferred.promise; return deferred.promise;
} }
public getChannelUrl(channelId: string): string { public getChannelUrl(channelId: string, ip?: string): string {
return globals.rest.getFullUrl('ws', this.ip, this.path, channelId); return globals.rest.getFullUrl('ws', ip ?? 'localhost', this.path, channelId);
} }
public send(channelId: string, command: Command): void { public send(channelId: string, command: Command): void {
@ -104,7 +76,7 @@ export class WsChannelServer extends EventEmitter {
// because by this way it can detect and warning if a command is never listened // because by this way it can detect and warning if a command is never listened
this.receiveCallbacks.push(callback); this.receiveCallbacks.push(callback);
for (const [channelId, channel] of this.channels) { for (const [channelId, channel] of this.channels) {
channel.onCommand(command => { callback(channelId, command); }); channel.onReceive(command => { callback(channelId, command); });
} }
} }
@ -118,33 +90,30 @@ export class WsChannelServer extends EventEmitter {
if (this.channels.has(channelId)) { if (this.channels.has(channelId)) {
this.log.warning(`Channel ${channelId} reconnecting, drop previous connection`); this.log.warning(`Channel ${channelId} reconnecting, drop previous connection`);
this.channels.get(channelId)!.setConnection(ws); this.channels.get(channelId)!.setConnection(ws, false);
return; return;
} }
const channel = new WsChannel(channelId, ws, heartbeatInterval); const channel = new WsChannel(channelId);
this.channels.set(channelId, channel); this.channels.set(channelId, channel);
channel.on('close', reason => { channel.onClose(reason => {
this.log.debug(`Connection ${channelId} closed:`, reason); this.log.debug(`Connection ${channelId} closed:`, reason);
this.channels.delete(channelId); this.channels.delete(channelId);
}); });
channel.on('error', error => { channel.onError(error => {
this.log.error(`Connection ${channelId} error:`, error); this.log.error(`Connection ${channelId} error:`, error);
this.channels.delete(channelId); this.channels.delete(channelId);
}); });
for (const cb of this.receiveCallbacks) { for (const cb of this.receiveCallbacks) {
channel.on('command', command => { cb(channelId, command); }); channel.onReceive(command => { cb(channelId, command); });
} }
channel.enableHeartbeat();
channel.setConnection(ws, false);
this.emit('connection', channelId, channel); this.emit('connection', channelId, channel);
} }
} }
export namespace UnitTestHelpers {
export function setHeartbeatInterval(ms: number): void {
heartbeatInterval = ms;
}
}

Просмотреть файл

@ -98,5 +98,5 @@ if (isUnitTest()) {
resetGlobals(); resetGlobals();
} }
const globals: MutableGlobals = (global as any).nni; export const globals: MutableGlobals = (global as any).nni;
export default globals; export default globals;

Просмотреть файл

@ -72,13 +72,13 @@ async function main(): Promise<void> {
logger.debug('command:', process.argv); logger.debug('command:', process.argv);
logger.debug('config:', config); logger.debug('config:', config);
const client = new WsChannelClient(args.managerCommandChannel, args.environmentId); const client = new WsChannelClient(args.environmentId, args.managerCommandChannel);
client.enableHeartbeat(5000); client.enableHeartbeat();
client.on('close', reason => { client.onClose(reason => {
logger.info('Manager closed connection:', reason); logger.info('Manager closed connection:', reason);
globals.shutdown.initiate('Connection end'); globals.shutdown.initiate('Connection end');
}); });
client.on('error', error => { client.onError(error => {
logger.info('Connection error:', error); logger.info('Connection error:', error);
globals.shutdown.initiate('Connection error'); globals.shutdown.initiate('Connection error');
}); });

Просмотреть файл

@ -4,7 +4,7 @@
"license": "MIT", "license": "MIT",
"scripts": { "scripts": {
"build": "tsc", "build": "tsc",
"test": "nyc --reporter=cobertura --reporter=text mocha test/**/*.test.ts", "test": "nyc --reporter=cobertura --reporter=text mocha \"test/**/*.test.ts\"",
"test_nnimanager": "nyc --reporter=cobertura --reporter=text mocha test/core/nnimanager.test.ts", "test_nnimanager": "nyc --reporter=cobertura --reporter=text mocha test/core/nnimanager.test.ts",
"mocha": "mocha", "mocha": "mocha",
"eslint": "eslint . --ext .ts" "eslint": "eslint . --ext .ts"

Просмотреть файл

@ -61,7 +61,7 @@ export class RestServerCore {
return deferred.promise; return deferred.promise;
} }
public shutdown(): Promise<void> { public shutdown(timeoutMilliseconds?: number): Promise<void> {
logger.info('Stopping REST server.'); logger.info('Stopping REST server.');
if (this.server === null) { if (this.server === null) {
logger.warning('REST server is not running.'); logger.warning('REST server is not running.');
@ -77,7 +77,7 @@ export class RestServerCore {
logger.debug('Killing connections'); logger.debug('Killing connections');
this.server?.closeAllConnections(); this.server?.closeAllConnections();
} }
}, 5000); }, timeoutMilliseconds ?? 5000);
return deferred.promise; return deferred.promise;
} }
} }

Просмотреть файл

@ -0,0 +1,209 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import assert from 'node:assert/strict';
import { setTimeout } from 'node:timers/promises';
import type {
Command, CommandChannel, CommandChannelClient, CommandChannelServer
} from 'common/command_channel/interface';
import { UnitTestHelper as Helper } from 'common/command_channel/websocket/channel';
import { WsChannelClient, WsChannelServer } from 'common/command_channel/websocket/index';
import { globals } from 'common/globals/unittest';
import { RestServerCore } from 'rest_server/core';
describe('## websocket command channel ##', () => {
before(beforeHook);
it('start', testServerStart);
it('connect', testClientStart);
it('message', testMessage);
it('reconnect', testReconnect);
it('message', testMessage);
it('handle error', testError);
it('shutdown', testShutdown);
after(afterHook);
});
/* test cases */
async function testServerStart(): Promise<void> {
ut.server = new WsChannelServer('ut_server', 'ut');
ut.server.onReceive((channelId, command) => {
if (channelId === '1') {
ut.events.push({ event: 'server_receive_1', command });
}
});
ut.server.onConnection((channelId, channel) => {
ut.events.push({ event: 'connect', channelId, channel });
channel.onClose(reason => {
ut.events.push({ event: `client_close_${channelId}`, reason });
});
channel.onError(error => {
ut.events.push({ event: `client_error_${channelId}`, error });
});
channel.onLost(() => {
ut.events.push({ event: `client_lost_${channelId}` });
});
if (channelId === '1') {
ut.serverChannel1 = channel;
}
if (channelId === '2') {
ut.serverChannel2 = channel;
channel.onReceive(command => {
ut.events.push({ event: 'server_receive_2', command });
});
}
});
await ut.server.start();
}
async function testClientStart(): Promise<void> {
const url1 = ut.server.getChannelUrl('1');
const url2 = ut.server.getChannelUrl('2', '127.0.0.1');
assert.equal(url1, `ws://localhost:${globals.args.port}/ut/1`);
assert.equal(url2, `ws://127.0.0.1:${globals.args.port}/ut/2`);
ut.client1 = new WsChannelClient('ut_client_1', url1);
ut.client2 = new WsChannelClient('ut_client_2', url2);
ut.client1.onReceive(command => {
ut.events.push({ event: 'client_receive_1', command });
});
ut.client2.onCommand('ut_command', command => {
ut.events.push({ event: 'client_receive_2', command });
});
ut.client2.onClose(reason => {
ut.events.push({ event: 'server_close_2', reason });
});
await Promise.all([
ut.client1.connect(),
ut.client2.connect(),
]);
assert.equal(ut.events[0].event, 'connect');
assert.equal(ut.events[1].event, 'connect');
assert.equal(Number(ut.events[0].channelId) + Number(ut.events[1].channelId), 3);
assert.equal(ut.events.length, 2);
ut.events.length = 0;
}
async function testReconnect(): Promise<void> {
const ws = (ut.client1 as any).connection.ws; // NOTE: private api
ws.pause();
await setTimeout(heartbeatTimeout);
ws.terminate();
ws.resume();
// mac pipeline can be slow
for (let i = 0; i < 10; i++) {
await setTimeout(heartbeat);
if (ut.events.length > 0) {
break;
}
}
assert.ok(ut.countEvents('client_lost_1') >= 1);
assert.ok(ut.countEvents('client_close_1') == 0);
assert.ok(ut.countEvents('client_error_1') == 0);
assert.ok(ut.countEvents('connect') == 0); // reconnect is not connect
ut.events.length = 0;
}
async function testMessage(): Promise<void> {
ut.server.send('1', ut.packCommand(1));
await ut.client2.sendAsync(ut.packCommand(2));
ut.client2.send(ut.packCommand('三'));
ut.server.send('2', ut.packCommand('4'));
ut.client1.send(ut.packCommand(5));
ut.server.send('1', ut.packCommand(6));
await setTimeout(heartbeat);
assert.deepEqual(ut.filterCommands('client_receive_1'), [ 1, 6 ]);
assert.deepEqual(ut.filterCommands('client_receive_2'), [ '4' ]);
assert.deepEqual(ut.filterCommands('server_receive_1'), [ 5 ]);
assert.deepEqual(ut.filterCommands('server_receive_2'), [ 2, '三' ]);
ut.events.length = 0;
}
async function testError(): Promise<void> {
ut.client2.terminate('client 2 terminate');
await setTimeout(terminateTimeout * 1.1);
assert.ok(ut.countEvents('client_close_2') == 0);
assert.ok(ut.countEvents('client_error_2') == 1);
ut.events.length = 0;
}
async function testShutdown(): Promise<void> {
await ut.server.shutdown();
assert.equal(ut.countEvents('client_close_1'), 1);
ut.events.length = 0;
}
/* helpers and states */
// NOTE: Increase these numbers if it fails randomly
const heartbeat = 10;
const heartbeatTimeout = 50;
const terminateTimeout = 100;
async function beforeHook(): Promise<void> {
globals.reset();
ut.rest = new RestServerCore();
await ut.rest.start();
Helper.setHeartbeatInterval(heartbeat);
Helper.setTerminateTimeout(terminateTimeout);
}
async function afterHook(): Promise<void> {
Helper.reset();
await ut.rest?.shutdown();
globals.reset();
}
class UnitTestStates {
server!: CommandChannelServer;
client1!: CommandChannelClient;
client2!: CommandChannelClient;
serverChannel1!: CommandChannel;
serverChannel2!: CommandChannel;
events: any[] = [];
rest!: RestServerCore;
countEvents(event: string): number {
return this.events.filter(e => (e.event === event)).length;
}
filterCommands(event: string): any[] {
return this.events.filter(e => (e.event === event)).map(e => e.command.value);
}
packCommand(value: any): Command {
return { type: 'ut_command', value };
}
}
const ut = new UnitTestStates();

Просмотреть файл

@ -39,16 +39,16 @@ export class RemoteTrainingServiceV3 implements TrainingServiceV3 {
this.log = getLogger(`RemoteV3.${this.id}`); this.log = getLogger(`RemoteV3.${this.id}`);
this.log.debug('Training sevice config:', config); this.log.debug('Training sevice config:', config);
this.server = new WsChannelServer('RemoteTrialKeeper', `platform/${this.id}`, config.nniManagerIp); this.server = new WsChannelServer(this.id, `/platform/${this.id}`);
this.server.on('connection', (channelId: string, channel: WsChannel) => { this.server.on('connection', (channelId: string, channel: WsChannel) => {
const worker = this.workersByChannel.get(channelId); const worker = this.workersByChannel.get(channelId);
if (worker) { if (worker) {
worker.setChannel(channel); worker.setChannel(channel);
channel.on('close', reason => { channel.onClose(reason => {
this.log.error('Worker channel closed unexpectedly:', reason); this.log.error('Worker channel closed unexpectedly:', reason);
}); });
channel.on('error', error => { channel.onError(error => {
this.log.error('Worker channel error:', error); this.log.error('Worker channel error:', error);
this.restartWorker(worker); this.restartWorker(worker);
}); });
@ -190,7 +190,7 @@ export class RemoteTrainingServiceV3 implements TrainingServiceV3 {
this.id, this.id,
channelId, channelId,
config, config,
this.server.getChannelUrl(channelId), this.server.getChannelUrl(channelId, this.config.nniManagerIp),
Boolean(this.config.trialGpuNumber) Boolean(this.config.trialGpuNumber)
); );

Просмотреть файл

@ -57,7 +57,7 @@ export class Worker {
public setChannel(channel: WsChannel): void { public setChannel(channel: WsChannel): void {
this.channel = channel; this.channel = channel;
this.trialKeeper.setChannel(channel); this.trialKeeper.setChannel(channel);
channel.on('lost', async () => { channel.onLost(async () => {
if (!await this.checkAlive()) { if (!await this.checkAlive()) {
this.log.error('Trial keeper failed'); this.log.error('Trial keeper failed');
channel.terminate('Trial keeper failed'); // MARK channel.terminate('Trial keeper failed'); // MARK