diff --git a/nni/runtime/command_channel/websocket/__init__.py b/nni/runtime/command_channel/websocket/__init__.py new file mode 100644 index 000000000..617080339 --- /dev/null +++ b/nni/runtime/command_channel/websocket/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from .channel import WsChannelClient diff --git a/nni/runtime/command_channel/websocket/channel.py b/nni/runtime/command_channel/websocket/channel.py new file mode 100644 index 000000000..90cc767eb --- /dev/null +++ b/nni/runtime/command_channel/websocket/channel.py @@ -0,0 +1,106 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import logging +import time + +import nni +from ..base import Command, CommandChannel +from .connection import WsConnection + +_logger = logging.getLogger(__name__) + +class WsChannelClient(CommandChannel): + def __init__(self, url: str): + self._url: str = url + self._closing: bool = False + self._conn: WsConnection | None = None + + def connect(self) -> None: + _logger.debug(f'Connect to {self._url}') + assert not self._closing + self._ensure_conn() + + def disconnect(self) -> None: + _logger.debug(f'Disconnect from {self._url}') + self.send({'type': '_bye_'}) + self._closing = True + self._close_conn('client intentionally close') + + def send(self, command: Command) -> None: + if self._closing: + return + _logger.debug(f'Send {command}') + msg = nni.dump(command) + for i in range(5): + try: + conn = self._ensure_conn() + conn.send(msg) + return + except Exception: + _logger.exception(f'Failed to send command. Retry in {i}s') + self._terminate_conn('send fail') + time.sleep(i) + _logger.warning(f'Failed to send command {command}. Last retry') + conn = self._ensure_conn() + conn.send(msg) + + def receive(self) -> Command | None: + while True: + if self._closing: + return None + msg = self._receive_msg() + if msg is None: + return None + command = nni.load(msg) + if command['type'] == '_nop_': + continue + if command['type'] == '_bye_': + reason = command.get('reason') + _logger.debug(f'Server close connection: {reason}') + self._closing = True + self._close_conn('server intentionally close') + return None + return command + + def _ensure_conn(self) -> WsConnection: + if self._conn is None and not self._closing: + self._conn = WsConnection(self._url) + self._conn.connect() + _logger.debug('Connected') + return self._conn # type: ignore + + def _close_conn(self, reason: str) -> None: + if self._conn is not None: + try: + self._conn.disconnect(reason) + except Exception: + pass + self._conn = None + + def _terminate_conn(self, reason: str) -> None: + if self._conn is not None: + try: + self._conn.terminate(reason) + except Exception: + pass + self._conn = None + + def _receive_msg(self) -> str | None: + for i in range(5): + try: + conn = self._ensure_conn() + msg = conn.receive() + _logger.debug(f'Receive {msg}') + if not self._closing: + assert msg is not None + return msg + except Exception: + _logger.exception(f'Failed to receive command. Retry in {i}s') + self._terminate_conn('receive fail') + time.sleep(i) + _logger.warning(f'Failed to receive command. Last retry') + conn = self._ensure_conn() + conn.receive() diff --git a/nni/runtime/command_channel/websocket/connection.py b/nni/runtime/command_channel/websocket/connection.py new file mode 100644 index 000000000..099098090 --- /dev/null +++ b/nni/runtime/command_channel/websocket/connection.py @@ -0,0 +1,139 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Synchronized and object-oriented WebSocket class. + +WebSocket guarantees that messages will not be divided at API level. +""" + +from __future__ import annotations + +__all__ = ['WsConnection'] + +import asyncio +import logging +from threading import Lock, Thread +from typing import Any, Type + +import websockets + +_logger = logging.getLogger(__name__) + +# the singleton event loop +_event_loop: asyncio.AbstractEventLoop = None # type: ignore +_event_loop_lock: Lock = Lock() +_event_loop_refcnt: int = 0 # number of connected websockets + +class WsConnection: + """ + A WebSocket connection. + + Call :meth:`connect` before :meth:`send` and :meth:`receive`. + + All methods are thread safe. + + Parameters + ---------- + url + The WebSocket URL. + For tuner command channel it should be something like ``ws://localhost:8080/tuner``. + """ + + ConnectionClosed: Type[Exception] = websockets.ConnectionClosed # type: ignore + + def __init__(self, url: str): + self._url: str = url + self._ws: Any = None # the library does not provide type hints + + def connect(self) -> None: + global _event_loop, _event_loop_refcnt + with _event_loop_lock: + _event_loop_refcnt += 1 + if _event_loop is None: + _logger.debug('Starting event loop.') + # following line must be outside _run_event_loop + # because _wait() might be executed before first line of the child thread + _event_loop = asyncio.new_event_loop() + thread = Thread(target=_run_event_loop, name='NNI-WebSocketEventLoop', daemon=True) + thread.start() + + _logger.debug(f'Connecting to {self._url}') + self._ws = _wait(_connect_async(self._url)) + _logger.debug(f'Connected.') + + def disconnect(self, reason: str | None = None, code: int | None = None) -> None: + if self._ws is None: + _logger.debug('disconnect: No connection.') + return + + try: + _wait(self._ws.close(code or 4000, reason)) + _logger.debug('Connection closed by client.') + except Exception as e: + _logger.warning(f'Failed to close connection: {repr(e)}') + self._ws = None + _decrease_refcnt() + + def terminate(self, reason: str | None = None) -> None: + if self._ws is None: + _logger.debug('terminate: No connection.') + return + self.disconnect(reason, 4001) + + def send(self, message: str) -> None: + _logger.debug(f'Sending {message}') + try: + _wait(self._ws.send(message)) + except websockets.ConnectionClosed: # type: ignore + _logger.debug('Connection closed by server.') + self._ws = None + _decrease_refcnt() + raise + + def receive(self) -> str | None: + """ + Return received message; + or return ``None`` if the connection has been closed by peer. + """ + try: + msg = _wait(self._ws.recv()) + _logger.debug(f'Received {msg}') + except websockets.ConnectionClosed: # type: ignore + _logger.debug('Connection closed by server.') + self._ws = None + _decrease_refcnt() + raise + + # seems the library will inference whether it's text or binary, so we don't have guarantee + if isinstance(msg, bytes): + return msg.decode() + else: + return msg + +def _wait(coro): + # Synchronized version of "await". + future = asyncio.run_coroutine_threadsafe(coro, _event_loop) + return future.result() + +def _run_event_loop() -> None: + # A separate thread to run the event loop. + # The event loop itself is blocking, and send/receive are also blocking, + # so they must run in different threads. + asyncio.set_event_loop(_event_loop) + _event_loop.run_forever() + _logger.debug('Event loop stopped.') + +async def _connect_async(url): + # Theoretically this function is meaningless and one can directly use `websockets.connect(url)`, + # but it will not work, raising "TypeError: A coroutine object is required". + # Seems a design flaw in websockets library. + return await websockets.connect(url, max_size=None) # type: ignore + +def _decrease_refcnt() -> None: + global _event_loop, _event_loop_refcnt + with _event_loop_lock: + _event_loop_refcnt -= 1 + if _event_loop_refcnt == 0: + _event_loop.call_soon_threadsafe(_event_loop.stop) + _event_loop = None # type: ignore diff --git a/nni/runtime/tuner_command_channel/websocket.py b/nni/runtime/tuner_command_channel/websocket.py index c4bef55fd..afc1790f6 100644 --- a/nni/runtime/tuner_command_channel/websocket.py +++ b/nni/runtime/tuner_command_channel/websocket.py @@ -1,133 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -""" -Synchronized and object-oriented WebSocket class. - -WebSocket guarantees that messages will not be divided at API level. -""" - -from __future__ import annotations - -__all__ = ['WebSocket'] - -import asyncio -import logging -from threading import Lock, Thread -from typing import Any, Type - -import websockets - -_logger = logging.getLogger(__name__) - -# the singleton event loop -_event_loop: asyncio.AbstractEventLoop = None # type: ignore -_event_loop_lock: Lock = Lock() -_event_loop_refcnt: int = 0 # number of connected websockets - -class WebSocket: - """ - A WebSocket connection. - - Call :meth:`connect` before :meth:`send` and :meth:`receive`. - - All methods are thread safe. - - Parameters - ---------- - url - The WebSocket URL. - For tuner command channel it should be something like ``ws://localhost:8080/tuner``. - """ - - ConnectionClosed: Type[Exception] = websockets.ConnectionClosed # type: ignore - - def __init__(self, url: str): - self._url: str = url - self._ws: Any = None # the library does not provide type hints - - def connect(self) -> None: - global _event_loop, _event_loop_refcnt - with _event_loop_lock: - _event_loop_refcnt += 1 - if _event_loop is None: - _logger.debug('Starting event loop.') - # following line must be outside _run_event_loop - # because _wait() might be executed before first line of the child thread - _event_loop = asyncio.new_event_loop() - thread = Thread(target=_run_event_loop, name='NNI-WebSocketEventLoop', daemon=True) - thread.start() - - _logger.debug(f'Connecting to {self._url}') - self._ws = _wait(_connect_async(self._url)) - _logger.debug(f'Connected.') - - def disconnect(self) -> None: - if self._ws is None: - _logger.debug('disconnect: No connection.') - return - - try: - _wait(self._ws.close()) - _logger.debug('Connection closed by client.') - except Exception as e: - _logger.warning(f'Failed to close connection: {repr(e)}') - self._ws = None - _decrease_refcnt() - - def send(self, message: str) -> None: - _logger.debug(f'Sending {message}') - try: - _wait(self._ws.send(message)) - except websockets.ConnectionClosed: # type: ignore - _logger.debug('Connection closed by server.') - self._ws = None - _decrease_refcnt() - raise - - def receive(self) -> str | None: - """ - Return received message; - or return ``None`` if the connection has been closed by peer. - """ - try: - msg = _wait(self._ws.recv()) - _logger.debug(f'Received {msg}') - except websockets.ConnectionClosed: # type: ignore - _logger.debug('Connection closed by server.') - self._ws = None - _decrease_refcnt() - raise - - # seems the library will inference whether it's text or binary, so we don't have guarantee - if isinstance(msg, bytes): - return msg.decode() - else: - return msg - -def _wait(coro): - # Synchronized version of "await". - future = asyncio.run_coroutine_threadsafe(coro, _event_loop) - return future.result() - -def _run_event_loop() -> None: - # A separate thread to run the event loop. - # The event loop itself is blocking, and send/receive are also blocking, - # so they must run in different threads. - asyncio.set_event_loop(_event_loop) - _event_loop.run_forever() - _logger.debug('Event loop stopped.') - -async def _connect_async(url): - # Theoretically this function is meaningless and one can directly use `websockets.connect(url)`, - # but it will not work, raising "TypeError: A coroutine object is required". - # Seems a design flaw in websockets library. - return await websockets.connect(url, max_size=None) # type: ignore - -def _decrease_refcnt() -> None: - global _event_loop, _event_loop_refcnt - with _event_loop_lock: - _event_loop_refcnt -= 1 - if _event_loop_refcnt == 0: - _event_loop.call_soon_threadsafe(_event_loop.stop) - _event_loop = None # type: ignore +from ..command_channel.websocket.connection import WsConnection as WebSocket # pylint: disable=unused-import diff --git a/test/ut/sdk/helper/websocket_server.py b/test/ut/sdk/helper/websocket_server.py index 82bc3c7a5..9ad3a4427 100644 --- a/test/ut/sdk/helper/websocket_server.py +++ b/test/ut/sdk/helper/websocket_server.py @@ -32,7 +32,7 @@ async def read_stdin(): line = line.decode().strip() _debug(f'read from stdin: {line}') if line == '_close_': - exit() + break await _ws.send(line) async def ws_server(): @@ -46,9 +46,12 @@ async def on_connect(ws): global _ws _debug('connected') _ws = ws - async for msg in ws: - _debug(f'received from websocket: {msg}') - print(msg, flush=True) + try: + async for msg in ws: + _debug(f'received from websocket: {msg}') + print(msg, flush=True) + except websockets.exceptions.ConnectionClosedError: + pass def _debug(msg): #sys.stderr.write(f'[server-debug] {msg}\n') diff --git a/test/ut/sdk/test_tuner_command_channel.py b/test/ut/sdk/test_ws_channel.py similarity index 80% rename from test/ut/sdk/test_tuner_command_channel.py rename to test/ut/sdk/test_ws_channel.py index 2320651bb..40aaf4685 100644 --- a/test/ut/sdk/test_tuner_command_channel.py +++ b/test/ut/sdk/test_ws_channel.py @@ -11,21 +11,21 @@ from subprocess import Popen, PIPE import sys import time -from nni.runtime.tuner_command_channel.websocket import WebSocket +from nni.runtime.command_channel.websocket import WsChannelClient # A helper server that connects its stdio to incoming WebSocket. _server = None _client = None -_command1 = 'T_hello world' -_command2 = 'T_你好' +_command1 = {'type': 'ut_command', 'value': 123} +_command2 = {'type': 'ut_command', 'value': '你好'} ## test cases ## def test_connect(): global _client port = _init() - _client = WebSocket(f'ws://localhost:{port}') + _client = WsChannelClient(f'ws://localhost:{port}') _client.connect() def test_send(): @@ -34,16 +34,16 @@ def test_send(): _client.send(_command2) time.sleep(0.01) - sent1 = _server.stdout.readline().strip() + sent1 = json.loads(_server.stdout.readline()) assert sent1 == _command1, sent1 - sent2 = _server.stdout.readline().strip() + sent2 = json.loads(_server.stdout.readline().strip()) assert sent2 == _command2, sent2 def test_receive(): # Send commands to server via stdin, and get them back via channel. - _server.stdin.write(_command1 + '\n') - _server.stdin.write(_command2 + '\n') + _server.stdin.write(json.dumps(_command1) + '\n') + _server.stdin.write(json.dumps(_command2) + '\n') _server.stdin.flush() received1 = _client.receive() diff --git a/ts/nni_manager/common/command_channel/http.ts b/ts/nni_manager/common/command_channel/http.ts index 5aa40fdee..df4eceb43 100644 --- a/ts/nni_manager/common/command_channel/http.ts +++ b/ts/nni_manager/common/command_channel/http.ts @@ -51,8 +51,8 @@ export class HttpChannelServer implements CommandChannelServer { this.outgoingQueues.forEach(queue => { queue.clear(); }); } - public getChannelUrl(channelId: string): string { - return globals.rest.getFullUrl('http', 'localhost', this.path, channelId); + public getChannelUrl(channelId: string, ip?: string): string { + return globals.rest.getFullUrl('http', ip ?? 'localhost', this.path, channelId); } public send(channelId: string, command: Command): void { @@ -63,6 +63,10 @@ export class HttpChannelServer implements CommandChannelServer { this.emitter.on('receive', callback); } + public onConnection(_callback: (channelId: string, channel: any) => void): void { + throw new Error('Not implemented'); + } + private handleGet(request: Request, response: Response): void { const channelId = request.params['channel']; const promise = this.getOutgoingQueue(channelId).asyncPop(timeoutMilliseconds); diff --git a/ts/nni_manager/common/command_channel/interface.ts b/ts/nni_manager/common/command_channel/interface.ts index 64f9b4a14..1b0ef4db1 100644 --- a/ts/nni_manager/common/command_channel/interface.ts +++ b/ts/nni_manager/common/command_channel/interface.ts @@ -1,26 +1,158 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -//export interface Command { -// type: string; -// [key: string]: any; -//} +/** + * Common interface of command channels. + * + * A command channel is a duplex connection which supports sending and receiving JSON commands. + * + * Typically a command channel implementation consists of a server and a client. + * + * The server should listen to a URL prefix like `http://localhost:8080/example/`; + * and each client should connect to a unique URL containing the prefix, e.g. `http://localhost:8080/example/channel1`. + * The client's URL should be created with `server.getChannelUrl(channelId, serverIp)`. + * + * We currently have implemented one full feature command channel, the WebSocket channel, + * and a simplified one, the HTTP channel. + * In v3.1 release we might implement file command channel and AzureML command channel. + * + * The clients might have a Python version locates in `nni/runtime/command_channel`. + * The TypeScript and Python version should be interchangable. + **/ + +/** + * A command is a JSON object. + * + * The object has only one mandatory entry, `type`. + * + * The type string should not be surrounded by underscore (e.g. `_nop_`), + * unless you are dealing with the underlying implementation of a specific command channel; + * it should never starts with two underscores (e.g. `__command`) in any circumstance. + **/ export type Command = any; + +// Maybe it's better to disable `noPropertyAccessFromIndexSignature` in tscofnig? +// export interface Command { +// type: string; +// [key: string]: any; +// } /** - * A command channel server serves one or more command channels. - * Each channel is connected to a client. + * `CommandChannel` is the base interface used by both the servers and the clients. * - * Normally each client has a unique channel URL, - * which can be got with `server.getChannelUrl(id)`. + * For servers, channels can be got with `onConnection()` event listener. + * For clients, a channel can be created with the client subclass' constructor. * - * The APIs might be changed to return `Promise` in future. + * The channel should be fault tolerant to some extend. It has three different types of closing related events: + * + * 1. Close: The channel is intentionally closed. + * + * 2. Lost: The channel is temporarily unavailable and is trying to recover. + * The user of this class should examine the peer's status out-of-band when receiving "lost" event. + * + * 3. Error: The channel is dead and cannot recover. + * A "close" event may or may not occur following this event. Do not rely on that. + **/ +export interface CommandChannel { + readonly name: string; // for better logging + + enableHeartbeat(intervalMilliseconds?: number): void; + + /** + * Graceful (intentional) close. + * A "close" event will be emitted by `this` and the peer. + **/ + close(reason: string): void; + + /** + * Force close. Should only be used when the channel is not working. + * An "error" event may be emitted by `this`. + * A "lost" and/or "error" event will be emitted by the peer, if its process is still alive. + **/ + terminate(reason: string): void; + + send(command: Command): void; + + /** + * The async version should try to ensures the command is successfully sent to the peer. + * But this is not guaranteed. + **/ + sendAsync(command: Command): Promise; + + onReceive(callback: (command: Command) => void): void; + onCommand(commandType: string, callback: (command: Command) => void): void; + + onClose(callback: (reason?: string) => void): void; + onError(callback: (error: Error) => void): void; + onLost(callback: () => void): void; +} + +/** + * Client side of a command channel. + * + * The constructor should have no side effects. + * + * The listeners should be registered before calling `connect()`, + * or the first few commands might be missed. + * + * Example usage: + * + * const client = new WsChannelClient('example', 'ws://1.2.3.4:8080/server/channel_id'); + * await client.connect(); + * client.send(command); + **/ +export interface CommandChannelClient extends CommandChannel { + // constructor(name: string, url: string); + + connect(): Promise; + + /** + * Typically an alias of `close()`. + **/ + disconnect(reason?: string): Promise; +} + +/** + * Server side of a command channel. + * + * The consructor should have no side effects. + * + * The listeners should be registered before calling `start()`. + * + * Example usage: + * + * const server = new WsChannelServer('example_server', '/server_prefix'); + * const url = server.getChannelUrl('channel_id'); + * const client = new WsChannelClient('example_client', url); + * await server.start(); + * await client.connect(); + * + * There two ways to listen to command: + * + * 1. Handle all clients' commands in one space: + * + * server.onReceive((channelId, command) => { ... }); + * server.send(channelId, command); + * + * 2. Maintain a `WsChannel` instance for each client: + * + * server.onConnection((channelId, channel) => { + * channel.onCommand(command => { ... }); + * channel.send(command); + * }); **/ export interface CommandChannelServer { - // constructor(name: string, urlPath: string) + // constructor(name: string, urlPath: string); + start(): Promise; shutdown(): Promise; - getChannelUrl(channelId: string): string; + + /** + * When `ip` is missing, it should default to localhost. + **/ + getChannelUrl(channelId: string, ip?: string): string; + send(channelId: string, command: Command): void; onReceive(callback: (channelId: string, command: Command) => void): void; + onConnection(callback: (channelId: string, channel: CommandChannel) => void): void; } diff --git a/ts/nni_manager/common/command_channel/rpc_util.ts b/ts/nni_manager/common/command_channel/rpc_util.ts index 9cf7ee544..f162f6e13 100644 --- a/ts/nni_manager/common/command_channel/rpc_util.ts +++ b/ts/nni_manager/common/command_channel/rpc_util.ts @@ -50,7 +50,7 @@ import { DefaultMap } from 'common/default_map'; import { Deferred } from 'common/deferred'; import { Logger, getLogger } from 'common/log'; import type { TrialKeeper } from 'common/trial_keeper/keeper'; -import { WsChannel } from './websocket/channel'; +import type { CommandChannel } from './interface'; interface RpcResponseCommand { type: 'rpc_response'; @@ -61,14 +61,14 @@ interface RpcResponseCommand { type Class = { new(...args: any[]): any; }; -const rpcHelpers: Map = new Map(); +const rpcHelpers: Map = new Map(); /** * Enable RPC on a channel. * * The channel does not need to be connected for calling this function. **/ -export function getRpcHelper(channel: WsChannel): RpcHelper { +export function getRpcHelper(channel: CommandChannel): RpcHelper { if (!rpcHelpers.has(channel)) { rpcHelpers.set(channel, new RpcHelper(channel)); } @@ -76,7 +76,7 @@ export function getRpcHelper(channel: WsChannel): RpcHelper { } export class RpcHelper { - private channel: WsChannel; + private channel: CommandChannel; private lastId: number = 0; private localCtors: Map = new Map(); private localObjs: Map = new Map(); @@ -87,7 +87,7 @@ export class RpcHelper { /** * NOTE: Don't use this constructor directly. Use `getRpcHelper()`. **/ - constructor(channel: WsChannel) { + constructor(channel: CommandChannel) { this.log = getLogger(`RpcHelper.${channel.name}`); this.channel = channel; this.channel.onCommand('rpc_constructor', command => { diff --git a/ts/nni_manager/common/command_channel/websocket/channel.ts b/ts/nni_manager/common/command_channel/websocket/channel.ts index ee0b908b9..b9591db58 100644 --- a/ts/nni_manager/common/command_channel/websocket/channel.ts +++ b/ts/nni_manager/common/command_channel/websocket/channel.ts @@ -4,353 +4,240 @@ /** * WebSocket command channel. * - * This is the base class that used by both server and client. + * A WsChannel operates on one WebSocket connection at a time. + * But when the network is unstable, it may close the underlying connection and create a new one. + * This is generally transparent to the user of this class, except that a "lost" event will be emitted. * - * For the server, channels can be got with `onConnection()` event listener. - * For the client, a channel can be created with `new WsChannelClient()` subclass. - * Do not use the constructor directly. + * To distinguish intentional close from connection lost, + * a "_bye_" command will be sent when `close()` or `disconnect()` is invoked. * - * The channel is fault tolerant to some extend. It has three different types of closing related events: + * If the connection is closed before receiving "_bye_" command, a "lost" event will be emitted and: * - * 1. "close": The channel is intentionally closed. + * * The client will try to reconnect for severaly times in around 15s. + * * The server will wait the client to reconnect for around 30s. * - * This is caused either by "close()" or "disconnect()" call, or by receiving a "_bye_" command from the peer. - * - * 2. "lost": The channel is temporarily unavailable and is trying to recover. - * (The high level class should examine the peer's status out-of-band when receiving this event.) - * - * When the underlying socket is dead, this event is emitted. - * The client will try to reconnect in around 15s. If all attempts fail, an "error" event will be emitted. - * The server will wait the client for 30s. If it does not reconnect, an "error" event will be emitted. - * Successful recover will not emit command. - * - * 3. "error": The channel is dead and cannot recover. - * - * A "close" event may or may not follow this event. Do not rely on that. + * If the reconnecting attempt failed, both side will emit an "error" event. **/ -import { EventEmitter } from 'node:events'; +import { EventEmitter, once } from 'node:events'; import util from 'node:util'; import type { WebSocket } from 'ws'; +import type { Command, CommandChannel } from 'common/command_channel/interface'; import { Deferred } from 'common/deferred'; import { Logger, getLogger } from 'common/log'; +import { WsConnection } from './connection'; -import type { Command } from '../interface'; - -interface WsChannelEvents { - 'command': (command: Command) => void; - 'close': (reason: string) => void; - 'lost': () => void; - 'error': (error: Error) => void; // not used in base class +interface QueuedCommand { + command: Command; + deferred?: Deferred; } -export declare interface WsChannel { - on(event: E, listener: WsChannelEvents[E]): this; -} - -export class WsChannel extends EventEmitter { +export class WsChannel implements CommandChannel { private closing: boolean = false; - private commandEmitter: EventEmitter = new EventEmitter(); - private connection: WsConnection | null = null; - private epoch: number = 0; - private heartbeatInterval: number | null; + private connection: WsConnection | null = null; // NOTE: used in unit test + private epoch: number = -1; + private heartbeatInterval: number | null = null; private log: Logger; + private queue: QueuedCommand[] = []; + private terminateTimer: NodeJS.Timer | null = null; + + protected emitter: EventEmitter = new EventEmitter(); public readonly name: string; // internal, don't use - constructor(name: string, ws?: WebSocket, heartbeatInterval?: number) { - super() + constructor(name: string) { this.log = getLogger(`WsChannel.${name}`); this.name = name; - this.heartbeatInterval = heartbeatInterval || null; - if (ws) { - this.setConnection(ws); - } - } - - public enableHeartbeat(interval: number): void { - this.log.debug('## enable heartbeat'); - this.heartbeatInterval = interval; } // internal, don't use - public setConnection(ws: WebSocket): void { - if (this.connection) { - this.log.debug('Abandon previous connection'); - this.epoch += 1; + public async setConnection(ws: WebSocket, waitOpen: boolean): Promise { + if (this.terminateTimer) { + clearTimeout(this.terminateTimer); + this.terminateTimer = null; } + + this.connection?.terminate('new epoch start'); + this.newEpoch(); this.log.debug(`Epoch ${this.epoch} start`); + this.connection = this.configConnection(ws); + if (waitOpen) { + await once(ws, 'open'); + } + + while (this.connection && this.queue.length > 0) { + const item = this.queue.shift()!; + try { + await this.connection.sendAsync(item.command); + item.deferred?.resolve(); + } catch (error) { + this.log.error('Failed to send command on recovered channel:', error); + this.log.error('Dropped command:', item.command); + item.deferred?.reject(error as any); + // it should trigger connection's error event and this.connection will be set to null + } + } + } + + public enableHeartbeat(interval?: number): void { + this.heartbeatInterval = interval ?? defaultHeartbeatInterval; + this.connection?.setHeartbeatInterval(this.heartbeatInterval); } public close(reason: string): void { this.log.debug('Close channel:', reason); - if (this.connection) { - this.connection.close(reason); - this.endEpoch(); - } - if (!this.closing) { - this.closing = true; - this.emit('close', reason); + this.connection?.close(reason); + if (this.setClosing()) { + this.emitter.emit('__close', reason); } } public terminate(reason: string): void { this.log.info('Terminate channel:', reason); - this.closing = true; - if (this.connection) { - this.connection.terminate(reason); - this.endEpoch(); + this.connection?.terminate(reason); + if (this.setClosing()) { + this.emitter.emit('__error', new Error(`WsChannel terminated: ${reason}`)); } } public send(command: Command): void { - if (this.connection) { - this.connection.send(command); - } else { - // TODO: add a queue? - this.log.error('Connection lost. Dropped command', command); + if (this.closing) { + this.log.error('Channel closed. Ignored command', command); + return; } + + if (!this.connection) { + this.log.warning('Connection lost. Enqueue command', command); + this.queue.push({ command }); + return; + } + + this.connection.send(command); } - /** - * Async version of `send()` that (partially) ensures the command is successfully sent to peer. - **/ public sendAsync(command: Command): Promise { - if (this.connection) { - return this.connection.sendAsync(command); - } else { - this.log.error('Connection lost. Dropped command async', command); - return Promise.reject(new Error('Connection is lost and trying to recover, cannot send command now')); + if (this.closing) { + this.log.error('(async) Channel closed. Refused command', command); + return Promise.reject(new Error('WsChannel has been closed')); } + + if (!this.connection) { + this.log.warning('(async) Connection lost. Enqueue command', command); + const deferred = new Deferred(); + this.queue.push({ command, deferred }); + return deferred.promise; + } + + return this.connection.sendAsync(command); } - // the first overload listens to all commands, while the second listens to one command type - public onCommand(callback: (command: Command) => void): void; - public onCommand(commandType: string, callback: (command: Command) => void): void; + public onReceive(callback: (command: Command) => void): void { + this.emitter.on('__receive', callback); + } - public onCommand(commandTypeOrCallback: any, callbackOrNone?: any): void { - if (callbackOrNone) { - this.commandEmitter.on(commandTypeOrCallback, callbackOrNone); - } else { - this.commandEmitter.on('__any', commandTypeOrCallback); - } + public onCommand(commandType: string, callback: (command: Command) => void): void { + this.emitter.on(commandType, callback); + } + + public onClose(callback: (reason?: string) => void): void { + this.emitter.on('__close', callback); + } + + public onError(callback: (error: Error) => void): void { + this.emitter.on('__error', callback); + } + + public onLost(callback: () => void): void { + this.emitter.on('__lost', callback); + } + + private newEpoch(): void { + this.connection = null; + this.epoch += 1; } private configConnection(ws: WebSocket): WsConnection { - this.log.debug('## config connection'); - const epoch = this.epoch; // copy it to use in closure - const conn = new WsConnection( - this.epoch ? `${this.name}.${epoch}` : this.name, - ws, - this.commandEmitter, - this.heartbeatInterval - ); + const connName = this.epoch ? `${this.name}.${this.epoch}` : this.name; + const conn = new WsConnection(connName, ws, this.emitter); + if (this.heartbeatInterval) { + conn.setHeartbeatInterval(this.heartbeatInterval); + } conn.on('bye', reason => { - this.log.debug('Peer intentionally close:', reason); - this.endEpoch(); - if (!this.closing) { - this.closing = true; - this.emit('close', reason); + this.log.debug('Peer intentionally closing:', reason); + if (this.setClosing()) { + this.emitter.emit('__close', reason); } }); conn.on('close', (code, reason) => { - this.closeConnection(epoch, `Received closing handshake: ${code} ${reason}`); + this.log.debug('Peer closed:', reason); + this.dropConnection(conn, `Peer closed: ${code} ${reason}`); }); conn.on('error', error => { - this.closeConnection(epoch, `Error occurred: ${util.inspect(error)}`); + this.dropConnection(conn, `Connection error: ${util.inspect(error)}`); }); return conn; } - private closeConnection(epoch: number, reason: string): void { + private setClosing(): boolean { if (this.closing) { - this.log.debug('Connection cleaned up:', reason); + return false; + } + this.closing = true; + this.newEpoch(); + this.queue.forEach(item => { + item.deferred?.reject(new Error('WsChannel has been closed.')); + }); + return true; + } + + private dropConnection(conn: WsConnection, reason: string): void { + if (this.closing) { + this.log.debug('Clean up:', reason); return; } - if (this.epoch !== epoch) { // the connection is already abandoned - this.log.debug(`Previous connection closed ${epoch}: ${reason}`); + if (this.connection !== conn) { // the connection is already abandoned + this.log.debug(`Previous connection closed: ${reason}`); return; } this.log.warning('Connection closed unexpectedly:', reason); - this.emit('lost'); - this.endEpoch(); + this.newEpoch(); + this.emitter.emit('__lost'); + + if (!this.terminateTimer) { + this.terminateTimer = setTimeout(() => { + if (!this.closing) { + this.terminate('have not reconnected in 30s'); + } + }, terminateTimeout); + } + // the reconnect logic is in client subclass } - - private endEpoch(): void { - this.connection = null; - this.epoch += 1; - } } -interface WsConnectionEvents { - 'command': (command: Command) => void; - 'bye': (reason: string) => void; - 'close': (code: number, reason: string) => void; - 'error': (error: Error) => void; -} +let defaultHeartbeatInterval: number = 5000; +let terminateTimeout: number = 30000; -declare interface WsConnection { - on(event: E, listener: WsConnectionEvents[E]): this; -} - -class WsConnection extends EventEmitter { - private closing: boolean = false; - private commandEmitter: EventEmitter; - private heartbeatTimer: NodeJS.Timer | null = null; - private log: Logger; - private missingPongs: number = 0; - private ws: WebSocket; - - constructor(name: string, ws: WebSocket, commandEmitter: EventEmitter, heartbeatInterval: number | null) { - super(); - this.log = getLogger(`WsConnection.${name}`); - this.ws = ws; - this.commandEmitter = commandEmitter; - - ws.on('close', this.handleClose.bind(this)); - ws.on('error', this.handleError.bind(this)); - ws.on('message', this.handleMessage.bind(this)); - ws.on('pong', this.handlePong.bind(this)); - - if (heartbeatInterval) { - this.heartbeatTimer = setInterval(this.heartbeat.bind(this), heartbeatInterval); - } +export namespace UnitTestHelper { + export function setHeartbeatInterval(ms: number): void { + defaultHeartbeatInterval = ms; } - public async close(reason: string): Promise { - if (this.closing) { - this.log.debug('Close again:', reason); - return; - } - - this.log.debug('Close:', reason); - this.closing = true; - if (this.heartbeatTimer) { - clearInterval(this.heartbeatTimer); - this.heartbeatTimer = null; - } - - try { - await this.sendAsync({ type: '_bye_', reason }); - } catch (error) { - this.log.error('Failed to send bye:', error); - } - - try { - this.ws.close(4000, reason); - } catch (error) { - this.log.error('Failed to close:', error); - this.ws.terminate(); - } + export function setTerminateTimeout(ms: number): void { + terminateTimeout = ms; } - public terminate(reason: string): void { - this.log.debug('Terminate:', reason); - this.closing = true; - if (this.heartbeatTimer) { - clearInterval(this.heartbeatTimer); - this.heartbeatTimer = null; - } - - try { - this.ws.close(4001, reason); - } catch (error) { - this.log.debug('Failed to close:', error); - this.ws.terminate(); - } - } - - public send(command: Command): void { - this.log.trace('Sending command', command); - this.ws.send(JSON.stringify(command)); - } - - public sendAsync(command: Command): Promise { - this.log.trace('Sending command async', command); - const deferred = new Deferred(); - this.ws.send(JSON.stringify(command), error => { - if (error) { - deferred.reject(error); - } else { - deferred.resolve(); - } - }); - return deferred.promise; - } - - private handleClose(code: number, reason: Buffer): void { - if (this.closing) { - this.log.debug('Connection closed'); - } else { - this.log.debug('Connection closed by peer:', code, String(reason)); - this.emit('close', code, String(reason)); - } - } - - private handleError(error: Error): void { - if (this.closing) { - this.log.warning('Error after closing:', error); - } else { - this.emit('error', error); - } - } - - private handleMessage(data: Buffer, _isBinary: boolean): void { - const s = String(data); - if (this.closing) { - this.log.warning('Received message after closing:', s); - return; - } - - this.log.trace('Received command', s); - const command = JSON.parse(s); - - if (command.type === '_nop_') { - return; - } - if (command.type === '_bye_') { - this.closing = true; - this.emit('bye', command.reason); - return; - } - - const hasAnyListener = this.commandEmitter.emit('__any', command); - const hasTypeListener = this.commandEmitter.emit(command.type, command); - if (!hasAnyListener && !hasTypeListener) { - this.log.warning('No listener for command', s); - } - } - - private handlePong(): void { - this.log.debug('receive pong'); // todo - this.missingPongs = 0; - } - - private heartbeat(): void { - if (this.missingPongs > 0) { - this.log.warning('Missing pong'); - } - if (this.missingPongs > 3) { // TODO: make it configurable? - // no response for ping, try real command - this.sendAsync({ type: '_nop_' }).then(() => { - this.missingPongs = 0; - }).catch(error => { - this.log.error('Failed sending command. Drop connection:', error); - this.terminate(`peer lost responsive: ${util.inspect(error)}`); - }); - } - this.missingPongs += 1; - this.log.debug('send ping'); - this.ws.ping(); + export function reset(): void { + defaultHeartbeatInterval = 5000; + terminateTimeout = 30000; } } diff --git a/ts/nni_manager/common/command_channel/websocket/client.ts b/ts/nni_manager/common/command_channel/websocket/client.ts index 844ecb179..3a2674020 100644 --- a/ts/nni_manager/common/command_channel/websocket/client.ts +++ b/ts/nni_manager/common/command_channel/websocket/client.ts @@ -3,15 +3,6 @@ /** * WebSocket command channel client. - * - * Usage: - * - * const client = new WsChannelClient('ws://1.2.3.4:8080/server/channel_id'); - * await client.connect(); - * client.send(command); - * - * Most APIs are derived the base class `WsChannel`. - * See its doc for more details. **/ import events from 'node:events'; @@ -26,7 +17,7 @@ import { WsChannel } from './channel'; const maxPayload: number = 1024 * 1024 * 1024; export class WsChannelClient extends WsChannel { - private logger: Logger; // avoid name conflict with base class + private logger: Logger; private reconnecting: boolean = false; private url: string; @@ -34,19 +25,17 @@ export class WsChannelClient extends WsChannel { * The url should start with "ws://". * The name is used for better logging. **/ - constructor(url: string, name?: string) { - const name_ = name ?? generateName(url); - super(name_); - this.logger = getLogger(`WsChannelClient.${name_}`); + constructor(name: string, url: string) { + super(name); + this.logger = getLogger(`WsChannelClient.${name}`); this.url = url; - this.on('lost', this.reconnect.bind(this)); + this.onLost(this.reconnect.bind(this)); } public async connect(): Promise { this.logger.debug('Connecting to', this.url); const ws = new WebSocket(this.url, { maxPayload }); - this.setConnection(ws); - await events.once(ws, 'open'); + await this.setConnection(ws, true), this.logger.debug('Connected'); } @@ -54,7 +43,7 @@ export class WsChannelClient extends WsChannel { * Alias of `close()`. **/ public async disconnect(reason?: string): Promise { - this.close(reason ?? 'client disconnecting'); + this.close(reason ?? 'client intentionally disconnect'); } private async reconnect(): Promise { @@ -82,16 +71,6 @@ export class WsChannelClient extends WsChannel { } this.logger.error('Conenction lost. Cannot reconnect'); - this.emit('error', new Error('Connection lost')); + this.emitter.emit('__error', new Error('Connection lost')); } } - -function generateName(url: string): string { - const parts = url.split('/'); - for (let i = parts.length - 1; i > 1; i--) { - if (parts[i]) { - return parts[i]; - } - } - return 'anonymous'; -} diff --git a/ts/nni_manager/common/command_channel/websocket/connection.ts b/ts/nni_manager/common/command_channel/websocket/connection.ts new file mode 100644 index 000000000..5d1478523 --- /dev/null +++ b/ts/nni_manager/common/command_channel/websocket/connection.ts @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +/** + * Internal helper class which handles one WebSocket connection. + **/ + +import { EventEmitter } from 'node:events'; +import util from 'node:util'; + +import type { WebSocket } from 'ws'; + +import type { Command } from 'common/command_channel/interface'; +import { Logger, getLogger } from 'common/log'; + +interface ConnectionEvents { + 'bye': (reason: string) => void; + 'close': (code: number, reason: string) => void; + 'error': (error: Error) => void; +} + +export declare interface WsConnection { + on(event: E, listener: ConnectionEvents[E]): this; +} + +export class WsConnection extends EventEmitter { + private closing: boolean = false; + private commandEmitter: EventEmitter; + private heartbeatTimer: NodeJS.Timer | null = null; + private log: Logger; + private missingPongs: number = 0; + private ws: WebSocket; // NOTE: used in unit test + + constructor(name: string, ws: WebSocket, commandEmitter: EventEmitter) { + super(); + this.log = getLogger(`WsConnection.${name}`); + this.ws = ws; + this.commandEmitter = commandEmitter; + + ws.on('close', this.handleClose.bind(this)); + ws.on('error', this.handleError.bind(this)); + ws.on('message', this.handleMessage.bind(this)); + ws.on('pong', this.handlePong.bind(this)); + } + + public setHeartbeatInterval(interval: number): void { + if (this.heartbeatTimer) { + clearTimeout(this.heartbeatTimer); + } + this.heartbeatTimer = setInterval(this.heartbeat.bind(this), interval); + } + + public async close(reason: string): Promise { + if (this.closing) { + this.log.debug('Close again:', reason); + return; + } + + this.log.debug('Close connection:', reason); + this.closing = true; + if (this.heartbeatTimer) { + clearInterval(this.heartbeatTimer); + this.heartbeatTimer = null; + } + + try { + await this.sendAsync({ type: '_bye_', reason }); + } catch (error) { + this.log.error('Failed to send bye:', error); + } + + try { + this.ws.close(4000, reason); + return; + } catch (error) { + this.log.error('Failed to close socket:', error); + } + + try { + this.ws.terminate(); + } catch (error) { + this.log.debug('Failed to terminate socket:', error); + } + } + + public terminate(reason: string): void { + this.log.debug('Terminate connection:', reason); + this.closing = true; + if (this.heartbeatTimer) { + clearInterval(this.heartbeatTimer); + this.heartbeatTimer = null; + } + + try { + this.ws.close(4001, reason); + return; + } catch (error) { + this.log.debug('Failed to close socket:', error); + } + + try { + this.ws.terminate(); + } catch (error) { + this.log.debug('Failed to terminate socket:', error); + } + } + + public send(command: Command): void { + this.log.trace('Send command', command); + this.ws.send(JSON.stringify(command)); + } + + public sendAsync(command: Command): Promise { + this.log.trace('(async) Send command', command); + const send: any = util.promisify(this.ws.send.bind(this.ws)); + return send(JSON.stringify(command)); + } + + private handleClose(code: number, reason: Buffer): void { + if (this.closing) { + this.log.debug('Connection closed'); + } else { + this.log.debug('Connection closed by peer:', code, String(reason)); + this.emit('close', code, String(reason)); + } + } + + private handleError(error: Error): void { + if (this.closing) { + this.log.warning('Error after closing:', error); + } else { + this.log.error('Connection error:', error); + this.emit('error', error); + } + } + + private handleMessage(data: Buffer, _isBinary: boolean): void { + const s = String(data); + if (this.closing) { + this.log.warning('Received message after closing:', s); + return; + } + + this.log.trace('Receive command', s); + const command = JSON.parse(s); + + if (command.type === '_nop_') { + return; + } + + if (command.type === '_bye_') { + this.log.debug('Intentionally close connection:', s); + this.closing = true; + this.emit('bye', command.reason); + return; + } + + const hasReceiveListener = this.commandEmitter.emit('__receive', command); + const hasCommandListener = this.commandEmitter.emit(command.type, command); + if (!hasReceiveListener && !hasCommandListener) { + this.log.warning('No listener for command', s); + } + } + + private handlePong(): void { + this.log.trace('Receive pong'); + this.missingPongs = 0; + } + + private heartbeat(): void { + if (this.missingPongs > 0) { + this.log.warning('Missing pong'); + } + + if (this.missingPongs > 3) { // TODO: make it configurable? + // no response for ping, try real command + this.sendAsync({ type: '_nop_' }).then(() => { + this.missingPongs = 0; + }).catch(error => { + this.log.error('Failed sending command. Drop connection:', error); + this.terminate(`peer lost responsive: ${util.inspect(error)}`); + }); + } + + this.missingPongs += 1; + this.log.trace('Send ping'); + this.ws.ping(); + } +} diff --git a/ts/nni_manager/common/command_channel/websocket/server.ts b/ts/nni_manager/common/command_channel/websocket/server.ts index 867a7cab6..133595f73 100644 --- a/ts/nni_manager/common/command_channel/websocket/server.ts +++ b/ts/nni_manager/common/command_channel/websocket/server.ts @@ -3,29 +3,6 @@ /** * WebSocket command channel server. - * - * The server will specify a URL prefix like `ws://1.2.3.4:8080/SERVER_PREFIX`, - * and each client will append a channel ID, like `ws://1.2.3.4:8080/SERVER_PREFIX/CHANNEL_ID`. - * - * const server = new WsChannelServer('example', 'SERVER_PREFIX'); - * const url = server.getChannelUrl('CHANNEL_ID'); - * const client = new WsChannelClient(url); - * await server.start(); - * await client.connect(); - * - * There two styles to use the server: - * - * 1. Handle all clients' commands in one space: - * - * server.onReceive((channelId, command) => { ... }); - * server.send(channelId, command); - * - * 2. Maintain a `WsChannel` instance for each client: - * - * server.onConnection((channelId, channel) => { - * channel.onCommand(command => { ... }); - * channel.send(command); - * }); **/ import { EventEmitter } from 'events'; @@ -39,22 +16,18 @@ import globals from 'common/globals'; import { Logger, getLogger } from 'common/log'; import { WsChannel } from './channel'; -let heartbeatInterval: number = 5000; - type ReceiveCallback = (channelId: string, command: Command) => void; export class WsChannelServer extends EventEmitter { private channels: Map = new Map(); - private ip: string; private log: Logger; private path: string; private receiveCallbacks: ReceiveCallback[] = []; - constructor(name: string, urlPath: string, ip?: string) { + constructor(name: string, urlPath: string) { super(); this.log = getLogger(`WsChannelServer.${name}`); this.path = urlPath; - this.ip = ip ?? 'localhost'; } public async start(): Promise { @@ -67,7 +40,7 @@ export class WsChannelServer extends EventEmitter { const deferred = new Deferred(); this.channels.forEach((channel, channelId) => { - channel.on('close', (_reason) => { + channel.onClose(_reason => { this.channels.delete(channelId); if (this.channels.size === 0) { deferred.resolve(); @@ -77,17 +50,16 @@ export class WsChannelServer extends EventEmitter { }); // wait for at most 5 seconds - // use heartbeatInterval here for easier unit test setTimeout(() => { this.log.debug('Shutdown timeout. Stop waiting following channels:', Array.from(this.channels.keys())); deferred.resolve(); - }, heartbeatInterval); + }, 5000); return deferred.promise; } - public getChannelUrl(channelId: string): string { - return globals.rest.getFullUrl('ws', this.ip, this.path, channelId); + public getChannelUrl(channelId: string, ip?: string): string { + return globals.rest.getFullUrl('ws', ip ?? 'localhost', this.path, channelId); } public send(channelId: string, command: Command): void { @@ -104,7 +76,7 @@ export class WsChannelServer extends EventEmitter { // because by this way it can detect and warning if a command is never listened this.receiveCallbacks.push(callback); for (const [channelId, channel] of this.channels) { - channel.onCommand(command => { callback(channelId, command); }); + channel.onReceive(command => { callback(channelId, command); }); } } @@ -118,33 +90,30 @@ export class WsChannelServer extends EventEmitter { if (this.channels.has(channelId)) { this.log.warning(`Channel ${channelId} reconnecting, drop previous connection`); - this.channels.get(channelId)!.setConnection(ws); + this.channels.get(channelId)!.setConnection(ws, false); return; } - const channel = new WsChannel(channelId, ws, heartbeatInterval); + const channel = new WsChannel(channelId); this.channels.set(channelId, channel); - channel.on('close', reason => { + channel.onClose(reason => { this.log.debug(`Connection ${channelId} closed:`, reason); this.channels.delete(channelId); }); - channel.on('error', error => { + channel.onError(error => { this.log.error(`Connection ${channelId} error:`, error); this.channels.delete(channelId); }); for (const cb of this.receiveCallbacks) { - channel.on('command', command => { cb(channelId, command); }); + channel.onReceive(command => { cb(channelId, command); }); } + channel.enableHeartbeat(); + channel.setConnection(ws, false); + this.emit('connection', channelId, channel); } } - -export namespace UnitTestHelpers { - export function setHeartbeatInterval(ms: number): void { - heartbeatInterval = ms; - } -} diff --git a/ts/nni_manager/common/globals/unittest.ts b/ts/nni_manager/common/globals/unittest.ts index bfbd8c023..d11c408b2 100644 --- a/ts/nni_manager/common/globals/unittest.ts +++ b/ts/nni_manager/common/globals/unittest.ts @@ -98,5 +98,5 @@ if (isUnitTest()) { resetGlobals(); } -const globals: MutableGlobals = (global as any).nni; +export const globals: MutableGlobals = (global as any).nni; export default globals; diff --git a/ts/nni_manager/common/trial_keeper/main.ts b/ts/nni_manager/common/trial_keeper/main.ts index 56eeeb6e3..d3e950598 100644 --- a/ts/nni_manager/common/trial_keeper/main.ts +++ b/ts/nni_manager/common/trial_keeper/main.ts @@ -72,13 +72,13 @@ async function main(): Promise { logger.debug('command:', process.argv); logger.debug('config:', config); - const client = new WsChannelClient(args.managerCommandChannel, args.environmentId); - client.enableHeartbeat(5000); - client.on('close', reason => { + const client = new WsChannelClient(args.environmentId, args.managerCommandChannel); + client.enableHeartbeat(); + client.onClose(reason => { logger.info('Manager closed connection:', reason); globals.shutdown.initiate('Connection end'); }); - client.on('error', error => { + client.onError(error => { logger.info('Connection error:', error); globals.shutdown.initiate('Connection error'); }); diff --git a/ts/nni_manager/package.json b/ts/nni_manager/package.json index 56d22d8b5..c8d7ad7b7 100644 --- a/ts/nni_manager/package.json +++ b/ts/nni_manager/package.json @@ -4,7 +4,7 @@ "license": "MIT", "scripts": { "build": "tsc", - "test": "nyc --reporter=cobertura --reporter=text mocha test/**/*.test.ts", + "test": "nyc --reporter=cobertura --reporter=text mocha \"test/**/*.test.ts\"", "test_nnimanager": "nyc --reporter=cobertura --reporter=text mocha test/core/nnimanager.test.ts", "mocha": "mocha", "eslint": "eslint . --ext .ts" diff --git a/ts/nni_manager/rest_server/core.ts b/ts/nni_manager/rest_server/core.ts index df021f2dd..f524f5516 100644 --- a/ts/nni_manager/rest_server/core.ts +++ b/ts/nni_manager/rest_server/core.ts @@ -61,7 +61,7 @@ export class RestServerCore { return deferred.promise; } - public shutdown(): Promise { + public shutdown(timeoutMilliseconds?: number): Promise { logger.info('Stopping REST server.'); if (this.server === null) { logger.warning('REST server is not running.'); @@ -77,7 +77,7 @@ export class RestServerCore { logger.debug('Killing connections'); this.server?.closeAllConnections(); } - }, 5000); + }, timeoutMilliseconds ?? 5000); return deferred.promise; } } diff --git a/ts/nni_manager/test/common/command_channel/websocket.test.ts b/ts/nni_manager/test/common/command_channel/websocket.test.ts new file mode 100644 index 000000000..b7d05f2d4 --- /dev/null +++ b/ts/nni_manager/test/common/command_channel/websocket.test.ts @@ -0,0 +1,209 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +import assert from 'node:assert/strict'; +import { setTimeout } from 'node:timers/promises'; + +import type { + Command, CommandChannel, CommandChannelClient, CommandChannelServer +} from 'common/command_channel/interface'; +import { UnitTestHelper as Helper } from 'common/command_channel/websocket/channel'; +import { WsChannelClient, WsChannelServer } from 'common/command_channel/websocket/index'; +import { globals } from 'common/globals/unittest'; +import { RestServerCore } from 'rest_server/core'; + +describe('## websocket command channel ##', () => { + before(beforeHook); + + it('start', testServerStart); + + it('connect', testClientStart); + it('message', testMessage); + + it('reconnect', testReconnect); + it('message', testMessage); + + it('handle error', testError); + it('shutdown', testShutdown); + + after(afterHook); +}); + +/* test cases */ + +async function testServerStart(): Promise { + ut.server = new WsChannelServer('ut_server', 'ut'); + + ut.server.onReceive((channelId, command) => { + if (channelId === '1') { + ut.events.push({ event: 'server_receive_1', command }); + } + }); + + ut.server.onConnection((channelId, channel) => { + ut.events.push({ event: 'connect', channelId, channel }); + + channel.onClose(reason => { + ut.events.push({ event: `client_close_${channelId}`, reason }); + }); + channel.onError(error => { + ut.events.push({ event: `client_error_${channelId}`, error }); + }); + channel.onLost(() => { + ut.events.push({ event: `client_lost_${channelId}` }); + }); + + if (channelId === '1') { + ut.serverChannel1 = channel; + } + + if (channelId === '2') { + ut.serverChannel2 = channel; + channel.onReceive(command => { + ut.events.push({ event: 'server_receive_2', command }); + }); + } + }); + + await ut.server.start(); +} + +async function testClientStart(): Promise { + const url1 = ut.server.getChannelUrl('1'); + const url2 = ut.server.getChannelUrl('2', '127.0.0.1'); + assert.equal(url1, `ws://localhost:${globals.args.port}/ut/1`); + assert.equal(url2, `ws://127.0.0.1:${globals.args.port}/ut/2`); + + ut.client1 = new WsChannelClient('ut_client_1', url1); + ut.client2 = new WsChannelClient('ut_client_2', url2); + + ut.client1.onReceive(command => { + ut.events.push({ event: 'client_receive_1', command }); + }); + ut.client2.onCommand('ut_command', command => { + ut.events.push({ event: 'client_receive_2', command }); + }); + + ut.client2.onClose(reason => { + ut.events.push({ event: 'server_close_2', reason }); + }); + + await Promise.all([ + ut.client1.connect(), + ut.client2.connect(), + ]); + + assert.equal(ut.events[0].event, 'connect'); + assert.equal(ut.events[1].event, 'connect'); + assert.equal(Number(ut.events[0].channelId) + Number(ut.events[1].channelId), 3); + assert.equal(ut.events.length, 2); + + ut.events.length = 0; +} + +async function testReconnect(): Promise { + const ws = (ut.client1 as any).connection.ws; // NOTE: private api + ws.pause(); + await setTimeout(heartbeatTimeout); + ws.terminate(); + ws.resume(); + + // mac pipeline can be slow + for (let i = 0; i < 10; i++) { + await setTimeout(heartbeat); + if (ut.events.length > 0) { + break; + } + } + + assert.ok(ut.countEvents('client_lost_1') >= 1); + assert.ok(ut.countEvents('client_close_1') == 0); + assert.ok(ut.countEvents('client_error_1') == 0); + assert.ok(ut.countEvents('connect') == 0); // reconnect is not connect + + ut.events.length = 0; +} + +async function testMessage(): Promise { + ut.server.send('1', ut.packCommand(1)); + await ut.client2.sendAsync(ut.packCommand(2)); + ut.client2.send(ut.packCommand('三')); + ut.server.send('2', ut.packCommand('4')); + ut.client1.send(ut.packCommand(5)); + ut.server.send('1', ut.packCommand(6)); + + await setTimeout(heartbeat); + + assert.deepEqual(ut.filterCommands('client_receive_1'), [ 1, 6 ]); + assert.deepEqual(ut.filterCommands('client_receive_2'), [ '4' ]); + assert.deepEqual(ut.filterCommands('server_receive_1'), [ 5 ]); + assert.deepEqual(ut.filterCommands('server_receive_2'), [ 2, '三' ]); + + ut.events.length = 0; +} + +async function testError(): Promise { + ut.client2.terminate('client 2 terminate'); + await setTimeout(terminateTimeout * 1.1); + + assert.ok(ut.countEvents('client_close_2') == 0); + assert.ok(ut.countEvents('client_error_2') == 1); + + ut.events.length = 0; +} + +async function testShutdown(): Promise { + await ut.server.shutdown(); + + assert.equal(ut.countEvents('client_close_1'), 1); + + ut.events.length = 0; +} + +/* helpers and states */ + +// NOTE: Increase these numbers if it fails randomly +const heartbeat = 10; +const heartbeatTimeout = 50; +const terminateTimeout = 100; + +async function beforeHook(): Promise { + globals.reset(); + + ut.rest = new RestServerCore(); + await ut.rest.start(); + + Helper.setHeartbeatInterval(heartbeat); + Helper.setTerminateTimeout(terminateTimeout); +} + +async function afterHook(): Promise { + Helper.reset(); + await ut.rest?.shutdown(); + globals.reset(); +} + +class UnitTestStates { + server!: CommandChannelServer; + client1!: CommandChannelClient; + client2!: CommandChannelClient; + serverChannel1!: CommandChannel; + serverChannel2!: CommandChannel; + events: any[] = []; + + rest!: RestServerCore; + + countEvents(event: string): number { + return this.events.filter(e => (e.event === event)).length; + } + + filterCommands(event: string): any[] { + return this.events.filter(e => (e.event === event)).map(e => e.command.value); + } + + packCommand(value: any): Command { + return { type: 'ut_command', value }; + } +} + +const ut = new UnitTestStates(); diff --git a/ts/nni_manager/training_service/remote_v3/remote.ts b/ts/nni_manager/training_service/remote_v3/remote.ts index 23489d243..b731f484f 100644 --- a/ts/nni_manager/training_service/remote_v3/remote.ts +++ b/ts/nni_manager/training_service/remote_v3/remote.ts @@ -39,16 +39,16 @@ export class RemoteTrainingServiceV3 implements TrainingServiceV3 { this.log = getLogger(`RemoteV3.${this.id}`); this.log.debug('Training sevice config:', config); - this.server = new WsChannelServer('RemoteTrialKeeper', `platform/${this.id}`, config.nniManagerIp); + this.server = new WsChannelServer(this.id, `/platform/${this.id}`); this.server.on('connection', (channelId: string, channel: WsChannel) => { const worker = this.workersByChannel.get(channelId); if (worker) { worker.setChannel(channel); - channel.on('close', reason => { + channel.onClose(reason => { this.log.error('Worker channel closed unexpectedly:', reason); }); - channel.on('error', error => { + channel.onError(error => { this.log.error('Worker channel error:', error); this.restartWorker(worker); }); @@ -190,7 +190,7 @@ export class RemoteTrainingServiceV3 implements TrainingServiceV3 { this.id, channelId, config, - this.server.getChannelUrl(channelId), + this.server.getChannelUrl(channelId, this.config.nniManagerIp), Boolean(this.config.trialGpuNumber) ); diff --git a/ts/nni_manager/training_service/remote_v3/worker.ts b/ts/nni_manager/training_service/remote_v3/worker.ts index 5ca54b6e2..97a0f58e1 100644 --- a/ts/nni_manager/training_service/remote_v3/worker.ts +++ b/ts/nni_manager/training_service/remote_v3/worker.ts @@ -57,7 +57,7 @@ export class Worker { public setChannel(channel: WsChannel): void { this.channel = channel; this.trialKeeper.setChannel(channel); - channel.on('lost', async () => { + channel.onLost(async () => { if (!await this.checkAlive()) { this.log.error('Trial keeper failed'); channel.terminate('Trial keeper failed'); // MARK