diff --git a/nni/runtime/command_channel/websocket/channel.py b/nni/runtime/command_channel/websocket/channel.py index 90cc767eb..955fd26bc 100644 --- a/nni/runtime/command_channel/websocket/channel.py +++ b/nni/runtime/command_channel/websocket/channel.py @@ -6,7 +6,6 @@ from __future__ import annotations import logging import time -import nni from ..base import Command, CommandChannel from .connection import WsConnection @@ -25,19 +24,25 @@ class WsChannelClient(CommandChannel): def disconnect(self) -> None: _logger.debug(f'Disconnect from {self._url}') - self.send({'type': '_bye_'}) - self._closing = True - self._close_conn('client intentionally close') + if self._closing: + _logger.debug('Already closing') + else: + try: + if self._conn is not None: + self._conn.send({'type': '_bye_'}) + except Exception as e: + _logger.debug(f'Failed to send bye: {repr(e)}') + self._closing = True + self._close_conn('client intentionally close') def send(self, command: Command) -> None: if self._closing: return _logger.debug(f'Send {command}') - msg = nni.dump(command) for i in range(5): try: conn = self._ensure_conn() - conn.send(msg) + conn.send(command) return except Exception: _logger.exception(f'Failed to send command. Retry in {i}s') @@ -45,16 +50,15 @@ class WsChannelClient(CommandChannel): time.sleep(i) _logger.warning(f'Failed to send command {command}. Last retry') conn = self._ensure_conn() - conn.send(msg) + conn.send(command) def receive(self) -> Command | None: while True: if self._closing: return None - msg = self._receive_msg() - if msg is None: + command = self._receive_command() + if command is None: return None - command = nni.load(msg) if command['type'] == '_nop_': continue if command['type'] == '_bye_': @@ -88,15 +92,14 @@ class WsChannelClient(CommandChannel): pass self._conn = None - def _receive_msg(self) -> str | None: + def _receive_command(self) -> Command | None: for i in range(5): try: conn = self._ensure_conn() - msg = conn.receive() - _logger.debug(f'Receive {msg}') + command = conn.receive() if not self._closing: - assert msg is not None - return msg + assert command is not None + return command except Exception: _logger.exception(f'Failed to receive command. Retry in {i}s') self._terminate_conn('receive fail') diff --git a/nni/runtime/command_channel/websocket/connection.py b/nni/runtime/command_channel/websocket/connection.py index 099098090..d6e7d3353 100644 --- a/nni/runtime/command_channel/websocket/connection.py +++ b/nni/runtime/command_channel/websocket/connection.py @@ -18,6 +18,9 @@ from typing import Any, Type import websockets +import nni +from ..base import Command + _logger = logging.getLogger(__name__) # the singleton event loop @@ -81,17 +84,17 @@ class WsConnection: return self.disconnect(reason, 4001) - def send(self, message: str) -> None: + def send(self, message: Command) -> None: _logger.debug(f'Sending {message}') try: - _wait(self._ws.send(message)) + _wait(self._ws.send(nni.dump(message))) except websockets.ConnectionClosed: # type: ignore _logger.debug('Connection closed by server.') self._ws = None _decrease_refcnt() raise - def receive(self) -> str | None: + def receive(self) -> Command | None: """ Return received message; or return ``None`` if the connection has been closed by peer. @@ -105,11 +108,12 @@ class WsConnection: _decrease_refcnt() raise + if msg is None: + return None # seems the library will inference whether it's text or binary, so we don't have guarantee if isinstance(msg, bytes): - return msg.decode() - else: - return msg + msg = msg.decode() + return nni.load(msg) def _wait(coro): # Synchronized version of "await". diff --git a/nni/runtime/tuner_command_channel/channel.py b/nni/runtime/tuner_command_channel/channel.py index f8f600e6a..947de01f2 100644 --- a/nni/runtime/tuner_command_channel/channel.py +++ b/nni/runtime/tuner_command_channel/channel.py @@ -11,19 +11,18 @@ __all__ = ['TunerCommandChannel'] import logging import os -import time from collections import defaultdict from threading import Event from typing import Any, Callable from nni.common.serializer import dump, load, PayloadTooLarge +from nni.runtime.command_channel.websocket import WsChannelClient from nni.typehint import Parameters from .command_type import ( CommandType, TunerIncomingCommand, Initialize, RequestTrialJobs, UpdateSearchSpace, ReportMetricData, TrialEnd, Terminate ) -from .websocket import WebSocket _logger = logging.getLogger(__name__) @@ -52,10 +51,7 @@ class TunerCommandChannel: """ def __init__(self, url: str): - self._url = url - self._channel = WebSocket(url) - self._retry_intervals = [0, 1, 10] - + self._channel = WsChannelClient(url) self._callbacks: dict[CommandType, list[Callable[..., None]]] = defaultdict(list) def connect(self) -> None: @@ -268,61 +264,14 @@ class TunerCommandChannel: self._callbacks[TrialEnd.command_type].append(callback) def _send(self, command_type: CommandType, data: str) -> None: - command = command_type.value.decode() + data - try: - self._channel.send(command) - except Exception as e: - _logger.warning('Exception on sending: %r', e) - if not isinstance(e, WebSocket.ConnectionClosed): - _logger.exception(e) - self._retry_send(command) - - def _retry_send(self, command: str) -> None: - _logger.warning('Connection lost. Trying to reconnect...') - for i, interval in enumerate(self._retry_intervals): - _logger.info(f'Attempt #{i}, wait {interval} seconds...') - time.sleep(interval) - self._channel = WebSocket(self._url) - self._channel.connect() - try: - self._channel.send(command) - _logger.info('Reconnected.') - return - except Exception as e: - _logger.exception(e) - _logger.error('Failed to reconnect.') - raise RuntimeError('Connection lost') + self._channel.send({'type': command_type.value, 'content': data}) def _receive(self) -> tuple[CommandType, str] | tuple[None, None]: - try: - command = self._channel.receive() - except Exception as e: - _logger.warning('Exception on receiving: %r', e) - if not isinstance(e, WebSocket.ConnectionClosed): - _logger.exception(e) - command = None + command = self._channel.receive() if command is None: - command = self._retry_receive() - command_type = CommandType(command[:2].encode()) - return command_type, command[2:] - - def _retry_receive(self) -> str: - _logger.warning('Connection lost. Trying to reconnect...') - for i, interval in enumerate(self._retry_intervals): - _logger.info(f'Attempt #{i}, wait {interval} seconds...') - time.sleep(interval) - self._channel = WebSocket(self._url) - self._channel.connect() - try: - command = self._channel.receive() - except Exception as e: - _logger.exception(e) - command = None # for robustness - if command is not None: - _logger.info('Reconnected') - return command - _logger.error('Failed to reconnect.') - raise RuntimeError('Connection lost') + return None, None + else: + return CommandType(command['type']), command.get('content', '') def _validate_placement_constraint(placement_constraint): diff --git a/nni/runtime/tuner_command_channel/command_type.py b/nni/runtime/tuner_command_channel/command_type.py index b47b9d991..abfb2d97c 100644 --- a/nni/runtime/tuner_command_channel/command_type.py +++ b/nni/runtime/tuner_command_channel/command_type.py @@ -10,23 +10,23 @@ from nni.utils import MetricType class CommandType(Enum): # in - Initialize = b'IN' - RequestTrialJobs = b'GE' - ReportMetricData = b'ME' - UpdateSearchSpace = b'SS' - ImportData = b'FD' - AddCustomizedTrialJob = b'AD' - TrialEnd = b'EN' - Terminate = b'TE' - Ping = b'PI' + Initialize = 'IN' + RequestTrialJobs = 'GE' + ReportMetricData = 'ME' + UpdateSearchSpace = 'SS' + ImportData = 'FD' + AddCustomizedTrialJob = 'AD' + TrialEnd = 'EN' + Terminate = 'TE' + Ping = 'PI' # out - Initialized = b'ID' - NewTrialJob = b'TR' - SendTrialJobParameter = b'SP' - NoMoreTrialJobs = b'NO' - KillTrialJob = b'KI' - Error = b'ER' + Initialized = 'ID' + NewTrialJob = 'TR' + SendTrialJobParameter = 'SP' + NoMoreTrialJobs = 'NO' + KillTrialJob = 'KI' + Error = 'ER' class TunerIncomingCommand: # For type checking. diff --git a/nni/runtime/tuner_command_channel/legacy.py b/nni/runtime/tuner_command_channel/legacy.py index d4e8a8f36..f0498d255 100644 --- a/nni/runtime/tuner_command_channel/legacy.py +++ b/nni/runtime/tuner_command_channel/legacy.py @@ -61,7 +61,7 @@ def send(command, data): try: _lock.acquire() data = data.encode('utf8') - msg = b'%b%014d%b' % (command.value, len(data), data) + msg = b'%b%014d%b' % (command.value.encode(), len(data), data) _logger.debug('Sending command, data: [%s]', msg) _out_file.write(msg) _out_file.flush() @@ -81,7 +81,7 @@ def receive(): return None, None length = int(header[2:]) data = _in_file.read(length) - command = CommandType(header[:2]) + command = CommandType(header[:2].decode()) data = data.decode('utf8') _logger.debug('Received command, data: [%s]', data) return command, data diff --git a/nni/runtime/tuner_command_channel/websocket.py b/nni/runtime/tuner_command_channel/websocket.py deleted file mode 100644 index afc1790f6..000000000 --- a/nni/runtime/tuner_command_channel/websocket.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from ..command_channel.websocket.connection import WsConnection as WebSocket # pylint: disable=unused-import diff --git a/pipelines/fast-test.yml b/pipelines/fast-test.yml index 84749f50d..a3150503d 100644 --- a/pipelines/fast-test.yml +++ b/pipelines/fast-test.yml @@ -150,7 +150,7 @@ stages: - script: | set -e - npm --prefix ts/nni_manager run test -- --exclude test/core/nnimanager.test.ts + npm --prefix ts/nni_manager run test npm --prefix ts/nni_manager run test_nnimanager cp ts/nni_manager/coverage/cobertura-coverage.xml coverage/typescript.xml displayName: TypeScript unit test @@ -198,7 +198,7 @@ stages: - script: | export PATH=${PWD}/toolchain/node/bin:$PATH - npm --prefix ts/nni_manager run test -- --exclude test/core/nnimanager.test.ts + npm --prefix ts/nni_manager run test npm --prefix ts/nni_manager run test_nnimanager displayName: TypeScript unit test @@ -223,7 +223,7 @@ stages: # temporarily disable this test, add it back after bug fixed - script: | - npm --prefix ts/nni_manager run test -- --exclude test/core/nnimanager.test.ts + npm --prefix ts/nni_manager run test npm --prefix ts/nni_manager run test_nnimanager displayName: TypeScript unit test @@ -249,7 +249,7 @@ stages: displayName: Python unit test - script: | - CI=true npm --prefix ts/nni_manager run test -- --exclude test/core/nnimanager.test.ts + CI=true npm --prefix ts/nni_manager run test # # exclude nnimanager's ut because macos in pipeline is pretty slow # CI=true npm --prefix ts/nni_manager run test_nnimanager displayName: TypeScript unit test diff --git a/ts/nni_manager/.eslintrc b/ts/nni_manager/.eslintrc index 5cdba21a4..41682a18e 100644 --- a/ts/nni_manager/.eslintrc +++ b/ts/nni_manager/.eslintrc @@ -28,7 +28,7 @@ "@typescript-eslint/no-non-null-assertion": 0, "@typescript-eslint/no-unused-vars": [ - "off", + "error", { "argsIgnorePattern": "^_" } diff --git a/ts/nni_manager/common/command_channel/rpc_util.ts b/ts/nni_manager/common/command_channel/rpc_util.ts index f162f6e13..777632c19 100644 --- a/ts/nni_manager/common/command_channel/rpc_util.ts +++ b/ts/nni_manager/common/command_channel/rpc_util.ts @@ -45,11 +45,9 @@ import util from 'node:util'; -import type { Command } from 'common/command_channel/interface'; import { DefaultMap } from 'common/default_map'; import { Deferred } from 'common/deferred'; import { Logger, getLogger } from 'common/log'; -import type { TrialKeeper } from 'common/trial_keeper/keeper'; import type { CommandChannel } from './interface'; interface RpcResponseCommand { diff --git a/ts/nni_manager/common/command_channel/websocket/channel.ts b/ts/nni_manager/common/command_channel/websocket/channel.ts index b9591db58..c67ce49eb 100644 --- a/ts/nni_manager/common/command_channel/websocket/channel.ts +++ b/ts/nni_manager/common/command_channel/websocket/channel.ts @@ -155,6 +155,11 @@ export class WsChannel implements CommandChannel { this.emitter.on('__lost', callback); } + // TODO: temporary api for tuner command channel + public getBufferedAmount(): number { + return this.connection?.ws.bufferedAmount ?? 0; + } + private newEpoch(): void { this.connection = null; this.epoch += 1; diff --git a/ts/nni_manager/common/command_channel/websocket/client.ts b/ts/nni_manager/common/command_channel/websocket/client.ts index 3a2674020..1291b3aa0 100644 --- a/ts/nni_manager/common/command_channel/websocket/client.ts +++ b/ts/nni_manager/common/command_channel/websocket/client.ts @@ -5,7 +5,6 @@ * WebSocket command channel client. **/ -import events from 'node:events'; import { setTimeout } from 'node:timers/promises'; import { WebSocket } from 'ws'; diff --git a/ts/nni_manager/common/command_channel/websocket/connection.ts b/ts/nni_manager/common/command_channel/websocket/connection.ts index 5d1478523..84456284a 100644 --- a/ts/nni_manager/common/command_channel/websocket/connection.ts +++ b/ts/nni_manager/common/command_channel/websocket/connection.ts @@ -29,7 +29,8 @@ export class WsConnection extends EventEmitter { private heartbeatTimer: NodeJS.Timer | null = null; private log: Logger; private missingPongs: number = 0; - private ws: WebSocket; // NOTE: used in unit test + + public readonly ws: WebSocket; constructor(name: string, ws: WebSocket, commandEmitter: EventEmitter) { super(); diff --git a/ts/nni_manager/common/command_channel/websocket/server.ts b/ts/nni_manager/common/command_channel/websocket/server.ts index 133595f73..bb497f21b 100644 --- a/ts/nni_manager/common/command_channel/websocket/server.ts +++ b/ts/nni_manager/common/command_channel/websocket/server.ts @@ -7,7 +7,6 @@ import { EventEmitter } from 'events'; -import type { Request } from 'express'; import type { WebSocket } from 'ws'; import type { Command } from 'common/command_channel/interface'; @@ -32,7 +31,12 @@ export class WsChannelServer extends EventEmitter { public async start(): Promise { const channelPath = globals.rest.urlJoin(this.path, ':channel'); - globals.rest.registerWebSocketHandler(channelPath, this.handleConnection.bind(this)); + globals.rest.registerWebSocketHandler(this.path, (ws, _req) => { + this.handleConnection('__default__', ws); // TODO: only used by tuner + }); + globals.rest.registerWebSocketHandler(channelPath, (ws, req) => { + this.handleConnection(req.params['channel'], ws); + }); this.log.debug('Start listening', channelPath); } @@ -84,8 +88,7 @@ export class WsChannelServer extends EventEmitter { this.on('connection', callback); } - private handleConnection(ws: WebSocket, req: Request): void { - const channelId = req.params['channel']; + private handleConnection(channelId: string, ws: WebSocket): void { this.log.debug('Incoming connection', channelId); if (this.channels.has(channelId)) { diff --git a/ts/nni_manager/common/deferred.ts b/ts/nni_manager/common/deferred.ts index 8ea366fcd..15c1115b2 100644 --- a/ts/nni_manager/common/deferred.ts +++ b/ts/nni_manager/common/deferred.ts @@ -26,7 +26,7 @@ import util from 'util'; import { Logger, getLogger } from 'common/log'; -const logger = getLogger('common.deferred'); +const logger: Logger = getLogger('common.deferred'); export class Deferred { private resolveCallbacks: any[] = []; diff --git a/ts/nni_manager/common/experimentConfig.ts b/ts/nni_manager/common/experimentConfig.ts index 3d71886af..a28fe6d96 100644 --- a/ts/nni_manager/common/experimentConfig.ts +++ b/ts/nni_manager/common/experimentConfig.ts @@ -1,10 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -import assert from 'assert'; - import { KubeflowOperator, OperatorApiVersion } from '../training_service/kubernetes/kubeflow/kubeflowConfig' -import { KubernetesStorageKind } from '../training_service/kubernetes/kubernetesConfig'; export interface TrainingServiceConfig { platform: string; diff --git a/ts/nni_manager/common/globals/rest.ts b/ts/nni_manager/common/globals/rest.ts index acacebd52..236108294 100644 --- a/ts/nni_manager/common/globals/rest.ts +++ b/ts/nni_manager/common/globals/rest.ts @@ -6,8 +6,8 @@ * Functions will be added when used. **/ -import express, { Request, Response, Router } from 'express'; -import type { Router as WsRouter } from 'express-ws'; +import express, { Express, Request, Response, Router } from 'express'; +import expressWs, { Router as WsRouter } from 'express-ws'; import type { WebSocket } from 'ws'; type HttpMethod = 'GET' | 'PUT'; @@ -16,13 +16,23 @@ type ExpressCallback = (req: Request, res: Response) => void; type WebSocketCallback = (ws: WebSocket, req: Request) => void; export class RestManager { + private app: Express; private router: Router; constructor() { + // we don't actually need the app here, + // but expressWs() must be called before router.ws(), and it requires an app instance + this.app = express(); + expressWs(this.app, undefined, { wsOptions: { maxPayload: 4 * 1024 * 1024 * 1024 }}); + this.router = Router(); this.router.use(express.json({ limit: '50mb' })); } + public getExpressApp(): Express { + return this.app; + } + public getExpressRouter(): Router { return this.router; } diff --git a/ts/nni_manager/common/trial_keeper/rpc.ts b/ts/nni_manager/common/trial_keeper/rpc.ts index 3882e3217..83a5994ef 100644 --- a/ts/nni_manager/common/trial_keeper/rpc.ts +++ b/ts/nni_manager/common/trial_keeper/rpc.ts @@ -24,7 +24,6 @@ **/ import { EventEmitter } from 'node:events'; -import util from 'node:util'; import type { Command } from 'common/command_channel/interface'; import { RpcHelper, getRpcHelper } from 'common/command_channel/rpc_util'; diff --git a/ts/nni_manager/common/utils.ts b/ts/nni_manager/common/utils.ts index 765b407ba..4a372b150 100644 --- a/ts/nni_manager/common/utils.ts +++ b/ts/nni_manager/common/utils.ts @@ -9,7 +9,6 @@ import { ChildProcess, spawn, StdioOptions } from 'child_process'; import dgram from 'dgram'; import fs from 'fs'; import net from 'net'; -import os from 'os'; import path from 'path'; import * as timersPromises from 'timers/promises'; import { Deferred } from 'ts-deferred'; diff --git a/ts/nni_manager/core/ipcInterface.ts b/ts/nni_manager/core/ipcInterface.ts index 7b383ead9..7f266db68 100644 --- a/ts/nni_manager/core/ipcInterface.ts +++ b/ts/nni_manager/core/ipcInterface.ts @@ -1,15 +1,14 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -import { IpcInterface } from './tuner_command_channel/common'; -export { IpcInterface } from './tuner_command_channel/common'; -import * as shim from './tuner_command_channel/shim'; +import { IpcInterface, getTunerServer } from './tuner_command_channel'; +export { IpcInterface } from './tuner_command_channel'; let tunerDisabled: boolean = false; export async function createDispatcherInterface(): Promise { if (!tunerDisabled) { - return await shim.createDispatcherInterface(); + return getTunerServer(); } else { return new DummyIpcInterface(); } diff --git a/ts/nni_manager/core/nnimanager.ts b/ts/nni_manager/core/nnimanager.ts index b0392742b..149d20324 100644 --- a/ts/nni_manager/core/nnimanager.ts +++ b/ts/nni_manager/core/nnimanager.ts @@ -15,7 +15,7 @@ import { NNIManagerStatus, ProfileUpdateType, TrialJobStatistics } from '../common/manager'; import { - ExperimentConfig, LocalConfig, TrainingServiceConfig, toSeconds, toCudaVisibleDevices + ExperimentConfig, TrainingServiceConfig, toSeconds, toCudaVisibleDevices } from '../common/experimentConfig'; import { getExperimentsManager } from 'extensions/experiments_manager'; import { TensorboardManager } from '../common/tensorboardManager'; diff --git a/ts/nni_manager/core/tuner_command_channel.ts b/ts/nni_manager/core/tuner_command_channel.ts new file mode 100644 index 000000000..2d967e118 --- /dev/null +++ b/ts/nni_manager/core/tuner_command_channel.ts @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +import { EventEmitter } from 'node:events'; + +import { WsChannel, WsChannelServer } from 'common/command_channel/websocket'; +import { Deferred } from 'common/deferred'; +import { Logger, getLogger } from 'common/log'; + +export interface IpcInterface { + init(): Promise; + sendCommand(commandType: string, content?: string): void; + onCommand(listener: (commandType: string, content: string) => void): void; + onError(listener: (error: Error) => void): void; +} + +export function getTunerServer(): IpcInterface { + return server; +} + +const logger: Logger = getLogger('tuner_command_channel'); + +class TunerServer { + private channel!: WsChannel; + private connect: Deferred = new Deferred(); + private emitter: EventEmitter = new EventEmitter(); + private server: WsChannelServer; + + constructor() { + this.server = new WsChannelServer('tuner', 'tuner'); + this.server.onConnection((_channelId, channel) => { + this.channel = channel; + this.channel.onError(error => { + this.emitter.emit('error', error); + }); + this.channel.onReceive(command => { + if (command.type === 'ER') { + this.emitter.emit('error', new Error(command.content)); + } else { + this.emitter.emit('command', command.type, command.content ?? ''); + } + }); + this.connect.resolve(); + }); + this.server.start(); + } + + public init(): Promise { // wait connection + if (this.connect.settled) { + logger.debug('Initialized.'); + return Promise.resolve(); + } else { + logger.debug('Waiting connection...'); + // TODO: This is a quick fix. It should check tuner's process status instead. + setTimeout(() => { + if (!this.connect.settled) { + const msg = 'Tuner did not connect in 10 seconds. Please check tuner (dispatcher) log.'; + this.connect.reject(new Error('tuner_command_channel: ' + msg)); + } + }, 10000); + return this.connect.promise; + } + } + + // TODO: for unit test only + public async stop(): Promise { + await this.server.shutdown(); + } + + public sendCommand(commandType: string, content?: string): void { + if (commandType === 'PI') { // ping is handled with WebSocket protocol + return; + } + + if (this.channel.getBufferedAmount() > 1000) { + logger.warning('Sending too fast! Try to reduce the frequency of intermediate results.'); + } + + this.channel.send({ type: commandType, content }); + + if (commandType === 'TE') { + this.channel.close('TE command'); + this.server.shutdown(); + } + } + + public onCommand(listener: (commandType: string, content: string) => void): void { + this.emitter.on('command', listener); + } + + public onError(listener: (error: Error) => void): void { + this.emitter.on('error', listener); + } +} + +let server: TunerServer = new TunerServer(); + +export namespace UnitTestHelpers { + export function reset(): void { + server = new TunerServer(); + } + + export async function stop(): Promise { + await server.stop(); + } +} diff --git a/ts/nni_manager/core/tuner_command_channel/common.ts b/ts/nni_manager/core/tuner_command_channel/common.ts deleted file mode 100644 index 2f21b87dc..000000000 --- a/ts/nni_manager/core/tuner_command_channel/common.ts +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -export interface IpcInterface { - init(): Promise; - sendCommand(commandType: string, content?: string): void; - onCommand(listener: (commandType: string, content: string) => void): void; - onError(listener: (error: Error) => void): void; -} diff --git a/ts/nni_manager/core/tuner_command_channel/index.ts b/ts/nni_manager/core/tuner_command_channel/index.ts deleted file mode 100644 index 6acc3f5d2..000000000 --- a/ts/nni_manager/core/tuner_command_channel/index.ts +++ /dev/null @@ -1,4 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -export { getWebSocketChannel, serveWebSocket } from './websocket_channel'; diff --git a/ts/nni_manager/core/tuner_command_channel/shim.ts b/ts/nni_manager/core/tuner_command_channel/shim.ts deleted file mode 100644 index 16228ea74..000000000 --- a/ts/nni_manager/core/tuner_command_channel/shim.ts +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -import type { IpcInterface } from './common'; -import { WebSocketChannel, getWebSocketChannel } from './websocket_channel'; - -export async function createDispatcherInterface(): Promise { - return new WsIpcInterface(); -} - -class WsIpcInterface implements IpcInterface { - private channel: WebSocketChannel = getWebSocketChannel(); - private commandListener?: (commandType: string, content: string) => void; - private errorListener?: (error: Error) => void; - - constructor() { - this.channel.onCommand((command: string) => { - const commandType = command.slice(0, 2); - const content = command.slice(2); - if (commandType === 'ER') { - if (this.errorListener !== undefined) { - this.errorListener(new Error(content)); - } - } else { - if (this.commandListener !== undefined) { - this.commandListener(commandType, content); - } - } - }); - } - - public async init(): Promise { - await this.channel.init(); - } - - public sendCommand(commandType: string, content: string = ''): void { - if (commandType !== 'PI') { // ping is handled with WebSocket protocol - this.channel.sendCommand(commandType + content); - if (commandType === 'TE') { - this.channel.shutdown(); - } - } - } - - public onCommand(listener: (commandType: string, content: string) => void): void { - this.commandListener = listener; - } - - public onError(listener: (error: Error) => void): void { - this.errorListener = listener; - } -} diff --git a/ts/nni_manager/core/tuner_command_channel/websocket_channel.ts b/ts/nni_manager/core/tuner_command_channel/websocket_channel.ts deleted file mode 100644 index 33a7e1ec5..000000000 --- a/ts/nni_manager/core/tuner_command_channel/websocket_channel.ts +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -/** - * The IPC channel between NNI manager and tuner. - * - * TODO: - * 1. Merge with environment service's WebSocket channel. - * 2. Split import data command to avoid extremely long message. - * 3. Refactor message format. - **/ - -import assert from 'assert/strict'; -import { EventEmitter } from 'events'; - -import type WebSocket from 'ws'; - -import { Deferred } from 'common/deferred'; -import { Logger, getLogger } from 'common/log'; - -const logger: Logger = getLogger('tuner_command_channel.WebSocketChannel'); - -export interface WebSocketChannel { - init(): Promise; - shutdown(): Promise; - sendCommand(command: string): void; // maybe this should return Promise - onCommand(callback: (command: string) => void): void; - onError(callback: (error: Error) => void): void; -} - -/** - * Get the singleton tuner command channel. - * Remember to invoke ``await channel.init()`` before doing anything else. - **/ -export function getWebSocketChannel(): WebSocketChannel { - return channelSingleton; -} - -/** - * The callback to serve WebSocket connection request. Used by REST server module. - * If it is invoked more than once, the previous connection will be dropped. - **/ -export function serveWebSocket(ws: WebSocket): void { - channelSingleton.serveWebSocket(ws); -} - -class WebSocketChannelImpl implements WebSocketChannel { - private deferredInit: Deferred = new Deferred(); - private emitter: EventEmitter = new EventEmitter(); - private heartbeatTimer!: NodeJS.Timer; - private serving: boolean = false; - private waitingPong: boolean = false; - private ws!: WebSocket; - - public serveWebSocket(ws: WebSocket): void { - if (this.ws === undefined) { - logger.debug('Connected.'); - } else { - logger.warning('Reconnecting. Drop previous connection.'); - this.dropConnection('Reconnected'); - } - - this.serving = true; - - this.ws = ws; - this.ws.on('close', this.handleWsClose); - this.ws.on('error', this.handleWsError); - this.ws.on('message', this.handleWsMessage); - this.ws.on('pong', this.handleWsPong); - - this.heartbeatTimer = setInterval(this.heartbeat.bind(this), heartbeatInterval); - this.deferredInit.resolve(); - } - - public init(): Promise { - if (this.ws === undefined) { - logger.debug('Waiting connection...'); - // TODO: This is a quick fix. It should check tuner's process status instead. - setTimeout(() => { - if (!this.deferredInit.settled) { - const msg = 'Tuner did not connect in 10 seconds. Please check tuner (dispatcher) log.'; - this.deferredInit.reject(new Error('tuner_command_channel: ' + msg)); - } - }, 10000); - return this.deferredInit.promise; - - } else { - logger.debug('Initialized.'); - return Promise.resolve(); - } - } - - public async shutdown(): Promise { - if (this.ws === undefined) { - return; - } - clearInterval(this.heartbeatTimer); - this.serving = false; - this.emitter.removeAllListeners(); - } - - public sendCommand(command: string): void { - assert.ok(this.ws !== undefined); - - logger.debug('Sending', command); - this.ws.send(command); - - if (this.ws.bufferedAmount > command.length + 1000) { - logger.warning('Sending too fast! Try to reduce the frequency of intermediate results.'); - } - } - - public onCommand(callback: (command: string) => void): void { - this.emitter.on('command', callback); - } - - public onError(callback: (error: Error) => void): void { - this.emitter.on('error', callback); - } - - /* Following callbacks must be auto-binded arrow functions to be turned off */ - - private handleWsClose = (): void => { - this.handleError(new Error('tuner_command_channel: Tuner closed connection')); - } - - private handleWsError = (error: Error): void => { - this.handleError(error); - } - - private handleWsMessage = (data: Buffer, _isBinary: boolean): void => { - this.receive(data); - } - - private handleWsPong = (): void => { - this.waitingPong = false; - } - - private dropConnection(reason: string): void { - if (this.ws === undefined) { - return; - } - - this.serving = false; - this.waitingPong = false; - clearInterval(this.heartbeatTimer); - - this.ws.off('close', this.handleWsClose); - this.ws.off('error', this.handleWsError); - this.ws.off('message', this.handleWsMessage); - this.ws.off('pong', this.handleWsPong); - - this.ws.on('close', () => { - logger.info('Connection dropped'); - }); - this.ws.on('message', (data, _isBinary) => { - logger.error('Received message after reconnect:', data); - }); - this.ws.on('pong', () => { - logger.error('Received pong after reconnect.'); - }); - this.ws.close(1001, reason); - } - - private heartbeat(): void { - // if (this.waitingPong) { - // this.ws.terminate(); // this will trigger "close" event - // this.handleError(new Error('tuner_command_channel: Tuner loses responsive')); - // } - - this.waitingPong = true; - this.ws.ping(); - } - - private receive(data: Buffer): void { - logger.debug('Received', data); - this.emitter.emit('command', data.toString()); - } - - private handleError(error: Error): void { - if (!this.serving) { - logger.debug('Silent error:', error); - return; - } - logger.error('Error:', error); - - clearInterval(this.heartbeatTimer); - this.emitter.emit('error', error); - this.serving = false; - } -} - -let channelSingleton: WebSocketChannelImpl = new WebSocketChannelImpl(); - -let heartbeatInterval: number = 5000; - -export namespace UnitTestHelpers { - export function setHeartbeatInterval(ms: number): void { - heartbeatInterval = ms; - } - // NOTE: this function is only for unittest of nnimanager, - // because resuming an experiment should reset websocket channel. - export function resetChannelSingleton(): void { - channelSingleton = new WebSocketChannelImpl(); - } -} diff --git a/ts/nni_manager/extensions/experiments_manager/manager.ts b/ts/nni_manager/extensions/experiments_manager/manager.ts index f51aed971..be7081509 100644 --- a/ts/nni_manager/extensions/experiments_manager/manager.ts +++ b/ts/nni_manager/extensions/experiments_manager/manager.ts @@ -3,8 +3,6 @@ import assert from 'assert/strict'; import fs from 'fs'; -import os from 'os'; -import path from 'path'; import * as timersPromises from 'timers/promises'; import { Deferred } from 'ts-deferred'; diff --git a/ts/nni_manager/extensions/experiments_manager/utils.ts b/ts/nni_manager/extensions/experiments_manager/utils.ts index 4cd833017..a12c43e3c 100644 --- a/ts/nni_manager/extensions/experiments_manager/utils.ts +++ b/ts/nni_manager/extensions/experiments_manager/utils.ts @@ -8,8 +8,6 @@ import * as timersPromises from 'timers/promises'; import glob from 'glob'; import lockfile from 'lockfile'; -import globals from 'common/globals'; - const lockStale: number = 2000; const retry: number = 100; diff --git a/ts/nni_manager/main.ts b/ts/nni_manager/main.ts index 2ee8e92d9..c7c310c4a 100644 --- a/ts/nni_manager/main.ts +++ b/ts/nni_manager/main.ts @@ -21,13 +21,13 @@ import 'app-module-path/register'; // so we can use absolute path to import -import fs from 'fs'; - import { Container, Scope } from 'typescript-ioc'; +import { globals, initGlobals } from 'common/globals'; +initGlobals(); + import * as component from 'common/component'; import { Database, DataStore } from 'common/datastore'; -import globals, { initGlobals } from 'common/globals'; import { Logger, getLogger } from 'common/log'; import { Manager } from 'common/manager'; import { TensorboardManager } from 'common/tensorboardManager'; @@ -71,8 +71,6 @@ process.on('SIGINT', () => { globals.shutdown.initiate('SIGINT'); }); /* main */ -initGlobals(); - start().then(() => { logger.debug('start() returned.'); }).catch((error) => { diff --git a/ts/nni_manager/package.json b/ts/nni_manager/package.json index c8d7ad7b7..35f27133f 100644 --- a/ts/nni_manager/package.json +++ b/ts/nni_manager/package.json @@ -4,8 +4,8 @@ "license": "MIT", "scripts": { "build": "tsc", - "test": "nyc --reporter=cobertura --reporter=text mocha \"test/**/*.test.ts\"", - "test_nnimanager": "nyc --reporter=cobertura --reporter=text mocha test/core/nnimanager.test.ts", + "test": "nyc --reporter=cobertura --reporter=text mocha \"test/**/*.test.ts\" --exclude test/core/nnimanager.test.ts", + "test_nnimanager": "mocha test/core/nnimanager.test.ts", "mocha": "mocha", "eslint": "eslint . --ext .ts" }, diff --git a/ts/nni_manager/rest_server/core.ts b/ts/nni_manager/rest_server/core.ts index f524f5516..e5f647ea0 100644 --- a/ts/nni_manager/rest_server/core.ts +++ b/ts/nni_manager/rest_server/core.ts @@ -13,9 +13,6 @@ import assert from 'node:assert/strict'; import type { Server } from 'node:http'; import type { AddressInfo } from 'node:net'; -import express from 'express'; -import expressWs from 'express-ws'; - import { Deferred } from 'common/deferred'; import { globals } from 'common/globals'; import { Logger, getLogger } from 'common/log'; @@ -40,9 +37,7 @@ export class RestServerCore { public start(): Promise { logger.info(`Starting REST server at port ${this.port}, URL prefix: "/${this.urlPrefix}"`); - const app = express(); - expressWs(app, undefined, { wsOptions: { maxPayload: 4 * 1024 * 1024 * 1024 }}); - + const app = globals.rest.getExpressApp(); app.use('/' + this.urlPrefix, globals.rest.getExpressRouter()); app.all('/' + this.urlPrefix, (_req, res) => { res.status(404).send('Not Found'); }); app.all('*', (_req, res) => { res.status(404).send(`Outside prefix "/${this.urlPrefix}"`); }); diff --git a/ts/nni_manager/rest_server/index.ts b/ts/nni_manager/rest_server/index.ts index 3425e861d..cca777995 100644 --- a/ts/nni_manager/rest_server/index.ts +++ b/ts/nni_manager/rest_server/index.ts @@ -26,13 +26,11 @@ import type { AddressInfo } from 'net'; import path from 'path'; import express, { Request, Response, Router } from 'express'; -import expressWs from 'express-ws'; import httpProxy from 'http-proxy'; import { Deferred } from 'common/deferred'; import globals from 'common/globals'; import { Logger, getLogger } from 'common/log'; -import * as tunerCommandChannel from 'core/tuner_command_channel'; const logger: Logger = getLogger('RestServer'); @@ -60,9 +58,7 @@ export class RestServer { public start(): Promise { logger.info(`Starting REST server at port ${this.port}, URL prefix: "/${this.urlPrefix}"`); - const app = express(); - expressWs(app, undefined, { wsOptions: { maxPayload: 4 * 1024 * 1024 * 1024 }}); - + const app = globals.rest.getExpressApp(); app.use('/' + this.urlPrefix, mainRouter()); app.use('/' + this.urlPrefix, fallbackRouter()); app.all('*', (_req: Request, res: Response) => { res.status(404).send(`Outside prefix "/${this.urlPrefix}"`); }); @@ -110,10 +106,7 @@ export class RestServer { * In fact experiments management should have a separate prefix and module. **/ function mainRouter(): Router { - const router = globals.rest.getExpressRouter() as expressWs.Router; - - /* WebSocket APIs */ - router.ws('/tuner', (ws, _req, _next) => { tunerCommandChannel.serveWebSocket(ws); }); + const router = globals.rest.getExpressRouter(); /* Download log files */ // The REST API path "/logs" does not match file system path "/log". @@ -142,6 +135,9 @@ function fallbackRouter(): Router { /* 404 as catch-all */ router.all('*', (_req: Request, res: Response) => { res.status(404).send('Not Found'); }); + + // TODO: websocket 404 + return router; } diff --git a/ts/nni_manager/rest_server/restHandler.ts b/ts/nni_manager/rest_server/restHandler.ts index de1773e2b..27ab1e934 100644 --- a/ts/nni_manager/rest_server/restHandler.ts +++ b/ts/nni_manager/rest_server/restHandler.ts @@ -13,7 +13,6 @@ import { getLogger, Logger } from '../common/log'; import { ExperimentProfile, Manager, TrialJobStatistics } from '../common/manager'; import { getExperimentsManager } from 'extensions/experiments_manager'; import { TensorboardManager, TensorboardTaskInfo } from '../common/tensorboardManager'; -import { ValidationSchemas } from './restValidationSchemas'; import { getVersion } from '../common/utils'; import { MetricType } from '../common/datastore'; import { ProfileUpdateType } from '../common/manager'; @@ -294,11 +293,7 @@ class NNIRestHandler { private getTrialFile(router: Router): void { router.get('/trial-file/:id/:filename', async(req: Request, res: Response) => { - let encoding: string | null = null; const filename = req.params['filename']; - if (!filename.includes('.') || filename.match(/.*\.(txt|log)/g)) { - encoding = 'utf8'; - } this.nniManager.getTrialFile(req.params['id'], filename).then((content: Buffer | string) => { const contentType = content instanceof Buffer ? 'application/octet-stream' : 'text/plain'; res.header('Content-Type', contentType); diff --git a/ts/nni_manager/test/core/ipcInterface.test.ts b/ts/nni_manager/test/core/ipcInterface.test.ts deleted file mode 100644 index 1b13daea8..000000000 --- a/ts/nni_manager/test/core/ipcInterface.test.ts +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -'use strict'; - -import * as assert from 'assert'; -import { ChildProcess, spawn, StdioOptions } from 'child_process'; -import { Deferred } from 'ts-deferred'; -import { cleanupUnitTest, prepareUnitTest, getTunerProc } from '../../common/utils'; -import * as CommandType from '../../core/commands'; -import { createDispatcherInterface, IpcInterface } from '../../core/ipcInterface'; -import { NNIError } from '../../common/errors'; - -let sentCommands: { [key: string]: string }[] = []; -const receivedCommands: { [key: string]: string }[] = []; - -let rejectCommandType: Error | undefined; - -async function runProcess(): Promise { - // the process is intended to throw error, do not reject - const deferred: Deferred = new Deferred(); - - // create fake assessor process - const stdio: StdioOptions = ['ignore', 'pipe', process.stderr, 'pipe', 'pipe']; - const command: string[] = [ 'python', 'assessor.py' ]; - const proc: ChildProcess = getTunerProc(command, stdio, 'core/test', process.env); - // record its sent/received commands on exit - proc.on('error', (error: Error): void => { deferred.resolve(error); }); - proc.on('exit', (code: number): void => { - if (code !== 0) { - deferred.resolve(new Error(`return code: ${code}`)); - } else { - let str = proc.stdout!.read().toString(); - if(str.search("\r\n")!=-1){ - sentCommands = str.split("\r\n"); - } - else{ - sentCommands = str.split('\n'); - } - deferred.resolve(null); - } - }); - - // create IPC interface - const dispatcher: IpcInterface = await createDispatcherInterface(); - dispatcher.onCommand((commandType: string, content: string): void => { - receivedCommands.push({ commandType, content }); - }); - - // Command #1: ok - dispatcher.sendCommand('IN'); - - // Command #2: ok - dispatcher.sendCommand('ME', '123'); - - // Command #3: FE is not tuner/assessor command, test the exception type of send non-valid command - try { - dispatcher.sendCommand('FE', '1'); - } catch (error) { - rejectCommandType = error as Error; - } - - return deferred.promise; -} - -/* FIXME -describe('core/protocol', (): void => { - - before(async () => { - prepareUnitTest(); - await runProcess(); - }); - - after(() => { - cleanupUnitTest(); - }); - - it('should have sent 2 successful commands', (): void => { - assert.equal(sentCommands.length, 3); - assert.equal(sentCommands[2], ''); - }); - - it('sendCommand() should work without content', (): void => { - assert.equal(sentCommands[0], "('IN', '')"); - }); - - it('sendCommand() should work with content', (): void => { - assert.equal(sentCommands[1], "('ME', '123')"); - }); - - it('sendCommand() should throw on wrong command type', (): void => { - assert.equal((rejectCommandType).name.split(' ')[0], 'AssertionError'); - }); - - it('should have received 3 commands', (): void => { - assert.equal(receivedCommands.length, 3); - }); - - it('onCommand() should work without content', (): void => { - assert.deepStrictEqual(receivedCommands[0], { - commandType: 'KI', - content: '' - }); - }); - - it('onCommand() should work with content', (): void => { - assert.deepStrictEqual(receivedCommands[1], { - commandType: 'KI', - content: 'hello' - }); - }); - - it('onCommand() should work with Unicode content', (): void => { - assert.deepStrictEqual(receivedCommands[2], { - commandType: 'KI', - content: '世界' - }); - }); - -}); -*/ diff --git a/ts/nni_manager/test/core/ipcInterfaceTerminate.test.ts b/ts/nni_manager/test/core/ipcInterfaceTerminate.test.ts deleted file mode 100644 index 97a4aec13..000000000 --- a/ts/nni_manager/test/core/ipcInterfaceTerminate.test.ts +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -'use strict'; - -import * as assert from 'assert'; -import { ChildProcess, spawn, StdioOptions } from 'child_process'; -import { Deferred } from 'ts-deferred'; -import { cleanupUnitTest, prepareUnitTest, getMsgDispatcherCommand, getTunerProc } from '../../common/utils'; -import * as CommandType from '../../core/commands'; -import { createDispatcherInterface, IpcInterface } from '../../core/ipcInterface'; - -let dispatcher: IpcInterface | undefined; -let procExit: boolean = false; -let procError: boolean = false; - -async function startProcess(): Promise { - // create fake assessor process - const stdio: StdioOptions = ['ignore', 'pipe', process.stderr, 'pipe', 'pipe']; - - const dispatcherCmd: string[] = getMsgDispatcherCommand( - // Mock tuner config - { - experimentName: 'exp1', - maxExperimentDuration: '1h', - searchSpace: '', - trainingService: { - platform: 'local' - }, - trialConcurrency: 1, - maxTrialNumber: 5, - tuner: { - className: 'dummy_tuner.DummyTuner', - codeDirectory: '.' - }, - assessor: { - className: 'dummy_assessor.DummyAssessor', - codeDirectory: '.' - }, - trialCommand: '', - trialCodeDirectory: '', - debug: true - } - ); - const proc: ChildProcess = getTunerProc(dispatcherCmd, stdio, 'core/test', process.env); - proc.on('error', (_error: Error): void => { - procExit = true; - procError = true; - }); - proc.on('exit', (code: number): void => { - procExit = true; - procError = (code !== 0); - }); - - // create IPC interface - dispatcher = await createDispatcherInterface(); - (dispatcher).onCommand((commandType: string, content: string): void => { - console.log(commandType, content); - }); -} - -/* FIXME -describe('core/ipcInterface.terminate', (): void => { - before(() => { - prepareUnitTest(); - startProcess(); - }); - - after(() => { - cleanupUnitTest(); - }); - - it('normal', () => { - (dispatcher).sendCommand( - CommandType.REPORT_METRIC_DATA, - '{"trial_job_id":"A","type":"PERIODICAL","value":1,"sequence":123}'); - - const deferred: Deferred = new Deferred(); - setTimeout( - () => { - assert.ok(!procExit); - assert.ok(!procError); - deferred.resolve(); - }, - 1000); - - return deferred.promise; - }); - - it('terminate', () => { - (dispatcher).sendCommand(CommandType.TERMINATE); - - const deferred: Deferred = new Deferred(); - setTimeout( - () => { - assert.ok(procExit); - assert.ok(!procError); - deferred.resolve(); - }, - 10000); - - return deferred.promise; - }); -}); -*/ diff --git a/ts/nni_manager/test/core/nnimanager.test.ts b/ts/nni_manager/test/core/nnimanager.test.ts index afa8636dd..7dec61772 100644 --- a/ts/nni_manager/test/core/nnimanager.test.ts +++ b/ts/nni_manager/test/core/nnimanager.test.ts @@ -22,7 +22,7 @@ import { NNITensorboardManager } from '../../extensions/nniTensorboardManager'; import * as path from 'path'; import { RestServer } from '../../rest_server'; import globals from '../../common/globals/unittest'; -import { UnitTestHelpers } from '../../core/tuner_command_channel/websocket_channel'; +import { UnitTestHelpers } from '../../core/tuner_command_channel'; import * as timersPromises from 'timers/promises'; let nniManager: NNIManager; @@ -324,7 +324,7 @@ async function resumeExperiment(): Promise { // globals.showLog(); // explicitly reset the websocket channel because it is singleton, does not work when two experiments // (one is start and the other is resume) run in the same process. - UnitTestHelpers.resetChannelSingleton(); + UnitTestHelpers.reset(); await initContainer('resume'); nniManager = component.get(Manager); diff --git a/ts/nni_manager/test/core/tuner_command_channel.test.ts b/ts/nni_manager/test/core/tuner_command_channel.test.ts deleted file mode 100644 index 29b350194..000000000 --- a/ts/nni_manager/test/core/tuner_command_channel.test.ts +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -import assert from 'assert/strict'; -import { setTimeout } from 'timers/promises'; - -import WebSocket from 'ws'; - -import { Deferred } from 'common/deferred'; -import { getWebSocketChannel, serveWebSocket } from 'core/tuner_command_channel'; -import { UnitTestHelpers } from 'core/tuner_command_channel/websocket_channel'; - -const heartbeatInterval: number = 10; - -// for testError, must be set before serveWebSocket() -UnitTestHelpers.setHeartbeatInterval(heartbeatInterval); - -/* test cases */ - -// Start serving and let a client connect. -async function testInit(): Promise { - const channel = getWebSocketChannel(); - channel.onCommand(command => { serverReceived.push(command); }); - channel.onError(error => { catchedError = error; }); - - server.on('connection', serveWebSocket); - client1 = new Client('client1'); - await channel.init(); -} - -// Send commands from server to client. -async function testSend(client: Client): Promise { - const channel = getWebSocketChannel(); - - channel.sendCommand(command1); - channel.sendCommand(command2); - await setTimeout(heartbeatInterval); - - assert.deepEqual(client.received, [command1, command2]); -} - -// Send commands from client to server. -async function testReceive(client: Client): Promise { - serverReceived.length = 0; - - client.ws.send(command2); - client.ws.send(command1); - await setTimeout(heartbeatInterval); - - assert.deepEqual(serverReceived, [command2, command1]); -} - -// Simulate client side crash. -async function testError(): Promise { - if (process.platform !== 'linux') { - // it is performance sensitive for the test case to yield error, - // but windows & mac agents of devops are too slow - client1.ws.terminate(); - return; - } - - // we have set heartbeat interval to 10ms, so pause for 30ms should make it timeout - client1.ws.pause(); - await setTimeout(heartbeatInterval * 3); - client1.ws.resume(); - - assert.notEqual(catchedError, undefined); -} - -// If the client losses connection by accident but not crashed, it will reconnect. -async function testReconnect(): Promise { - client2 = new Client('client2'); - await client2.deferred.promise; -} - -// Clean up. -async function testShutdown(): Promise { - const channel = getWebSocketChannel(); - await channel.shutdown(); - - client1.ws.close(); - client2.ws.close(); - server.close(); -} - -/* register */ -describe('## tuner_command_channel ##', () => { - it('init', testInit); - - it('send', () => testSend(client1)); - it('receive', () => testReceive(client1)); - - // it('mock timeout', testError); - it('reconnect', testReconnect); - - it('send after reconnect', () => testSend(client2)); - it('receive after reconnect', () => testReceive(client2)); - - it('shutdown', testShutdown); -}); - -/** helpers **/ - -const command1 = 'T_hello world'; -const command2 = 'T_你好'; - -const server = new WebSocket.Server({ port: 0 }); -let client1!: Client; -let client2!: Client; - -const serverReceived: string[] = []; -let catchedError: Error | undefined; - -class Client { - name: string; - received: string[] = []; - ws!: WebSocket; - deferred: Deferred = new Deferred(); - - constructor(name: string) { - this.name = name; - const port = (server.address() as any).port; - this.ws = new WebSocket(`ws://localhost:${port}`); - this.ws.on('message', (data, _isBinary) => { - this.received.push(data.toString()); - }); - this.ws.on('open', () => { - this.deferred.resolve(); - }); - } -} diff --git a/ts/nni_manager/training_service/reusable/environments/kubernetes/frameworkcontrollerEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/kubernetes/frameworkcontrollerEnvironmentService.ts index 1d495d654..3990408dd 100644 --- a/ts/nni_manager/training_service/reusable/environments/kubernetes/frameworkcontrollerEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/kubernetes/frameworkcontrollerEnvironmentService.ts @@ -12,9 +12,8 @@ import { ExperimentStartupInfo } from '../../../../common/experimentStartupInfo' import { EnvironmentInformation } from '../../environment'; import { KubernetesEnvironmentService } from './kubernetesEnvironmentService'; import { FrameworkControllerClientFactory } from '../../../kubernetes/frameworkcontroller/frameworkcontrollerApiClient'; -import { FrameworkControllerClusterConfigAzure, FrameworkControllerJobStatus, FrameworkControllerTrialConfigTemplate, +import { FrameworkControllerJobStatus, FrameworkControllerTrialConfigTemplate, FrameworkControllerJobCompleteStatus } from '../../../kubernetes/frameworkcontroller/frameworkcontrollerConfig'; -import { KeyVaultConfig, AzureStorage } from '../../../kubernetes/kubernetesConfig'; @component.Singleton export class FrameworkControllerEnvironmentService extends KubernetesEnvironmentService { diff --git a/ts/nni_manager/training_service/reusable/environments/kubernetes/kubernetesEnvironmentService.ts b/ts/nni_manager/training_service/reusable/environments/kubernetes/kubernetesEnvironmentService.ts index 2cf892712..94ef3ae07 100644 --- a/ts/nni_manager/training_service/reusable/environments/kubernetes/kubernetesEnvironmentService.ts +++ b/ts/nni_manager/training_service/reusable/environments/kubernetes/kubernetesEnvironmentService.ts @@ -6,7 +6,6 @@ import path from 'path'; import azureStorage from 'azure-storage'; import {Base64} from 'js-base64'; import {String} from 'typescript-string-operations'; -import { ExperimentConfig } from 'common/experimentConfig'; import { ExperimentStartupInfo } from 'common/experimentStartupInfo'; import { getLogger, Logger } from 'common/log'; import { EnvironmentInformation, EnvironmentService } from 'training_service/reusable/environment'; diff --git a/ts/nni_manager/training_service/v3/compat.ts b/ts/nni_manager/training_service/v3/compat.ts index b0c50b14c..6f43da102 100644 --- a/ts/nni_manager/training_service/v3/compat.ts +++ b/ts/nni_manager/training_service/v3/compat.ts @@ -11,7 +11,7 @@ import type { TrainingServiceConfig } from 'common/experimentConfig'; import globals from 'common/globals'; import { getLogger } from 'common/log'; import { - TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus + TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric } from 'common/trainingService'; import type { EnvironmentInfo, Parameter, TrainingServiceV3 } from 'common/training_service_v3'; import { trainingServiceFactoryV3 } from './factory'; diff --git a/ts/nni_manager/training_service/v3/factory.ts b/ts/nni_manager/training_service/v3/factory.ts index f7e2da45c..913658a27 100644 --- a/ts/nni_manager/training_service/v3/factory.ts +++ b/ts/nni_manager/training_service/v3/factory.ts @@ -7,7 +7,7 @@ * For now we only have "local_v3" and "remote_v3" as PoC. **/ -import type { LocalConfig, RemoteConfig, TrainingServiceConfig } from 'common/experimentConfig'; +import type { TrainingServiceConfig } from 'common/experimentConfig'; import type { TrainingServiceV3 } from 'common/training_service_v3'; import { LocalTrainingServiceV3 } from '../local_v3'; import { RemoteTrainingServiceV3 } from '../remote_v3';