Migrate tuner command channel to v3 channel (#5475)

This commit is contained in:
liuzhe-lz 2023-03-24 13:58:02 +08:00 коммит произвёл GitHub
Родитель ec0ddbb711
Коммит f58f3ab3a7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
40 изменённых файлов: 210 добавлений и 792 удалений

Просмотреть файл

@ -6,7 +6,6 @@ from __future__ import annotations
import logging
import time
import nni
from ..base import Command, CommandChannel
from .connection import WsConnection
@ -25,19 +24,25 @@ class WsChannelClient(CommandChannel):
def disconnect(self) -> None:
_logger.debug(f'Disconnect from {self._url}')
self.send({'type': '_bye_'})
self._closing = True
self._close_conn('client intentionally close')
if self._closing:
_logger.debug('Already closing')
else:
try:
if self._conn is not None:
self._conn.send({'type': '_bye_'})
except Exception as e:
_logger.debug(f'Failed to send bye: {repr(e)}')
self._closing = True
self._close_conn('client intentionally close')
def send(self, command: Command) -> None:
if self._closing:
return
_logger.debug(f'Send {command}')
msg = nni.dump(command)
for i in range(5):
try:
conn = self._ensure_conn()
conn.send(msg)
conn.send(command)
return
except Exception:
_logger.exception(f'Failed to send command. Retry in {i}s')
@ -45,16 +50,15 @@ class WsChannelClient(CommandChannel):
time.sleep(i)
_logger.warning(f'Failed to send command {command}. Last retry')
conn = self._ensure_conn()
conn.send(msg)
conn.send(command)
def receive(self) -> Command | None:
while True:
if self._closing:
return None
msg = self._receive_msg()
if msg is None:
command = self._receive_command()
if command is None:
return None
command = nni.load(msg)
if command['type'] == '_nop_':
continue
if command['type'] == '_bye_':
@ -88,15 +92,14 @@ class WsChannelClient(CommandChannel):
pass
self._conn = None
def _receive_msg(self) -> str | None:
def _receive_command(self) -> Command | None:
for i in range(5):
try:
conn = self._ensure_conn()
msg = conn.receive()
_logger.debug(f'Receive {msg}')
command = conn.receive()
if not self._closing:
assert msg is not None
return msg
assert command is not None
return command
except Exception:
_logger.exception(f'Failed to receive command. Retry in {i}s')
self._terminate_conn('receive fail')

Просмотреть файл

@ -18,6 +18,9 @@ from typing import Any, Type
import websockets
import nni
from ..base import Command
_logger = logging.getLogger(__name__)
# the singleton event loop
@ -81,17 +84,17 @@ class WsConnection:
return
self.disconnect(reason, 4001)
def send(self, message: str) -> None:
def send(self, message: Command) -> None:
_logger.debug(f'Sending {message}')
try:
_wait(self._ws.send(message))
_wait(self._ws.send(nni.dump(message)))
except websockets.ConnectionClosed: # type: ignore
_logger.debug('Connection closed by server.')
self._ws = None
_decrease_refcnt()
raise
def receive(self) -> str | None:
def receive(self) -> Command | None:
"""
Return received message;
or return ``None`` if the connection has been closed by peer.
@ -105,11 +108,12 @@ class WsConnection:
_decrease_refcnt()
raise
if msg is None:
return None
# seems the library will inference whether it's text or binary, so we don't have guarantee
if isinstance(msg, bytes):
return msg.decode()
else:
return msg
msg = msg.decode()
return nni.load(msg)
def _wait(coro):
# Synchronized version of "await".

Просмотреть файл

@ -11,19 +11,18 @@ __all__ = ['TunerCommandChannel']
import logging
import os
import time
from collections import defaultdict
from threading import Event
from typing import Any, Callable
from nni.common.serializer import dump, load, PayloadTooLarge
from nni.runtime.command_channel.websocket import WsChannelClient
from nni.typehint import Parameters
from .command_type import (
CommandType, TunerIncomingCommand,
Initialize, RequestTrialJobs, UpdateSearchSpace, ReportMetricData, TrialEnd, Terminate
)
from .websocket import WebSocket
_logger = logging.getLogger(__name__)
@ -52,10 +51,7 @@ class TunerCommandChannel:
"""
def __init__(self, url: str):
self._url = url
self._channel = WebSocket(url)
self._retry_intervals = [0, 1, 10]
self._channel = WsChannelClient(url)
self._callbacks: dict[CommandType, list[Callable[..., None]]] = defaultdict(list)
def connect(self) -> None:
@ -268,61 +264,14 @@ class TunerCommandChannel:
self._callbacks[TrialEnd.command_type].append(callback)
def _send(self, command_type: CommandType, data: str) -> None:
command = command_type.value.decode() + data
try:
self._channel.send(command)
except Exception as e:
_logger.warning('Exception on sending: %r', e)
if not isinstance(e, WebSocket.ConnectionClosed):
_logger.exception(e)
self._retry_send(command)
def _retry_send(self, command: str) -> None:
_logger.warning('Connection lost. Trying to reconnect...')
for i, interval in enumerate(self._retry_intervals):
_logger.info(f'Attempt #{i}, wait {interval} seconds...')
time.sleep(interval)
self._channel = WebSocket(self._url)
self._channel.connect()
try:
self._channel.send(command)
_logger.info('Reconnected.')
return
except Exception as e:
_logger.exception(e)
_logger.error('Failed to reconnect.')
raise RuntimeError('Connection lost')
self._channel.send({'type': command_type.value, 'content': data})
def _receive(self) -> tuple[CommandType, str] | tuple[None, None]:
try:
command = self._channel.receive()
except Exception as e:
_logger.warning('Exception on receiving: %r', e)
if not isinstance(e, WebSocket.ConnectionClosed):
_logger.exception(e)
command = None
command = self._channel.receive()
if command is None:
command = self._retry_receive()
command_type = CommandType(command[:2].encode())
return command_type, command[2:]
def _retry_receive(self) -> str:
_logger.warning('Connection lost. Trying to reconnect...')
for i, interval in enumerate(self._retry_intervals):
_logger.info(f'Attempt #{i}, wait {interval} seconds...')
time.sleep(interval)
self._channel = WebSocket(self._url)
self._channel.connect()
try:
command = self._channel.receive()
except Exception as e:
_logger.exception(e)
command = None # for robustness
if command is not None:
_logger.info('Reconnected')
return command
_logger.error('Failed to reconnect.')
raise RuntimeError('Connection lost')
return None, None
else:
return CommandType(command['type']), command.get('content', '')
def _validate_placement_constraint(placement_constraint):

Просмотреть файл

@ -10,23 +10,23 @@ from nni.utils import MetricType
class CommandType(Enum):
# in
Initialize = b'IN'
RequestTrialJobs = b'GE'
ReportMetricData = b'ME'
UpdateSearchSpace = b'SS'
ImportData = b'FD'
AddCustomizedTrialJob = b'AD'
TrialEnd = b'EN'
Terminate = b'TE'
Ping = b'PI'
Initialize = 'IN'
RequestTrialJobs = 'GE'
ReportMetricData = 'ME'
UpdateSearchSpace = 'SS'
ImportData = 'FD'
AddCustomizedTrialJob = 'AD'
TrialEnd = 'EN'
Terminate = 'TE'
Ping = 'PI'
# out
Initialized = b'ID'
NewTrialJob = b'TR'
SendTrialJobParameter = b'SP'
NoMoreTrialJobs = b'NO'
KillTrialJob = b'KI'
Error = b'ER'
Initialized = 'ID'
NewTrialJob = 'TR'
SendTrialJobParameter = 'SP'
NoMoreTrialJobs = 'NO'
KillTrialJob = 'KI'
Error = 'ER'
class TunerIncomingCommand:
# For type checking.

Просмотреть файл

@ -61,7 +61,7 @@ def send(command, data):
try:
_lock.acquire()
data = data.encode('utf8')
msg = b'%b%014d%b' % (command.value, len(data), data)
msg = b'%b%014d%b' % (command.value.encode(), len(data), data)
_logger.debug('Sending command, data: [%s]', msg)
_out_file.write(msg)
_out_file.flush()
@ -81,7 +81,7 @@ def receive():
return None, None
length = int(header[2:])
data = _in_file.read(length)
command = CommandType(header[:2])
command = CommandType(header[:2].decode())
data = data.decode('utf8')
_logger.debug('Received command, data: [%s]', data)
return command, data

Просмотреть файл

@ -1,4 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from ..command_channel.websocket.connection import WsConnection as WebSocket # pylint: disable=unused-import

Просмотреть файл

@ -150,7 +150,7 @@ stages:
- script: |
set -e
npm --prefix ts/nni_manager run test -- --exclude test/core/nnimanager.test.ts
npm --prefix ts/nni_manager run test
npm --prefix ts/nni_manager run test_nnimanager
cp ts/nni_manager/coverage/cobertura-coverage.xml coverage/typescript.xml
displayName: TypeScript unit test
@ -198,7 +198,7 @@ stages:
- script: |
export PATH=${PWD}/toolchain/node/bin:$PATH
npm --prefix ts/nni_manager run test -- --exclude test/core/nnimanager.test.ts
npm --prefix ts/nni_manager run test
npm --prefix ts/nni_manager run test_nnimanager
displayName: TypeScript unit test
@ -223,7 +223,7 @@ stages:
# temporarily disable this test, add it back after bug fixed
- script: |
npm --prefix ts/nni_manager run test -- --exclude test/core/nnimanager.test.ts
npm --prefix ts/nni_manager run test
npm --prefix ts/nni_manager run test_nnimanager
displayName: TypeScript unit test
@ -249,7 +249,7 @@ stages:
displayName: Python unit test
- script: |
CI=true npm --prefix ts/nni_manager run test -- --exclude test/core/nnimanager.test.ts
CI=true npm --prefix ts/nni_manager run test
# # exclude nnimanager's ut because macos in pipeline is pretty slow
# CI=true npm --prefix ts/nni_manager run test_nnimanager
displayName: TypeScript unit test

Просмотреть файл

@ -28,7 +28,7 @@
"@typescript-eslint/no-non-null-assertion": 0,
"@typescript-eslint/no-unused-vars": [
"off",
"error",
{
"argsIgnorePattern": "^_"
}

Просмотреть файл

@ -45,11 +45,9 @@
import util from 'node:util';
import type { Command } from 'common/command_channel/interface';
import { DefaultMap } from 'common/default_map';
import { Deferred } from 'common/deferred';
import { Logger, getLogger } from 'common/log';
import type { TrialKeeper } from 'common/trial_keeper/keeper';
import type { CommandChannel } from './interface';
interface RpcResponseCommand {

Просмотреть файл

@ -155,6 +155,11 @@ export class WsChannel implements CommandChannel {
this.emitter.on('__lost', callback);
}
// TODO: temporary api for tuner command channel
public getBufferedAmount(): number {
return this.connection?.ws.bufferedAmount ?? 0;
}
private newEpoch(): void {
this.connection = null;
this.epoch += 1;

Просмотреть файл

@ -5,7 +5,6 @@
* WebSocket command channel client.
**/
import events from 'node:events';
import { setTimeout } from 'node:timers/promises';
import { WebSocket } from 'ws';

Просмотреть файл

@ -29,7 +29,8 @@ export class WsConnection extends EventEmitter {
private heartbeatTimer: NodeJS.Timer | null = null;
private log: Logger;
private missingPongs: number = 0;
private ws: WebSocket; // NOTE: used in unit test
public readonly ws: WebSocket;
constructor(name: string, ws: WebSocket, commandEmitter: EventEmitter) {
super();

Просмотреть файл

@ -7,7 +7,6 @@
import { EventEmitter } from 'events';
import type { Request } from 'express';
import type { WebSocket } from 'ws';
import type { Command } from 'common/command_channel/interface';
@ -32,7 +31,12 @@ export class WsChannelServer extends EventEmitter {
public async start(): Promise<void> {
const channelPath = globals.rest.urlJoin(this.path, ':channel');
globals.rest.registerWebSocketHandler(channelPath, this.handleConnection.bind(this));
globals.rest.registerWebSocketHandler(this.path, (ws, _req) => {
this.handleConnection('__default__', ws); // TODO: only used by tuner
});
globals.rest.registerWebSocketHandler(channelPath, (ws, req) => {
this.handleConnection(req.params['channel'], ws);
});
this.log.debug('Start listening', channelPath);
}
@ -84,8 +88,7 @@ export class WsChannelServer extends EventEmitter {
this.on('connection', callback);
}
private handleConnection(ws: WebSocket, req: Request): void {
const channelId = req.params['channel'];
private handleConnection(channelId: string, ws: WebSocket): void {
this.log.debug('Incoming connection', channelId);
if (this.channels.has(channelId)) {

Просмотреть файл

@ -26,7 +26,7 @@ import util from 'util';
import { Logger, getLogger } from 'common/log';
const logger = getLogger('common.deferred');
const logger: Logger = getLogger('common.deferred');
export class Deferred<T> {
private resolveCallbacks: any[] = [];

Просмотреть файл

@ -1,10 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import assert from 'assert';
import { KubeflowOperator, OperatorApiVersion } from '../training_service/kubernetes/kubeflow/kubeflowConfig'
import { KubernetesStorageKind } from '../training_service/kubernetes/kubernetesConfig';
export interface TrainingServiceConfig {
platform: string;

Просмотреть файл

@ -6,8 +6,8 @@
* Functions will be added when used.
**/
import express, { Request, Response, Router } from 'express';
import type { Router as WsRouter } from 'express-ws';
import express, { Express, Request, Response, Router } from 'express';
import expressWs, { Router as WsRouter } from 'express-ws';
import type { WebSocket } from 'ws';
type HttpMethod = 'GET' | 'PUT';
@ -16,13 +16,23 @@ type ExpressCallback = (req: Request, res: Response) => void;
type WebSocketCallback = (ws: WebSocket, req: Request) => void;
export class RestManager {
private app: Express;
private router: Router;
constructor() {
// we don't actually need the app here,
// but expressWs() must be called before router.ws(), and it requires an app instance
this.app = express();
expressWs(this.app, undefined, { wsOptions: { maxPayload: 4 * 1024 * 1024 * 1024 }});
this.router = Router();
this.router.use(express.json({ limit: '50mb' }));
}
public getExpressApp(): Express {
return this.app;
}
public getExpressRouter(): Router {
return this.router;
}

Просмотреть файл

@ -24,7 +24,6 @@
**/
import { EventEmitter } from 'node:events';
import util from 'node:util';
import type { Command } from 'common/command_channel/interface';
import { RpcHelper, getRpcHelper } from 'common/command_channel/rpc_util';

Просмотреть файл

@ -9,7 +9,6 @@ import { ChildProcess, spawn, StdioOptions } from 'child_process';
import dgram from 'dgram';
import fs from 'fs';
import net from 'net';
import os from 'os';
import path from 'path';
import * as timersPromises from 'timers/promises';
import { Deferred } from 'ts-deferred';

Просмотреть файл

@ -1,15 +1,14 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import { IpcInterface } from './tuner_command_channel/common';
export { IpcInterface } from './tuner_command_channel/common';
import * as shim from './tuner_command_channel/shim';
import { IpcInterface, getTunerServer } from './tuner_command_channel';
export { IpcInterface } from './tuner_command_channel';
let tunerDisabled: boolean = false;
export async function createDispatcherInterface(): Promise<IpcInterface> {
if (!tunerDisabled) {
return await shim.createDispatcherInterface();
return getTunerServer();
} else {
return new DummyIpcInterface();
}

Просмотреть файл

@ -15,7 +15,7 @@ import {
NNIManagerStatus, ProfileUpdateType, TrialJobStatistics
} from '../common/manager';
import {
ExperimentConfig, LocalConfig, TrainingServiceConfig, toSeconds, toCudaVisibleDevices
ExperimentConfig, TrainingServiceConfig, toSeconds, toCudaVisibleDevices
} from '../common/experimentConfig';
import { getExperimentsManager } from 'extensions/experiments_manager';
import { TensorboardManager } from '../common/tensorboardManager';

Просмотреть файл

@ -0,0 +1,106 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import { EventEmitter } from 'node:events';
import { WsChannel, WsChannelServer } from 'common/command_channel/websocket';
import { Deferred } from 'common/deferred';
import { Logger, getLogger } from 'common/log';
export interface IpcInterface {
init(): Promise<void>;
sendCommand(commandType: string, content?: string): void;
onCommand(listener: (commandType: string, content: string) => void): void;
onError(listener: (error: Error) => void): void;
}
export function getTunerServer(): IpcInterface {
return server;
}
const logger: Logger = getLogger('tuner_command_channel');
class TunerServer {
private channel!: WsChannel;
private connect: Deferred<void> = new Deferred();
private emitter: EventEmitter = new EventEmitter();
private server: WsChannelServer;
constructor() {
this.server = new WsChannelServer('tuner', 'tuner');
this.server.onConnection((_channelId, channel) => {
this.channel = channel;
this.channel.onError(error => {
this.emitter.emit('error', error);
});
this.channel.onReceive(command => {
if (command.type === 'ER') {
this.emitter.emit('error', new Error(command.content));
} else {
this.emitter.emit('command', command.type, command.content ?? '');
}
});
this.connect.resolve();
});
this.server.start();
}
public init(): Promise<void> { // wait connection
if (this.connect.settled) {
logger.debug('Initialized.');
return Promise.resolve();
} else {
logger.debug('Waiting connection...');
// TODO: This is a quick fix. It should check tuner's process status instead.
setTimeout(() => {
if (!this.connect.settled) {
const msg = 'Tuner did not connect in 10 seconds. Please check tuner (dispatcher) log.';
this.connect.reject(new Error('tuner_command_channel: ' + msg));
}
}, 10000);
return this.connect.promise;
}
}
// TODO: for unit test only
public async stop(): Promise<void> {
await this.server.shutdown();
}
public sendCommand(commandType: string, content?: string): void {
if (commandType === 'PI') { // ping is handled with WebSocket protocol
return;
}
if (this.channel.getBufferedAmount() > 1000) {
logger.warning('Sending too fast! Try to reduce the frequency of intermediate results.');
}
this.channel.send({ type: commandType, content });
if (commandType === 'TE') {
this.channel.close('TE command');
this.server.shutdown();
}
}
public onCommand(listener: (commandType: string, content: string) => void): void {
this.emitter.on('command', listener);
}
public onError(listener: (error: Error) => void): void {
this.emitter.on('error', listener);
}
}
let server: TunerServer = new TunerServer();
export namespace UnitTestHelpers {
export function reset(): void {
server = new TunerServer();
}
export async function stop(): Promise<void> {
await server.stop();
}
}

Просмотреть файл

@ -1,9 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
export interface IpcInterface {
init(): Promise<void>;
sendCommand(commandType: string, content?: string): void;
onCommand(listener: (commandType: string, content: string) => void): void;
onError(listener: (error: Error) => void): void;
}

Просмотреть файл

@ -1,4 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
export { getWebSocketChannel, serveWebSocket } from './websocket_channel';

Просмотреть файл

@ -1,52 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import type { IpcInterface } from './common';
import { WebSocketChannel, getWebSocketChannel } from './websocket_channel';
export async function createDispatcherInterface(): Promise<IpcInterface> {
return new WsIpcInterface();
}
class WsIpcInterface implements IpcInterface {
private channel: WebSocketChannel = getWebSocketChannel();
private commandListener?: (commandType: string, content: string) => void;
private errorListener?: (error: Error) => void;
constructor() {
this.channel.onCommand((command: string) => {
const commandType = command.slice(0, 2);
const content = command.slice(2);
if (commandType === 'ER') {
if (this.errorListener !== undefined) {
this.errorListener(new Error(content));
}
} else {
if (this.commandListener !== undefined) {
this.commandListener(commandType, content);
}
}
});
}
public async init(): Promise<void> {
await this.channel.init();
}
public sendCommand(commandType: string, content: string = ''): void {
if (commandType !== 'PI') { // ping is handled with WebSocket protocol
this.channel.sendCommand(commandType + content);
if (commandType === 'TE') {
this.channel.shutdown();
}
}
}
public onCommand(listener: (commandType: string, content: string) => void): void {
this.commandListener = listener;
}
public onError(listener: (error: Error) => void): void {
this.errorListener = listener;
}
}

Просмотреть файл

@ -1,206 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
/**
* The IPC channel between NNI manager and tuner.
*
* TODO:
* 1. Merge with environment service's WebSocket channel.
* 2. Split import data command to avoid extremely long message.
* 3. Refactor message format.
**/
import assert from 'assert/strict';
import { EventEmitter } from 'events';
import type WebSocket from 'ws';
import { Deferred } from 'common/deferred';
import { Logger, getLogger } from 'common/log';
const logger: Logger = getLogger('tuner_command_channel.WebSocketChannel');
export interface WebSocketChannel {
init(): Promise<void>;
shutdown(): Promise<void>;
sendCommand(command: string): void; // maybe this should return Promise<void>
onCommand(callback: (command: string) => void): void;
onError(callback: (error: Error) => void): void;
}
/**
* Get the singleton tuner command channel.
* Remember to invoke ``await channel.init()`` before doing anything else.
**/
export function getWebSocketChannel(): WebSocketChannel {
return channelSingleton;
}
/**
* The callback to serve WebSocket connection request. Used by REST server module.
* If it is invoked more than once, the previous connection will be dropped.
**/
export function serveWebSocket(ws: WebSocket): void {
channelSingleton.serveWebSocket(ws);
}
class WebSocketChannelImpl implements WebSocketChannel {
private deferredInit: Deferred<void> = new Deferred<void>();
private emitter: EventEmitter = new EventEmitter();
private heartbeatTimer!: NodeJS.Timer;
private serving: boolean = false;
private waitingPong: boolean = false;
private ws!: WebSocket;
public serveWebSocket(ws: WebSocket): void {
if (this.ws === undefined) {
logger.debug('Connected.');
} else {
logger.warning('Reconnecting. Drop previous connection.');
this.dropConnection('Reconnected');
}
this.serving = true;
this.ws = ws;
this.ws.on('close', this.handleWsClose);
this.ws.on('error', this.handleWsError);
this.ws.on('message', this.handleWsMessage);
this.ws.on('pong', this.handleWsPong);
this.heartbeatTimer = setInterval(this.heartbeat.bind(this), heartbeatInterval);
this.deferredInit.resolve();
}
public init(): Promise<void> {
if (this.ws === undefined) {
logger.debug('Waiting connection...');
// TODO: This is a quick fix. It should check tuner's process status instead.
setTimeout(() => {
if (!this.deferredInit.settled) {
const msg = 'Tuner did not connect in 10 seconds. Please check tuner (dispatcher) log.';
this.deferredInit.reject(new Error('tuner_command_channel: ' + msg));
}
}, 10000);
return this.deferredInit.promise;
} else {
logger.debug('Initialized.');
return Promise.resolve();
}
}
public async shutdown(): Promise<void> {
if (this.ws === undefined) {
return;
}
clearInterval(this.heartbeatTimer);
this.serving = false;
this.emitter.removeAllListeners();
}
public sendCommand(command: string): void {
assert.ok(this.ws !== undefined);
logger.debug('Sending', command);
this.ws.send(command);
if (this.ws.bufferedAmount > command.length + 1000) {
logger.warning('Sending too fast! Try to reduce the frequency of intermediate results.');
}
}
public onCommand(callback: (command: string) => void): void {
this.emitter.on('command', callback);
}
public onError(callback: (error: Error) => void): void {
this.emitter.on('error', callback);
}
/* Following callbacks must be auto-binded arrow functions to be turned off */
private handleWsClose = (): void => {
this.handleError(new Error('tuner_command_channel: Tuner closed connection'));
}
private handleWsError = (error: Error): void => {
this.handleError(error);
}
private handleWsMessage = (data: Buffer, _isBinary: boolean): void => {
this.receive(data);
}
private handleWsPong = (): void => {
this.waitingPong = false;
}
private dropConnection(reason: string): void {
if (this.ws === undefined) {
return;
}
this.serving = false;
this.waitingPong = false;
clearInterval(this.heartbeatTimer);
this.ws.off('close', this.handleWsClose);
this.ws.off('error', this.handleWsError);
this.ws.off('message', this.handleWsMessage);
this.ws.off('pong', this.handleWsPong);
this.ws.on('close', () => {
logger.info('Connection dropped');
});
this.ws.on('message', (data, _isBinary) => {
logger.error('Received message after reconnect:', data);
});
this.ws.on('pong', () => {
logger.error('Received pong after reconnect.');
});
this.ws.close(1001, reason);
}
private heartbeat(): void {
// if (this.waitingPong) {
// this.ws.terminate(); // this will trigger "close" event
// this.handleError(new Error('tuner_command_channel: Tuner loses responsive'));
// }
this.waitingPong = true;
this.ws.ping();
}
private receive(data: Buffer): void {
logger.debug('Received', data);
this.emitter.emit('command', data.toString());
}
private handleError(error: Error): void {
if (!this.serving) {
logger.debug('Silent error:', error);
return;
}
logger.error('Error:', error);
clearInterval(this.heartbeatTimer);
this.emitter.emit('error', error);
this.serving = false;
}
}
let channelSingleton: WebSocketChannelImpl = new WebSocketChannelImpl();
let heartbeatInterval: number = 5000;
export namespace UnitTestHelpers {
export function setHeartbeatInterval(ms: number): void {
heartbeatInterval = ms;
}
// NOTE: this function is only for unittest of nnimanager,
// because resuming an experiment should reset websocket channel.
export function resetChannelSingleton(): void {
channelSingleton = new WebSocketChannelImpl();
}
}

Просмотреть файл

@ -3,8 +3,6 @@
import assert from 'assert/strict';
import fs from 'fs';
import os from 'os';
import path from 'path';
import * as timersPromises from 'timers/promises';
import { Deferred } from 'ts-deferred';

Просмотреть файл

@ -8,8 +8,6 @@ import * as timersPromises from 'timers/promises';
import glob from 'glob';
import lockfile from 'lockfile';
import globals from 'common/globals';
const lockStale: number = 2000;
const retry: number = 100;

Просмотреть файл

@ -21,13 +21,13 @@
import 'app-module-path/register'; // so we can use absolute path to import
import fs from 'fs';
import { Container, Scope } from 'typescript-ioc';
import { globals, initGlobals } from 'common/globals';
initGlobals();
import * as component from 'common/component';
import { Database, DataStore } from 'common/datastore';
import globals, { initGlobals } from 'common/globals';
import { Logger, getLogger } from 'common/log';
import { Manager } from 'common/manager';
import { TensorboardManager } from 'common/tensorboardManager';
@ -71,8 +71,6 @@ process.on('SIGINT', () => { globals.shutdown.initiate('SIGINT'); });
/* main */
initGlobals();
start().then(() => {
logger.debug('start() returned.');
}).catch((error) => {

Просмотреть файл

@ -4,8 +4,8 @@
"license": "MIT",
"scripts": {
"build": "tsc",
"test": "nyc --reporter=cobertura --reporter=text mocha \"test/**/*.test.ts\"",
"test_nnimanager": "nyc --reporter=cobertura --reporter=text mocha test/core/nnimanager.test.ts",
"test": "nyc --reporter=cobertura --reporter=text mocha \"test/**/*.test.ts\" --exclude test/core/nnimanager.test.ts",
"test_nnimanager": "mocha test/core/nnimanager.test.ts",
"mocha": "mocha",
"eslint": "eslint . --ext .ts"
},

Просмотреть файл

@ -13,9 +13,6 @@ import assert from 'node:assert/strict';
import type { Server } from 'node:http';
import type { AddressInfo } from 'node:net';
import express from 'express';
import expressWs from 'express-ws';
import { Deferred } from 'common/deferred';
import { globals } from 'common/globals';
import { Logger, getLogger } from 'common/log';
@ -40,9 +37,7 @@ export class RestServerCore {
public start(): Promise<void> {
logger.info(`Starting REST server at port ${this.port}, URL prefix: "/${this.urlPrefix}"`);
const app = express();
expressWs(app, undefined, { wsOptions: { maxPayload: 4 * 1024 * 1024 * 1024 }});
const app = globals.rest.getExpressApp();
app.use('/' + this.urlPrefix, globals.rest.getExpressRouter());
app.all('/' + this.urlPrefix, (_req, res) => { res.status(404).send('Not Found'); });
app.all('*', (_req, res) => { res.status(404).send(`Outside prefix "/${this.urlPrefix}"`); });

Просмотреть файл

@ -26,13 +26,11 @@ import type { AddressInfo } from 'net';
import path from 'path';
import express, { Request, Response, Router } from 'express';
import expressWs from 'express-ws';
import httpProxy from 'http-proxy';
import { Deferred } from 'common/deferred';
import globals from 'common/globals';
import { Logger, getLogger } from 'common/log';
import * as tunerCommandChannel from 'core/tuner_command_channel';
const logger: Logger = getLogger('RestServer');
@ -60,9 +58,7 @@ export class RestServer {
public start(): Promise<void> {
logger.info(`Starting REST server at port ${this.port}, URL prefix: "/${this.urlPrefix}"`);
const app = express();
expressWs(app, undefined, { wsOptions: { maxPayload: 4 * 1024 * 1024 * 1024 }});
const app = globals.rest.getExpressApp();
app.use('/' + this.urlPrefix, mainRouter());
app.use('/' + this.urlPrefix, fallbackRouter());
app.all('*', (_req: Request, res: Response) => { res.status(404).send(`Outside prefix "/${this.urlPrefix}"`); });
@ -110,10 +106,7 @@ export class RestServer {
* In fact experiments management should have a separate prefix and module.
**/
function mainRouter(): Router {
const router = globals.rest.getExpressRouter() as expressWs.Router;
/* WebSocket APIs */
router.ws('/tuner', (ws, _req, _next) => { tunerCommandChannel.serveWebSocket(ws); });
const router = globals.rest.getExpressRouter();
/* Download log files */
// The REST API path "/logs" does not match file system path "/log".
@ -142,6 +135,9 @@ function fallbackRouter(): Router {
/* 404 as catch-all */
router.all('*', (_req: Request, res: Response) => { res.status(404).send('Not Found'); });
// TODO: websocket 404
return router;
}

Просмотреть файл

@ -13,7 +13,6 @@ import { getLogger, Logger } from '../common/log';
import { ExperimentProfile, Manager, TrialJobStatistics } from '../common/manager';
import { getExperimentsManager } from 'extensions/experiments_manager';
import { TensorboardManager, TensorboardTaskInfo } from '../common/tensorboardManager';
import { ValidationSchemas } from './restValidationSchemas';
import { getVersion } from '../common/utils';
import { MetricType } from '../common/datastore';
import { ProfileUpdateType } from '../common/manager';
@ -294,11 +293,7 @@ class NNIRestHandler {
private getTrialFile(router: Router): void {
router.get('/trial-file/:id/:filename', async(req: Request, res: Response) => {
let encoding: string | null = null;
const filename = req.params['filename'];
if (!filename.includes('.') || filename.match(/.*\.(txt|log)/g)) {
encoding = 'utf8';
}
this.nniManager.getTrialFile(req.params['id'], filename).then((content: Buffer | string) => {
const contentType = content instanceof Buffer ? 'application/octet-stream' : 'text/plain';
res.header('Content-Type', contentType);

Просмотреть файл

@ -1,121 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as assert from 'assert';
import { ChildProcess, spawn, StdioOptions } from 'child_process';
import { Deferred } from 'ts-deferred';
import { cleanupUnitTest, prepareUnitTest, getTunerProc } from '../../common/utils';
import * as CommandType from '../../core/commands';
import { createDispatcherInterface, IpcInterface } from '../../core/ipcInterface';
import { NNIError } from '../../common/errors';
let sentCommands: { [key: string]: string }[] = [];
const receivedCommands: { [key: string]: string }[] = [];
let rejectCommandType: Error | undefined;
async function runProcess(): Promise<Error | null> {
// the process is intended to throw error, do not reject
const deferred: Deferred<Error | null> = new Deferred<Error | null>();
// create fake assessor process
const stdio: StdioOptions = ['ignore', 'pipe', process.stderr, 'pipe', 'pipe'];
const command: string[] = [ 'python', 'assessor.py' ];
const proc: ChildProcess = getTunerProc(command, stdio, 'core/test', process.env);
// record its sent/received commands on exit
proc.on('error', (error: Error): void => { deferred.resolve(error); });
proc.on('exit', (code: number): void => {
if (code !== 0) {
deferred.resolve(new Error(`return code: ${code}`));
} else {
let str = proc.stdout!.read().toString();
if(str.search("\r\n")!=-1){
sentCommands = str.split("\r\n");
}
else{
sentCommands = str.split('\n');
}
deferred.resolve(null);
}
});
// create IPC interface
const dispatcher: IpcInterface = await createDispatcherInterface();
dispatcher.onCommand((commandType: string, content: string): void => {
receivedCommands.push({ commandType, content });
});
// Command #1: ok
dispatcher.sendCommand('IN');
// Command #2: ok
dispatcher.sendCommand('ME', '123');
// Command #3: FE is not tuner/assessor command, test the exception type of send non-valid command
try {
dispatcher.sendCommand('FE', '1');
} catch (error) {
rejectCommandType = error as Error;
}
return deferred.promise;
}
/* FIXME
describe('core/protocol', (): void => {
before(async () => {
prepareUnitTest();
await runProcess();
});
after(() => {
cleanupUnitTest();
});
it('should have sent 2 successful commands', (): void => {
assert.equal(sentCommands.length, 3);
assert.equal(sentCommands[2], '');
});
it('sendCommand() should work without content', (): void => {
assert.equal(sentCommands[0], "('IN', '')");
});
it('sendCommand() should work with content', (): void => {
assert.equal(sentCommands[1], "('ME', '123')");
});
it('sendCommand() should throw on wrong command type', (): void => {
assert.equal((<Error>rejectCommandType).name.split(' ')[0], 'AssertionError');
});
it('should have received 3 commands', (): void => {
assert.equal(receivedCommands.length, 3);
});
it('onCommand() should work without content', (): void => {
assert.deepStrictEqual(receivedCommands[0], {
commandType: 'KI',
content: ''
});
});
it('onCommand() should work with content', (): void => {
assert.deepStrictEqual(receivedCommands[1], {
commandType: 'KI',
content: 'hello'
});
});
it('onCommand() should work with Unicode content', (): void => {
assert.deepStrictEqual(receivedCommands[2], {
commandType: 'KI',
content: '世界'
});
});
});
*/

Просмотреть файл

@ -1,105 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as assert from 'assert';
import { ChildProcess, spawn, StdioOptions } from 'child_process';
import { Deferred } from 'ts-deferred';
import { cleanupUnitTest, prepareUnitTest, getMsgDispatcherCommand, getTunerProc } from '../../common/utils';
import * as CommandType from '../../core/commands';
import { createDispatcherInterface, IpcInterface } from '../../core/ipcInterface';
let dispatcher: IpcInterface | undefined;
let procExit: boolean = false;
let procError: boolean = false;
async function startProcess(): Promise<void> {
// create fake assessor process
const stdio: StdioOptions = ['ignore', 'pipe', process.stderr, 'pipe', 'pipe'];
const dispatcherCmd: string[] = getMsgDispatcherCommand(
// Mock tuner config
<any>{
experimentName: 'exp1',
maxExperimentDuration: '1h',
searchSpace: '',
trainingService: {
platform: 'local'
},
trialConcurrency: 1,
maxTrialNumber: 5,
tuner: {
className: 'dummy_tuner.DummyTuner',
codeDirectory: '.'
},
assessor: {
className: 'dummy_assessor.DummyAssessor',
codeDirectory: '.'
},
trialCommand: '',
trialCodeDirectory: '',
debug: true
}
);
const proc: ChildProcess = getTunerProc(dispatcherCmd, stdio, 'core/test', process.env);
proc.on('error', (_error: Error): void => {
procExit = true;
procError = true;
});
proc.on('exit', (code: number): void => {
procExit = true;
procError = (code !== 0);
});
// create IPC interface
dispatcher = await createDispatcherInterface();
(<IpcInterface>dispatcher).onCommand((commandType: string, content: string): void => {
console.log(commandType, content);
});
}
/* FIXME
describe('core/ipcInterface.terminate', (): void => {
before(() => {
prepareUnitTest();
startProcess();
});
after(() => {
cleanupUnitTest();
});
it('normal', () => {
(<IpcInterface>dispatcher).sendCommand(
CommandType.REPORT_METRIC_DATA,
'{"trial_job_id":"A","type":"PERIODICAL","value":1,"sequence":123}');
const deferred: Deferred<void> = new Deferred<void>();
setTimeout(
() => {
assert.ok(!procExit);
assert.ok(!procError);
deferred.resolve();
},
1000);
return deferred.promise;
});
it('terminate', () => {
(<IpcInterface>dispatcher).sendCommand(CommandType.TERMINATE);
const deferred: Deferred<void> = new Deferred<void>();
setTimeout(
() => {
assert.ok(procExit);
assert.ok(!procError);
deferred.resolve();
},
10000);
return deferred.promise;
});
});
*/

Просмотреть файл

@ -22,7 +22,7 @@ import { NNITensorboardManager } from '../../extensions/nniTensorboardManager';
import * as path from 'path';
import { RestServer } from '../../rest_server';
import globals from '../../common/globals/unittest';
import { UnitTestHelpers } from '../../core/tuner_command_channel/websocket_channel';
import { UnitTestHelpers } from '../../core/tuner_command_channel';
import * as timersPromises from 'timers/promises';
let nniManager: NNIManager;
@ -324,7 +324,7 @@ async function resumeExperiment(): Promise<void> {
// globals.showLog();
// explicitly reset the websocket channel because it is singleton, does not work when two experiments
// (one is start and the other is resume) run in the same process.
UnitTestHelpers.resetChannelSingleton();
UnitTestHelpers.reset();
await initContainer('resume');
nniManager = component.get(Manager);

Просмотреть файл

@ -1,131 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import assert from 'assert/strict';
import { setTimeout } from 'timers/promises';
import WebSocket from 'ws';
import { Deferred } from 'common/deferred';
import { getWebSocketChannel, serveWebSocket } from 'core/tuner_command_channel';
import { UnitTestHelpers } from 'core/tuner_command_channel/websocket_channel';
const heartbeatInterval: number = 10;
// for testError, must be set before serveWebSocket()
UnitTestHelpers.setHeartbeatInterval(heartbeatInterval);
/* test cases */
// Start serving and let a client connect.
async function testInit(): Promise<void> {
const channel = getWebSocketChannel();
channel.onCommand(command => { serverReceived.push(command); });
channel.onError(error => { catchedError = error; });
server.on('connection', serveWebSocket);
client1 = new Client('client1');
await channel.init();
}
// Send commands from server to client.
async function testSend(client: Client): Promise<void> {
const channel = getWebSocketChannel();
channel.sendCommand(command1);
channel.sendCommand(command2);
await setTimeout(heartbeatInterval);
assert.deepEqual(client.received, [command1, command2]);
}
// Send commands from client to server.
async function testReceive(client: Client): Promise<void> {
serverReceived.length = 0;
client.ws.send(command2);
client.ws.send(command1);
await setTimeout(heartbeatInterval);
assert.deepEqual(serverReceived, [command2, command1]);
}
// Simulate client side crash.
async function testError(): Promise<void> {
if (process.platform !== 'linux') {
// it is performance sensitive for the test case to yield error,
// but windows & mac agents of devops are too slow
client1.ws.terminate();
return;
}
// we have set heartbeat interval to 10ms, so pause for 30ms should make it timeout
client1.ws.pause();
await setTimeout(heartbeatInterval * 3);
client1.ws.resume();
assert.notEqual(catchedError, undefined);
}
// If the client losses connection by accident but not crashed, it will reconnect.
async function testReconnect(): Promise<void> {
client2 = new Client('client2');
await client2.deferred.promise;
}
// Clean up.
async function testShutdown(): Promise<void> {
const channel = getWebSocketChannel();
await channel.shutdown();
client1.ws.close();
client2.ws.close();
server.close();
}
/* register */
describe('## tuner_command_channel ##', () => {
it('init', testInit);
it('send', () => testSend(client1));
it('receive', () => testReceive(client1));
// it('mock timeout', testError);
it('reconnect', testReconnect);
it('send after reconnect', () => testSend(client2));
it('receive after reconnect', () => testReceive(client2));
it('shutdown', testShutdown);
});
/** helpers **/
const command1 = 'T_hello world';
const command2 = 'T_你好';
const server = new WebSocket.Server({ port: 0 });
let client1!: Client;
let client2!: Client;
const serverReceived: string[] = [];
let catchedError: Error | undefined;
class Client {
name: string;
received: string[] = [];
ws!: WebSocket;
deferred: Deferred<void> = new Deferred();
constructor(name: string) {
this.name = name;
const port = (server.address() as any).port;
this.ws = new WebSocket(`ws://localhost:${port}`);
this.ws.on('message', (data, _isBinary) => {
this.received.push(data.toString());
});
this.ws.on('open', () => {
this.deferred.resolve();
});
}
}

Просмотреть файл

@ -12,9 +12,8 @@ import { ExperimentStartupInfo } from '../../../../common/experimentStartupInfo'
import { EnvironmentInformation } from '../../environment';
import { KubernetesEnvironmentService } from './kubernetesEnvironmentService';
import { FrameworkControllerClientFactory } from '../../../kubernetes/frameworkcontroller/frameworkcontrollerApiClient';
import { FrameworkControllerClusterConfigAzure, FrameworkControllerJobStatus, FrameworkControllerTrialConfigTemplate,
import { FrameworkControllerJobStatus, FrameworkControllerTrialConfigTemplate,
FrameworkControllerJobCompleteStatus } from '../../../kubernetes/frameworkcontroller/frameworkcontrollerConfig';
import { KeyVaultConfig, AzureStorage } from '../../../kubernetes/kubernetesConfig';
@component.Singleton
export class FrameworkControllerEnvironmentService extends KubernetesEnvironmentService {

Просмотреть файл

@ -6,7 +6,6 @@ import path from 'path';
import azureStorage from 'azure-storage';
import {Base64} from 'js-base64';
import {String} from 'typescript-string-operations';
import { ExperimentConfig } from 'common/experimentConfig';
import { ExperimentStartupInfo } from 'common/experimentStartupInfo';
import { getLogger, Logger } from 'common/log';
import { EnvironmentInformation, EnvironmentService } from 'training_service/reusable/environment';

Просмотреть файл

@ -11,7 +11,7 @@ import type { TrainingServiceConfig } from 'common/experimentConfig';
import globals from 'common/globals';
import { getLogger } from 'common/log';
import {
TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus
TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from 'common/trainingService';
import type { EnvironmentInfo, Parameter, TrainingServiceV3 } from 'common/training_service_v3';
import { trainingServiceFactoryV3 } from './factory';

Просмотреть файл

@ -7,7 +7,7 @@
* For now we only have "local_v3" and "remote_v3" as PoC.
**/
import type { LocalConfig, RemoteConfig, TrainingServiceConfig } from 'common/experimentConfig';
import type { TrainingServiceConfig } from 'common/experimentConfig';
import type { TrainingServiceV3 } from 'common/training_service_v3';
import { LocalTrainingServiceV3 } from '../local_v3';
import { RemoteTrainingServiceV3 } from '../remote_v3';