зеркало из https://github.com/microsoft/nni.git
Migrate tuner command channel to v3 channel (#5475)
This commit is contained in:
Родитель
ec0ddbb711
Коммит
f58f3ab3a7
|
@ -6,7 +6,6 @@ from __future__ import annotations
|
|||
import logging
|
||||
import time
|
||||
|
||||
import nni
|
||||
from ..base import Command, CommandChannel
|
||||
from .connection import WsConnection
|
||||
|
||||
|
@ -25,19 +24,25 @@ class WsChannelClient(CommandChannel):
|
|||
|
||||
def disconnect(self) -> None:
|
||||
_logger.debug(f'Disconnect from {self._url}')
|
||||
self.send({'type': '_bye_'})
|
||||
self._closing = True
|
||||
self._close_conn('client intentionally close')
|
||||
if self._closing:
|
||||
_logger.debug('Already closing')
|
||||
else:
|
||||
try:
|
||||
if self._conn is not None:
|
||||
self._conn.send({'type': '_bye_'})
|
||||
except Exception as e:
|
||||
_logger.debug(f'Failed to send bye: {repr(e)}')
|
||||
self._closing = True
|
||||
self._close_conn('client intentionally close')
|
||||
|
||||
def send(self, command: Command) -> None:
|
||||
if self._closing:
|
||||
return
|
||||
_logger.debug(f'Send {command}')
|
||||
msg = nni.dump(command)
|
||||
for i in range(5):
|
||||
try:
|
||||
conn = self._ensure_conn()
|
||||
conn.send(msg)
|
||||
conn.send(command)
|
||||
return
|
||||
except Exception:
|
||||
_logger.exception(f'Failed to send command. Retry in {i}s')
|
||||
|
@ -45,16 +50,15 @@ class WsChannelClient(CommandChannel):
|
|||
time.sleep(i)
|
||||
_logger.warning(f'Failed to send command {command}. Last retry')
|
||||
conn = self._ensure_conn()
|
||||
conn.send(msg)
|
||||
conn.send(command)
|
||||
|
||||
def receive(self) -> Command | None:
|
||||
while True:
|
||||
if self._closing:
|
||||
return None
|
||||
msg = self._receive_msg()
|
||||
if msg is None:
|
||||
command = self._receive_command()
|
||||
if command is None:
|
||||
return None
|
||||
command = nni.load(msg)
|
||||
if command['type'] == '_nop_':
|
||||
continue
|
||||
if command['type'] == '_bye_':
|
||||
|
@ -88,15 +92,14 @@ class WsChannelClient(CommandChannel):
|
|||
pass
|
||||
self._conn = None
|
||||
|
||||
def _receive_msg(self) -> str | None:
|
||||
def _receive_command(self) -> Command | None:
|
||||
for i in range(5):
|
||||
try:
|
||||
conn = self._ensure_conn()
|
||||
msg = conn.receive()
|
||||
_logger.debug(f'Receive {msg}')
|
||||
command = conn.receive()
|
||||
if not self._closing:
|
||||
assert msg is not None
|
||||
return msg
|
||||
assert command is not None
|
||||
return command
|
||||
except Exception:
|
||||
_logger.exception(f'Failed to receive command. Retry in {i}s')
|
||||
self._terminate_conn('receive fail')
|
||||
|
|
|
@ -18,6 +18,9 @@ from typing import Any, Type
|
|||
|
||||
import websockets
|
||||
|
||||
import nni
|
||||
from ..base import Command
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# the singleton event loop
|
||||
|
@ -81,17 +84,17 @@ class WsConnection:
|
|||
return
|
||||
self.disconnect(reason, 4001)
|
||||
|
||||
def send(self, message: str) -> None:
|
||||
def send(self, message: Command) -> None:
|
||||
_logger.debug(f'Sending {message}')
|
||||
try:
|
||||
_wait(self._ws.send(message))
|
||||
_wait(self._ws.send(nni.dump(message)))
|
||||
except websockets.ConnectionClosed: # type: ignore
|
||||
_logger.debug('Connection closed by server.')
|
||||
self._ws = None
|
||||
_decrease_refcnt()
|
||||
raise
|
||||
|
||||
def receive(self) -> str | None:
|
||||
def receive(self) -> Command | None:
|
||||
"""
|
||||
Return received message;
|
||||
or return ``None`` if the connection has been closed by peer.
|
||||
|
@ -105,11 +108,12 @@ class WsConnection:
|
|||
_decrease_refcnt()
|
||||
raise
|
||||
|
||||
if msg is None:
|
||||
return None
|
||||
# seems the library will inference whether it's text or binary, so we don't have guarantee
|
||||
if isinstance(msg, bytes):
|
||||
return msg.decode()
|
||||
else:
|
||||
return msg
|
||||
msg = msg.decode()
|
||||
return nni.load(msg)
|
||||
|
||||
def _wait(coro):
|
||||
# Synchronized version of "await".
|
||||
|
|
|
@ -11,19 +11,18 @@ __all__ = ['TunerCommandChannel']
|
|||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from threading import Event
|
||||
from typing import Any, Callable
|
||||
|
||||
from nni.common.serializer import dump, load, PayloadTooLarge
|
||||
from nni.runtime.command_channel.websocket import WsChannelClient
|
||||
from nni.typehint import Parameters
|
||||
|
||||
from .command_type import (
|
||||
CommandType, TunerIncomingCommand,
|
||||
Initialize, RequestTrialJobs, UpdateSearchSpace, ReportMetricData, TrialEnd, Terminate
|
||||
)
|
||||
from .websocket import WebSocket
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -52,10 +51,7 @@ class TunerCommandChannel:
|
|||
"""
|
||||
|
||||
def __init__(self, url: str):
|
||||
self._url = url
|
||||
self._channel = WebSocket(url)
|
||||
self._retry_intervals = [0, 1, 10]
|
||||
|
||||
self._channel = WsChannelClient(url)
|
||||
self._callbacks: dict[CommandType, list[Callable[..., None]]] = defaultdict(list)
|
||||
|
||||
def connect(self) -> None:
|
||||
|
@ -268,61 +264,14 @@ class TunerCommandChannel:
|
|||
self._callbacks[TrialEnd.command_type].append(callback)
|
||||
|
||||
def _send(self, command_type: CommandType, data: str) -> None:
|
||||
command = command_type.value.decode() + data
|
||||
try:
|
||||
self._channel.send(command)
|
||||
except Exception as e:
|
||||
_logger.warning('Exception on sending: %r', e)
|
||||
if not isinstance(e, WebSocket.ConnectionClosed):
|
||||
_logger.exception(e)
|
||||
self._retry_send(command)
|
||||
|
||||
def _retry_send(self, command: str) -> None:
|
||||
_logger.warning('Connection lost. Trying to reconnect...')
|
||||
for i, interval in enumerate(self._retry_intervals):
|
||||
_logger.info(f'Attempt #{i}, wait {interval} seconds...')
|
||||
time.sleep(interval)
|
||||
self._channel = WebSocket(self._url)
|
||||
self._channel.connect()
|
||||
try:
|
||||
self._channel.send(command)
|
||||
_logger.info('Reconnected.')
|
||||
return
|
||||
except Exception as e:
|
||||
_logger.exception(e)
|
||||
_logger.error('Failed to reconnect.')
|
||||
raise RuntimeError('Connection lost')
|
||||
self._channel.send({'type': command_type.value, 'content': data})
|
||||
|
||||
def _receive(self) -> tuple[CommandType, str] | tuple[None, None]:
|
||||
try:
|
||||
command = self._channel.receive()
|
||||
except Exception as e:
|
||||
_logger.warning('Exception on receiving: %r', e)
|
||||
if not isinstance(e, WebSocket.ConnectionClosed):
|
||||
_logger.exception(e)
|
||||
command = None
|
||||
command = self._channel.receive()
|
||||
if command is None:
|
||||
command = self._retry_receive()
|
||||
command_type = CommandType(command[:2].encode())
|
||||
return command_type, command[2:]
|
||||
|
||||
def _retry_receive(self) -> str:
|
||||
_logger.warning('Connection lost. Trying to reconnect...')
|
||||
for i, interval in enumerate(self._retry_intervals):
|
||||
_logger.info(f'Attempt #{i}, wait {interval} seconds...')
|
||||
time.sleep(interval)
|
||||
self._channel = WebSocket(self._url)
|
||||
self._channel.connect()
|
||||
try:
|
||||
command = self._channel.receive()
|
||||
except Exception as e:
|
||||
_logger.exception(e)
|
||||
command = None # for robustness
|
||||
if command is not None:
|
||||
_logger.info('Reconnected')
|
||||
return command
|
||||
_logger.error('Failed to reconnect.')
|
||||
raise RuntimeError('Connection lost')
|
||||
return None, None
|
||||
else:
|
||||
return CommandType(command['type']), command.get('content', '')
|
||||
|
||||
|
||||
def _validate_placement_constraint(placement_constraint):
|
||||
|
|
|
@ -10,23 +10,23 @@ from nni.utils import MetricType
|
|||
|
||||
class CommandType(Enum):
|
||||
# in
|
||||
Initialize = b'IN'
|
||||
RequestTrialJobs = b'GE'
|
||||
ReportMetricData = b'ME'
|
||||
UpdateSearchSpace = b'SS'
|
||||
ImportData = b'FD'
|
||||
AddCustomizedTrialJob = b'AD'
|
||||
TrialEnd = b'EN'
|
||||
Terminate = b'TE'
|
||||
Ping = b'PI'
|
||||
Initialize = 'IN'
|
||||
RequestTrialJobs = 'GE'
|
||||
ReportMetricData = 'ME'
|
||||
UpdateSearchSpace = 'SS'
|
||||
ImportData = 'FD'
|
||||
AddCustomizedTrialJob = 'AD'
|
||||
TrialEnd = 'EN'
|
||||
Terminate = 'TE'
|
||||
Ping = 'PI'
|
||||
|
||||
# out
|
||||
Initialized = b'ID'
|
||||
NewTrialJob = b'TR'
|
||||
SendTrialJobParameter = b'SP'
|
||||
NoMoreTrialJobs = b'NO'
|
||||
KillTrialJob = b'KI'
|
||||
Error = b'ER'
|
||||
Initialized = 'ID'
|
||||
NewTrialJob = 'TR'
|
||||
SendTrialJobParameter = 'SP'
|
||||
NoMoreTrialJobs = 'NO'
|
||||
KillTrialJob = 'KI'
|
||||
Error = 'ER'
|
||||
|
||||
class TunerIncomingCommand:
|
||||
# For type checking.
|
||||
|
|
|
@ -61,7 +61,7 @@ def send(command, data):
|
|||
try:
|
||||
_lock.acquire()
|
||||
data = data.encode('utf8')
|
||||
msg = b'%b%014d%b' % (command.value, len(data), data)
|
||||
msg = b'%b%014d%b' % (command.value.encode(), len(data), data)
|
||||
_logger.debug('Sending command, data: [%s]', msg)
|
||||
_out_file.write(msg)
|
||||
_out_file.flush()
|
||||
|
@ -81,7 +81,7 @@ def receive():
|
|||
return None, None
|
||||
length = int(header[2:])
|
||||
data = _in_file.read(length)
|
||||
command = CommandType(header[:2])
|
||||
command = CommandType(header[:2].decode())
|
||||
data = data.decode('utf8')
|
||||
_logger.debug('Received command, data: [%s]', data)
|
||||
return command, data
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from ..command_channel.websocket.connection import WsConnection as WebSocket # pylint: disable=unused-import
|
|
@ -150,7 +150,7 @@ stages:
|
|||
|
||||
- script: |
|
||||
set -e
|
||||
npm --prefix ts/nni_manager run test -- --exclude test/core/nnimanager.test.ts
|
||||
npm --prefix ts/nni_manager run test
|
||||
npm --prefix ts/nni_manager run test_nnimanager
|
||||
cp ts/nni_manager/coverage/cobertura-coverage.xml coverage/typescript.xml
|
||||
displayName: TypeScript unit test
|
||||
|
@ -198,7 +198,7 @@ stages:
|
|||
|
||||
- script: |
|
||||
export PATH=${PWD}/toolchain/node/bin:$PATH
|
||||
npm --prefix ts/nni_manager run test -- --exclude test/core/nnimanager.test.ts
|
||||
npm --prefix ts/nni_manager run test
|
||||
npm --prefix ts/nni_manager run test_nnimanager
|
||||
displayName: TypeScript unit test
|
||||
|
||||
|
@ -223,7 +223,7 @@ stages:
|
|||
|
||||
# temporarily disable this test, add it back after bug fixed
|
||||
- script: |
|
||||
npm --prefix ts/nni_manager run test -- --exclude test/core/nnimanager.test.ts
|
||||
npm --prefix ts/nni_manager run test
|
||||
npm --prefix ts/nni_manager run test_nnimanager
|
||||
displayName: TypeScript unit test
|
||||
|
||||
|
@ -249,7 +249,7 @@ stages:
|
|||
displayName: Python unit test
|
||||
|
||||
- script: |
|
||||
CI=true npm --prefix ts/nni_manager run test -- --exclude test/core/nnimanager.test.ts
|
||||
CI=true npm --prefix ts/nni_manager run test
|
||||
# # exclude nnimanager's ut because macos in pipeline is pretty slow
|
||||
# CI=true npm --prefix ts/nni_manager run test_nnimanager
|
||||
displayName: TypeScript unit test
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
"@typescript-eslint/no-non-null-assertion": 0,
|
||||
|
||||
"@typescript-eslint/no-unused-vars": [
|
||||
"off",
|
||||
"error",
|
||||
{
|
||||
"argsIgnorePattern": "^_"
|
||||
}
|
||||
|
|
|
@ -45,11 +45,9 @@
|
|||
|
||||
import util from 'node:util';
|
||||
|
||||
import type { Command } from 'common/command_channel/interface';
|
||||
import { DefaultMap } from 'common/default_map';
|
||||
import { Deferred } from 'common/deferred';
|
||||
import { Logger, getLogger } from 'common/log';
|
||||
import type { TrialKeeper } from 'common/trial_keeper/keeper';
|
||||
import type { CommandChannel } from './interface';
|
||||
|
||||
interface RpcResponseCommand {
|
||||
|
|
|
@ -155,6 +155,11 @@ export class WsChannel implements CommandChannel {
|
|||
this.emitter.on('__lost', callback);
|
||||
}
|
||||
|
||||
// TODO: temporary api for tuner command channel
|
||||
public getBufferedAmount(): number {
|
||||
return this.connection?.ws.bufferedAmount ?? 0;
|
||||
}
|
||||
|
||||
private newEpoch(): void {
|
||||
this.connection = null;
|
||||
this.epoch += 1;
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
* WebSocket command channel client.
|
||||
**/
|
||||
|
||||
import events from 'node:events';
|
||||
import { setTimeout } from 'node:timers/promises';
|
||||
|
||||
import { WebSocket } from 'ws';
|
||||
|
|
|
@ -29,7 +29,8 @@ export class WsConnection extends EventEmitter {
|
|||
private heartbeatTimer: NodeJS.Timer | null = null;
|
||||
private log: Logger;
|
||||
private missingPongs: number = 0;
|
||||
private ws: WebSocket; // NOTE: used in unit test
|
||||
|
||||
public readonly ws: WebSocket;
|
||||
|
||||
constructor(name: string, ws: WebSocket, commandEmitter: EventEmitter) {
|
||||
super();
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
import { EventEmitter } from 'events';
|
||||
|
||||
import type { Request } from 'express';
|
||||
import type { WebSocket } from 'ws';
|
||||
|
||||
import type { Command } from 'common/command_channel/interface';
|
||||
|
@ -32,7 +31,12 @@ export class WsChannelServer extends EventEmitter {
|
|||
|
||||
public async start(): Promise<void> {
|
||||
const channelPath = globals.rest.urlJoin(this.path, ':channel');
|
||||
globals.rest.registerWebSocketHandler(channelPath, this.handleConnection.bind(this));
|
||||
globals.rest.registerWebSocketHandler(this.path, (ws, _req) => {
|
||||
this.handleConnection('__default__', ws); // TODO: only used by tuner
|
||||
});
|
||||
globals.rest.registerWebSocketHandler(channelPath, (ws, req) => {
|
||||
this.handleConnection(req.params['channel'], ws);
|
||||
});
|
||||
this.log.debug('Start listening', channelPath);
|
||||
}
|
||||
|
||||
|
@ -84,8 +88,7 @@ export class WsChannelServer extends EventEmitter {
|
|||
this.on('connection', callback);
|
||||
}
|
||||
|
||||
private handleConnection(ws: WebSocket, req: Request): void {
|
||||
const channelId = req.params['channel'];
|
||||
private handleConnection(channelId: string, ws: WebSocket): void {
|
||||
this.log.debug('Incoming connection', channelId);
|
||||
|
||||
if (this.channels.has(channelId)) {
|
||||
|
|
|
@ -26,7 +26,7 @@ import util from 'util';
|
|||
|
||||
import { Logger, getLogger } from 'common/log';
|
||||
|
||||
const logger = getLogger('common.deferred');
|
||||
const logger: Logger = getLogger('common.deferred');
|
||||
|
||||
export class Deferred<T> {
|
||||
private resolveCallbacks: any[] = [];
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
import assert from 'assert';
|
||||
|
||||
import { KubeflowOperator, OperatorApiVersion } from '../training_service/kubernetes/kubeflow/kubeflowConfig'
|
||||
import { KubernetesStorageKind } from '../training_service/kubernetes/kubernetesConfig';
|
||||
|
||||
export interface TrainingServiceConfig {
|
||||
platform: string;
|
||||
|
|
|
@ -6,8 +6,8 @@
|
|||
* Functions will be added when used.
|
||||
**/
|
||||
|
||||
import express, { Request, Response, Router } from 'express';
|
||||
import type { Router as WsRouter } from 'express-ws';
|
||||
import express, { Express, Request, Response, Router } from 'express';
|
||||
import expressWs, { Router as WsRouter } from 'express-ws';
|
||||
import type { WebSocket } from 'ws';
|
||||
|
||||
type HttpMethod = 'GET' | 'PUT';
|
||||
|
@ -16,13 +16,23 @@ type ExpressCallback = (req: Request, res: Response) => void;
|
|||
type WebSocketCallback = (ws: WebSocket, req: Request) => void;
|
||||
|
||||
export class RestManager {
|
||||
private app: Express;
|
||||
private router: Router;
|
||||
|
||||
constructor() {
|
||||
// we don't actually need the app here,
|
||||
// but expressWs() must be called before router.ws(), and it requires an app instance
|
||||
this.app = express();
|
||||
expressWs(this.app, undefined, { wsOptions: { maxPayload: 4 * 1024 * 1024 * 1024 }});
|
||||
|
||||
this.router = Router();
|
||||
this.router.use(express.json({ limit: '50mb' }));
|
||||
}
|
||||
|
||||
public getExpressApp(): Express {
|
||||
return this.app;
|
||||
}
|
||||
|
||||
public getExpressRouter(): Router {
|
||||
return this.router;
|
||||
}
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
**/
|
||||
|
||||
import { EventEmitter } from 'node:events';
|
||||
import util from 'node:util';
|
||||
|
||||
import type { Command } from 'common/command_channel/interface';
|
||||
import { RpcHelper, getRpcHelper } from 'common/command_channel/rpc_util';
|
||||
|
|
|
@ -9,7 +9,6 @@ import { ChildProcess, spawn, StdioOptions } from 'child_process';
|
|||
import dgram from 'dgram';
|
||||
import fs from 'fs';
|
||||
import net from 'net';
|
||||
import os from 'os';
|
||||
import path from 'path';
|
||||
import * as timersPromises from 'timers/promises';
|
||||
import { Deferred } from 'ts-deferred';
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
import { IpcInterface } from './tuner_command_channel/common';
|
||||
export { IpcInterface } from './tuner_command_channel/common';
|
||||
import * as shim from './tuner_command_channel/shim';
|
||||
import { IpcInterface, getTunerServer } from './tuner_command_channel';
|
||||
export { IpcInterface } from './tuner_command_channel';
|
||||
|
||||
let tunerDisabled: boolean = false;
|
||||
|
||||
export async function createDispatcherInterface(): Promise<IpcInterface> {
|
||||
if (!tunerDisabled) {
|
||||
return await shim.createDispatcherInterface();
|
||||
return getTunerServer();
|
||||
} else {
|
||||
return new DummyIpcInterface();
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ import {
|
|||
NNIManagerStatus, ProfileUpdateType, TrialJobStatistics
|
||||
} from '../common/manager';
|
||||
import {
|
||||
ExperimentConfig, LocalConfig, TrainingServiceConfig, toSeconds, toCudaVisibleDevices
|
||||
ExperimentConfig, TrainingServiceConfig, toSeconds, toCudaVisibleDevices
|
||||
} from '../common/experimentConfig';
|
||||
import { getExperimentsManager } from 'extensions/experiments_manager';
|
||||
import { TensorboardManager } from '../common/tensorboardManager';
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
import { EventEmitter } from 'node:events';
|
||||
|
||||
import { WsChannel, WsChannelServer } from 'common/command_channel/websocket';
|
||||
import { Deferred } from 'common/deferred';
|
||||
import { Logger, getLogger } from 'common/log';
|
||||
|
||||
export interface IpcInterface {
|
||||
init(): Promise<void>;
|
||||
sendCommand(commandType: string, content?: string): void;
|
||||
onCommand(listener: (commandType: string, content: string) => void): void;
|
||||
onError(listener: (error: Error) => void): void;
|
||||
}
|
||||
|
||||
export function getTunerServer(): IpcInterface {
|
||||
return server;
|
||||
}
|
||||
|
||||
const logger: Logger = getLogger('tuner_command_channel');
|
||||
|
||||
class TunerServer {
|
||||
private channel!: WsChannel;
|
||||
private connect: Deferred<void> = new Deferred();
|
||||
private emitter: EventEmitter = new EventEmitter();
|
||||
private server: WsChannelServer;
|
||||
|
||||
constructor() {
|
||||
this.server = new WsChannelServer('tuner', 'tuner');
|
||||
this.server.onConnection((_channelId, channel) => {
|
||||
this.channel = channel;
|
||||
this.channel.onError(error => {
|
||||
this.emitter.emit('error', error);
|
||||
});
|
||||
this.channel.onReceive(command => {
|
||||
if (command.type === 'ER') {
|
||||
this.emitter.emit('error', new Error(command.content));
|
||||
} else {
|
||||
this.emitter.emit('command', command.type, command.content ?? '');
|
||||
}
|
||||
});
|
||||
this.connect.resolve();
|
||||
});
|
||||
this.server.start();
|
||||
}
|
||||
|
||||
public init(): Promise<void> { // wait connection
|
||||
if (this.connect.settled) {
|
||||
logger.debug('Initialized.');
|
||||
return Promise.resolve();
|
||||
} else {
|
||||
logger.debug('Waiting connection...');
|
||||
// TODO: This is a quick fix. It should check tuner's process status instead.
|
||||
setTimeout(() => {
|
||||
if (!this.connect.settled) {
|
||||
const msg = 'Tuner did not connect in 10 seconds. Please check tuner (dispatcher) log.';
|
||||
this.connect.reject(new Error('tuner_command_channel: ' + msg));
|
||||
}
|
||||
}, 10000);
|
||||
return this.connect.promise;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: for unit test only
|
||||
public async stop(): Promise<void> {
|
||||
await this.server.shutdown();
|
||||
}
|
||||
|
||||
public sendCommand(commandType: string, content?: string): void {
|
||||
if (commandType === 'PI') { // ping is handled with WebSocket protocol
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.channel.getBufferedAmount() > 1000) {
|
||||
logger.warning('Sending too fast! Try to reduce the frequency of intermediate results.');
|
||||
}
|
||||
|
||||
this.channel.send({ type: commandType, content });
|
||||
|
||||
if (commandType === 'TE') {
|
||||
this.channel.close('TE command');
|
||||
this.server.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
public onCommand(listener: (commandType: string, content: string) => void): void {
|
||||
this.emitter.on('command', listener);
|
||||
}
|
||||
|
||||
public onError(listener: (error: Error) => void): void {
|
||||
this.emitter.on('error', listener);
|
||||
}
|
||||
}
|
||||
|
||||
let server: TunerServer = new TunerServer();
|
||||
|
||||
export namespace UnitTestHelpers {
|
||||
export function reset(): void {
|
||||
server = new TunerServer();
|
||||
}
|
||||
|
||||
export async function stop(): Promise<void> {
|
||||
await server.stop();
|
||||
}
|
||||
}
|
|
@ -1,9 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
export interface IpcInterface {
|
||||
init(): Promise<void>;
|
||||
sendCommand(commandType: string, content?: string): void;
|
||||
onCommand(listener: (commandType: string, content: string) => void): void;
|
||||
onError(listener: (error: Error) => void): void;
|
||||
}
|
|
@ -1,4 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
export { getWebSocketChannel, serveWebSocket } from './websocket_channel';
|
|
@ -1,52 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
import type { IpcInterface } from './common';
|
||||
import { WebSocketChannel, getWebSocketChannel } from './websocket_channel';
|
||||
|
||||
export async function createDispatcherInterface(): Promise<IpcInterface> {
|
||||
return new WsIpcInterface();
|
||||
}
|
||||
|
||||
class WsIpcInterface implements IpcInterface {
|
||||
private channel: WebSocketChannel = getWebSocketChannel();
|
||||
private commandListener?: (commandType: string, content: string) => void;
|
||||
private errorListener?: (error: Error) => void;
|
||||
|
||||
constructor() {
|
||||
this.channel.onCommand((command: string) => {
|
||||
const commandType = command.slice(0, 2);
|
||||
const content = command.slice(2);
|
||||
if (commandType === 'ER') {
|
||||
if (this.errorListener !== undefined) {
|
||||
this.errorListener(new Error(content));
|
||||
}
|
||||
} else {
|
||||
if (this.commandListener !== undefined) {
|
||||
this.commandListener(commandType, content);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public async init(): Promise<void> {
|
||||
await this.channel.init();
|
||||
}
|
||||
|
||||
public sendCommand(commandType: string, content: string = ''): void {
|
||||
if (commandType !== 'PI') { // ping is handled with WebSocket protocol
|
||||
this.channel.sendCommand(commandType + content);
|
||||
if (commandType === 'TE') {
|
||||
this.channel.shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public onCommand(listener: (commandType: string, content: string) => void): void {
|
||||
this.commandListener = listener;
|
||||
}
|
||||
|
||||
public onError(listener: (error: Error) => void): void {
|
||||
this.errorListener = listener;
|
||||
}
|
||||
}
|
|
@ -1,206 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
/**
|
||||
* The IPC channel between NNI manager and tuner.
|
||||
*
|
||||
* TODO:
|
||||
* 1. Merge with environment service's WebSocket channel.
|
||||
* 2. Split import data command to avoid extremely long message.
|
||||
* 3. Refactor message format.
|
||||
**/
|
||||
|
||||
import assert from 'assert/strict';
|
||||
import { EventEmitter } from 'events';
|
||||
|
||||
import type WebSocket from 'ws';
|
||||
|
||||
import { Deferred } from 'common/deferred';
|
||||
import { Logger, getLogger } from 'common/log';
|
||||
|
||||
const logger: Logger = getLogger('tuner_command_channel.WebSocketChannel');
|
||||
|
||||
export interface WebSocketChannel {
|
||||
init(): Promise<void>;
|
||||
shutdown(): Promise<void>;
|
||||
sendCommand(command: string): void; // maybe this should return Promise<void>
|
||||
onCommand(callback: (command: string) => void): void;
|
||||
onError(callback: (error: Error) => void): void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the singleton tuner command channel.
|
||||
* Remember to invoke ``await channel.init()`` before doing anything else.
|
||||
**/
|
||||
export function getWebSocketChannel(): WebSocketChannel {
|
||||
return channelSingleton;
|
||||
}
|
||||
|
||||
/**
|
||||
* The callback to serve WebSocket connection request. Used by REST server module.
|
||||
* If it is invoked more than once, the previous connection will be dropped.
|
||||
**/
|
||||
export function serveWebSocket(ws: WebSocket): void {
|
||||
channelSingleton.serveWebSocket(ws);
|
||||
}
|
||||
|
||||
class WebSocketChannelImpl implements WebSocketChannel {
|
||||
private deferredInit: Deferred<void> = new Deferred<void>();
|
||||
private emitter: EventEmitter = new EventEmitter();
|
||||
private heartbeatTimer!: NodeJS.Timer;
|
||||
private serving: boolean = false;
|
||||
private waitingPong: boolean = false;
|
||||
private ws!: WebSocket;
|
||||
|
||||
public serveWebSocket(ws: WebSocket): void {
|
||||
if (this.ws === undefined) {
|
||||
logger.debug('Connected.');
|
||||
} else {
|
||||
logger.warning('Reconnecting. Drop previous connection.');
|
||||
this.dropConnection('Reconnected');
|
||||
}
|
||||
|
||||
this.serving = true;
|
||||
|
||||
this.ws = ws;
|
||||
this.ws.on('close', this.handleWsClose);
|
||||
this.ws.on('error', this.handleWsError);
|
||||
this.ws.on('message', this.handleWsMessage);
|
||||
this.ws.on('pong', this.handleWsPong);
|
||||
|
||||
this.heartbeatTimer = setInterval(this.heartbeat.bind(this), heartbeatInterval);
|
||||
this.deferredInit.resolve();
|
||||
}
|
||||
|
||||
public init(): Promise<void> {
|
||||
if (this.ws === undefined) {
|
||||
logger.debug('Waiting connection...');
|
||||
// TODO: This is a quick fix. It should check tuner's process status instead.
|
||||
setTimeout(() => {
|
||||
if (!this.deferredInit.settled) {
|
||||
const msg = 'Tuner did not connect in 10 seconds. Please check tuner (dispatcher) log.';
|
||||
this.deferredInit.reject(new Error('tuner_command_channel: ' + msg));
|
||||
}
|
||||
}, 10000);
|
||||
return this.deferredInit.promise;
|
||||
|
||||
} else {
|
||||
logger.debug('Initialized.');
|
||||
return Promise.resolve();
|
||||
}
|
||||
}
|
||||
|
||||
public async shutdown(): Promise<void> {
|
||||
if (this.ws === undefined) {
|
||||
return;
|
||||
}
|
||||
clearInterval(this.heartbeatTimer);
|
||||
this.serving = false;
|
||||
this.emitter.removeAllListeners();
|
||||
}
|
||||
|
||||
public sendCommand(command: string): void {
|
||||
assert.ok(this.ws !== undefined);
|
||||
|
||||
logger.debug('Sending', command);
|
||||
this.ws.send(command);
|
||||
|
||||
if (this.ws.bufferedAmount > command.length + 1000) {
|
||||
logger.warning('Sending too fast! Try to reduce the frequency of intermediate results.');
|
||||
}
|
||||
}
|
||||
|
||||
public onCommand(callback: (command: string) => void): void {
|
||||
this.emitter.on('command', callback);
|
||||
}
|
||||
|
||||
public onError(callback: (error: Error) => void): void {
|
||||
this.emitter.on('error', callback);
|
||||
}
|
||||
|
||||
/* Following callbacks must be auto-binded arrow functions to be turned off */
|
||||
|
||||
private handleWsClose = (): void => {
|
||||
this.handleError(new Error('tuner_command_channel: Tuner closed connection'));
|
||||
}
|
||||
|
||||
private handleWsError = (error: Error): void => {
|
||||
this.handleError(error);
|
||||
}
|
||||
|
||||
private handleWsMessage = (data: Buffer, _isBinary: boolean): void => {
|
||||
this.receive(data);
|
||||
}
|
||||
|
||||
private handleWsPong = (): void => {
|
||||
this.waitingPong = false;
|
||||
}
|
||||
|
||||
private dropConnection(reason: string): void {
|
||||
if (this.ws === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.serving = false;
|
||||
this.waitingPong = false;
|
||||
clearInterval(this.heartbeatTimer);
|
||||
|
||||
this.ws.off('close', this.handleWsClose);
|
||||
this.ws.off('error', this.handleWsError);
|
||||
this.ws.off('message', this.handleWsMessage);
|
||||
this.ws.off('pong', this.handleWsPong);
|
||||
|
||||
this.ws.on('close', () => {
|
||||
logger.info('Connection dropped');
|
||||
});
|
||||
this.ws.on('message', (data, _isBinary) => {
|
||||
logger.error('Received message after reconnect:', data);
|
||||
});
|
||||
this.ws.on('pong', () => {
|
||||
logger.error('Received pong after reconnect.');
|
||||
});
|
||||
this.ws.close(1001, reason);
|
||||
}
|
||||
|
||||
private heartbeat(): void {
|
||||
// if (this.waitingPong) {
|
||||
// this.ws.terminate(); // this will trigger "close" event
|
||||
// this.handleError(new Error('tuner_command_channel: Tuner loses responsive'));
|
||||
// }
|
||||
|
||||
this.waitingPong = true;
|
||||
this.ws.ping();
|
||||
}
|
||||
|
||||
private receive(data: Buffer): void {
|
||||
logger.debug('Received', data);
|
||||
this.emitter.emit('command', data.toString());
|
||||
}
|
||||
|
||||
private handleError(error: Error): void {
|
||||
if (!this.serving) {
|
||||
logger.debug('Silent error:', error);
|
||||
return;
|
||||
}
|
||||
logger.error('Error:', error);
|
||||
|
||||
clearInterval(this.heartbeatTimer);
|
||||
this.emitter.emit('error', error);
|
||||
this.serving = false;
|
||||
}
|
||||
}
|
||||
|
||||
let channelSingleton: WebSocketChannelImpl = new WebSocketChannelImpl();
|
||||
|
||||
let heartbeatInterval: number = 5000;
|
||||
|
||||
export namespace UnitTestHelpers {
|
||||
export function setHeartbeatInterval(ms: number): void {
|
||||
heartbeatInterval = ms;
|
||||
}
|
||||
// NOTE: this function is only for unittest of nnimanager,
|
||||
// because resuming an experiment should reset websocket channel.
|
||||
export function resetChannelSingleton(): void {
|
||||
channelSingleton = new WebSocketChannelImpl();
|
||||
}
|
||||
}
|
|
@ -3,8 +3,6 @@
|
|||
|
||||
import assert from 'assert/strict';
|
||||
import fs from 'fs';
|
||||
import os from 'os';
|
||||
import path from 'path';
|
||||
import * as timersPromises from 'timers/promises';
|
||||
|
||||
import { Deferred } from 'ts-deferred';
|
||||
|
|
|
@ -8,8 +8,6 @@ import * as timersPromises from 'timers/promises';
|
|||
import glob from 'glob';
|
||||
import lockfile from 'lockfile';
|
||||
|
||||
import globals from 'common/globals';
|
||||
|
||||
const lockStale: number = 2000;
|
||||
const retry: number = 100;
|
||||
|
||||
|
|
|
@ -21,13 +21,13 @@
|
|||
|
||||
import 'app-module-path/register'; // so we can use absolute path to import
|
||||
|
||||
import fs from 'fs';
|
||||
|
||||
import { Container, Scope } from 'typescript-ioc';
|
||||
|
||||
import { globals, initGlobals } from 'common/globals';
|
||||
initGlobals();
|
||||
|
||||
import * as component from 'common/component';
|
||||
import { Database, DataStore } from 'common/datastore';
|
||||
import globals, { initGlobals } from 'common/globals';
|
||||
import { Logger, getLogger } from 'common/log';
|
||||
import { Manager } from 'common/manager';
|
||||
import { TensorboardManager } from 'common/tensorboardManager';
|
||||
|
@ -71,8 +71,6 @@ process.on('SIGINT', () => { globals.shutdown.initiate('SIGINT'); });
|
|||
|
||||
/* main */
|
||||
|
||||
initGlobals();
|
||||
|
||||
start().then(() => {
|
||||
logger.debug('start() returned.');
|
||||
}).catch((error) => {
|
||||
|
|
|
@ -4,8 +4,8 @@
|
|||
"license": "MIT",
|
||||
"scripts": {
|
||||
"build": "tsc",
|
||||
"test": "nyc --reporter=cobertura --reporter=text mocha \"test/**/*.test.ts\"",
|
||||
"test_nnimanager": "nyc --reporter=cobertura --reporter=text mocha test/core/nnimanager.test.ts",
|
||||
"test": "nyc --reporter=cobertura --reporter=text mocha \"test/**/*.test.ts\" --exclude test/core/nnimanager.test.ts",
|
||||
"test_nnimanager": "mocha test/core/nnimanager.test.ts",
|
||||
"mocha": "mocha",
|
||||
"eslint": "eslint . --ext .ts"
|
||||
},
|
||||
|
|
|
@ -13,9 +13,6 @@ import assert from 'node:assert/strict';
|
|||
import type { Server } from 'node:http';
|
||||
import type { AddressInfo } from 'node:net';
|
||||
|
||||
import express from 'express';
|
||||
import expressWs from 'express-ws';
|
||||
|
||||
import { Deferred } from 'common/deferred';
|
||||
import { globals } from 'common/globals';
|
||||
import { Logger, getLogger } from 'common/log';
|
||||
|
@ -40,9 +37,7 @@ export class RestServerCore {
|
|||
public start(): Promise<void> {
|
||||
logger.info(`Starting REST server at port ${this.port}, URL prefix: "/${this.urlPrefix}"`);
|
||||
|
||||
const app = express();
|
||||
expressWs(app, undefined, { wsOptions: { maxPayload: 4 * 1024 * 1024 * 1024 }});
|
||||
|
||||
const app = globals.rest.getExpressApp();
|
||||
app.use('/' + this.urlPrefix, globals.rest.getExpressRouter());
|
||||
app.all('/' + this.urlPrefix, (_req, res) => { res.status(404).send('Not Found'); });
|
||||
app.all('*', (_req, res) => { res.status(404).send(`Outside prefix "/${this.urlPrefix}"`); });
|
||||
|
|
|
@ -26,13 +26,11 @@ import type { AddressInfo } from 'net';
|
|||
import path from 'path';
|
||||
|
||||
import express, { Request, Response, Router } from 'express';
|
||||
import expressWs from 'express-ws';
|
||||
import httpProxy from 'http-proxy';
|
||||
|
||||
import { Deferred } from 'common/deferred';
|
||||
import globals from 'common/globals';
|
||||
import { Logger, getLogger } from 'common/log';
|
||||
import * as tunerCommandChannel from 'core/tuner_command_channel';
|
||||
|
||||
const logger: Logger = getLogger('RestServer');
|
||||
|
||||
|
@ -60,9 +58,7 @@ export class RestServer {
|
|||
public start(): Promise<void> {
|
||||
logger.info(`Starting REST server at port ${this.port}, URL prefix: "/${this.urlPrefix}"`);
|
||||
|
||||
const app = express();
|
||||
expressWs(app, undefined, { wsOptions: { maxPayload: 4 * 1024 * 1024 * 1024 }});
|
||||
|
||||
const app = globals.rest.getExpressApp();
|
||||
app.use('/' + this.urlPrefix, mainRouter());
|
||||
app.use('/' + this.urlPrefix, fallbackRouter());
|
||||
app.all('*', (_req: Request, res: Response) => { res.status(404).send(`Outside prefix "/${this.urlPrefix}"`); });
|
||||
|
@ -110,10 +106,7 @@ export class RestServer {
|
|||
* In fact experiments management should have a separate prefix and module.
|
||||
**/
|
||||
function mainRouter(): Router {
|
||||
const router = globals.rest.getExpressRouter() as expressWs.Router;
|
||||
|
||||
/* WebSocket APIs */
|
||||
router.ws('/tuner', (ws, _req, _next) => { tunerCommandChannel.serveWebSocket(ws); });
|
||||
const router = globals.rest.getExpressRouter();
|
||||
|
||||
/* Download log files */
|
||||
// The REST API path "/logs" does not match file system path "/log".
|
||||
|
@ -142,6 +135,9 @@ function fallbackRouter(): Router {
|
|||
|
||||
/* 404 as catch-all */
|
||||
router.all('*', (_req: Request, res: Response) => { res.status(404).send('Not Found'); });
|
||||
|
||||
// TODO: websocket 404
|
||||
|
||||
return router;
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@ import { getLogger, Logger } from '../common/log';
|
|||
import { ExperimentProfile, Manager, TrialJobStatistics } from '../common/manager';
|
||||
import { getExperimentsManager } from 'extensions/experiments_manager';
|
||||
import { TensorboardManager, TensorboardTaskInfo } from '../common/tensorboardManager';
|
||||
import { ValidationSchemas } from './restValidationSchemas';
|
||||
import { getVersion } from '../common/utils';
|
||||
import { MetricType } from '../common/datastore';
|
||||
import { ProfileUpdateType } from '../common/manager';
|
||||
|
@ -294,11 +293,7 @@ class NNIRestHandler {
|
|||
|
||||
private getTrialFile(router: Router): void {
|
||||
router.get('/trial-file/:id/:filename', async(req: Request, res: Response) => {
|
||||
let encoding: string | null = null;
|
||||
const filename = req.params['filename'];
|
||||
if (!filename.includes('.') || filename.match(/.*\.(txt|log)/g)) {
|
||||
encoding = 'utf8';
|
||||
}
|
||||
this.nniManager.getTrialFile(req.params['id'], filename).then((content: Buffer | string) => {
|
||||
const contentType = content instanceof Buffer ? 'application/octet-stream' : 'text/plain';
|
||||
res.header('Content-Type', contentType);
|
||||
|
|
|
@ -1,121 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
'use strict';
|
||||
|
||||
import * as assert from 'assert';
|
||||
import { ChildProcess, spawn, StdioOptions } from 'child_process';
|
||||
import { Deferred } from 'ts-deferred';
|
||||
import { cleanupUnitTest, prepareUnitTest, getTunerProc } from '../../common/utils';
|
||||
import * as CommandType from '../../core/commands';
|
||||
import { createDispatcherInterface, IpcInterface } from '../../core/ipcInterface';
|
||||
import { NNIError } from '../../common/errors';
|
||||
|
||||
let sentCommands: { [key: string]: string }[] = [];
|
||||
const receivedCommands: { [key: string]: string }[] = [];
|
||||
|
||||
let rejectCommandType: Error | undefined;
|
||||
|
||||
async function runProcess(): Promise<Error | null> {
|
||||
// the process is intended to throw error, do not reject
|
||||
const deferred: Deferred<Error | null> = new Deferred<Error | null>();
|
||||
|
||||
// create fake assessor process
|
||||
const stdio: StdioOptions = ['ignore', 'pipe', process.stderr, 'pipe', 'pipe'];
|
||||
const command: string[] = [ 'python', 'assessor.py' ];
|
||||
const proc: ChildProcess = getTunerProc(command, stdio, 'core/test', process.env);
|
||||
// record its sent/received commands on exit
|
||||
proc.on('error', (error: Error): void => { deferred.resolve(error); });
|
||||
proc.on('exit', (code: number): void => {
|
||||
if (code !== 0) {
|
||||
deferred.resolve(new Error(`return code: ${code}`));
|
||||
} else {
|
||||
let str = proc.stdout!.read().toString();
|
||||
if(str.search("\r\n")!=-1){
|
||||
sentCommands = str.split("\r\n");
|
||||
}
|
||||
else{
|
||||
sentCommands = str.split('\n');
|
||||
}
|
||||
deferred.resolve(null);
|
||||
}
|
||||
});
|
||||
|
||||
// create IPC interface
|
||||
const dispatcher: IpcInterface = await createDispatcherInterface();
|
||||
dispatcher.onCommand((commandType: string, content: string): void => {
|
||||
receivedCommands.push({ commandType, content });
|
||||
});
|
||||
|
||||
// Command #1: ok
|
||||
dispatcher.sendCommand('IN');
|
||||
|
||||
// Command #2: ok
|
||||
dispatcher.sendCommand('ME', '123');
|
||||
|
||||
// Command #3: FE is not tuner/assessor command, test the exception type of send non-valid command
|
||||
try {
|
||||
dispatcher.sendCommand('FE', '1');
|
||||
} catch (error) {
|
||||
rejectCommandType = error as Error;
|
||||
}
|
||||
|
||||
return deferred.promise;
|
||||
}
|
||||
|
||||
/* FIXME
|
||||
describe('core/protocol', (): void => {
|
||||
|
||||
before(async () => {
|
||||
prepareUnitTest();
|
||||
await runProcess();
|
||||
});
|
||||
|
||||
after(() => {
|
||||
cleanupUnitTest();
|
||||
});
|
||||
|
||||
it('should have sent 2 successful commands', (): void => {
|
||||
assert.equal(sentCommands.length, 3);
|
||||
assert.equal(sentCommands[2], '');
|
||||
});
|
||||
|
||||
it('sendCommand() should work without content', (): void => {
|
||||
assert.equal(sentCommands[0], "('IN', '')");
|
||||
});
|
||||
|
||||
it('sendCommand() should work with content', (): void => {
|
||||
assert.equal(sentCommands[1], "('ME', '123')");
|
||||
});
|
||||
|
||||
it('sendCommand() should throw on wrong command type', (): void => {
|
||||
assert.equal((<Error>rejectCommandType).name.split(' ')[0], 'AssertionError');
|
||||
});
|
||||
|
||||
it('should have received 3 commands', (): void => {
|
||||
assert.equal(receivedCommands.length, 3);
|
||||
});
|
||||
|
||||
it('onCommand() should work without content', (): void => {
|
||||
assert.deepStrictEqual(receivedCommands[0], {
|
||||
commandType: 'KI',
|
||||
content: ''
|
||||
});
|
||||
});
|
||||
|
||||
it('onCommand() should work with content', (): void => {
|
||||
assert.deepStrictEqual(receivedCommands[1], {
|
||||
commandType: 'KI',
|
||||
content: 'hello'
|
||||
});
|
||||
});
|
||||
|
||||
it('onCommand() should work with Unicode content', (): void => {
|
||||
assert.deepStrictEqual(receivedCommands[2], {
|
||||
commandType: 'KI',
|
||||
content: '世界'
|
||||
});
|
||||
});
|
||||
|
||||
});
|
||||
*/
|
|
@ -1,105 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
'use strict';
|
||||
|
||||
import * as assert from 'assert';
|
||||
import { ChildProcess, spawn, StdioOptions } from 'child_process';
|
||||
import { Deferred } from 'ts-deferred';
|
||||
import { cleanupUnitTest, prepareUnitTest, getMsgDispatcherCommand, getTunerProc } from '../../common/utils';
|
||||
import * as CommandType from '../../core/commands';
|
||||
import { createDispatcherInterface, IpcInterface } from '../../core/ipcInterface';
|
||||
|
||||
let dispatcher: IpcInterface | undefined;
|
||||
let procExit: boolean = false;
|
||||
let procError: boolean = false;
|
||||
|
||||
async function startProcess(): Promise<void> {
|
||||
// create fake assessor process
|
||||
const stdio: StdioOptions = ['ignore', 'pipe', process.stderr, 'pipe', 'pipe'];
|
||||
|
||||
const dispatcherCmd: string[] = getMsgDispatcherCommand(
|
||||
// Mock tuner config
|
||||
<any>{
|
||||
experimentName: 'exp1',
|
||||
maxExperimentDuration: '1h',
|
||||
searchSpace: '',
|
||||
trainingService: {
|
||||
platform: 'local'
|
||||
},
|
||||
trialConcurrency: 1,
|
||||
maxTrialNumber: 5,
|
||||
tuner: {
|
||||
className: 'dummy_tuner.DummyTuner',
|
||||
codeDirectory: '.'
|
||||
},
|
||||
assessor: {
|
||||
className: 'dummy_assessor.DummyAssessor',
|
||||
codeDirectory: '.'
|
||||
},
|
||||
trialCommand: '',
|
||||
trialCodeDirectory: '',
|
||||
debug: true
|
||||
}
|
||||
);
|
||||
const proc: ChildProcess = getTunerProc(dispatcherCmd, stdio, 'core/test', process.env);
|
||||
proc.on('error', (_error: Error): void => {
|
||||
procExit = true;
|
||||
procError = true;
|
||||
});
|
||||
proc.on('exit', (code: number): void => {
|
||||
procExit = true;
|
||||
procError = (code !== 0);
|
||||
});
|
||||
|
||||
// create IPC interface
|
||||
dispatcher = await createDispatcherInterface();
|
||||
(<IpcInterface>dispatcher).onCommand((commandType: string, content: string): void => {
|
||||
console.log(commandType, content);
|
||||
});
|
||||
}
|
||||
|
||||
/* FIXME
|
||||
describe('core/ipcInterface.terminate', (): void => {
|
||||
before(() => {
|
||||
prepareUnitTest();
|
||||
startProcess();
|
||||
});
|
||||
|
||||
after(() => {
|
||||
cleanupUnitTest();
|
||||
});
|
||||
|
||||
it('normal', () => {
|
||||
(<IpcInterface>dispatcher).sendCommand(
|
||||
CommandType.REPORT_METRIC_DATA,
|
||||
'{"trial_job_id":"A","type":"PERIODICAL","value":1,"sequence":123}');
|
||||
|
||||
const deferred: Deferred<void> = new Deferred<void>();
|
||||
setTimeout(
|
||||
() => {
|
||||
assert.ok(!procExit);
|
||||
assert.ok(!procError);
|
||||
deferred.resolve();
|
||||
},
|
||||
1000);
|
||||
|
||||
return deferred.promise;
|
||||
});
|
||||
|
||||
it('terminate', () => {
|
||||
(<IpcInterface>dispatcher).sendCommand(CommandType.TERMINATE);
|
||||
|
||||
const deferred: Deferred<void> = new Deferred<void>();
|
||||
setTimeout(
|
||||
() => {
|
||||
assert.ok(procExit);
|
||||
assert.ok(!procError);
|
||||
deferred.resolve();
|
||||
},
|
||||
10000);
|
||||
|
||||
return deferred.promise;
|
||||
});
|
||||
});
|
||||
*/
|
|
@ -22,7 +22,7 @@ import { NNITensorboardManager } from '../../extensions/nniTensorboardManager';
|
|||
import * as path from 'path';
|
||||
import { RestServer } from '../../rest_server';
|
||||
import globals from '../../common/globals/unittest';
|
||||
import { UnitTestHelpers } from '../../core/tuner_command_channel/websocket_channel';
|
||||
import { UnitTestHelpers } from '../../core/tuner_command_channel';
|
||||
import * as timersPromises from 'timers/promises';
|
||||
|
||||
let nniManager: NNIManager;
|
||||
|
@ -324,7 +324,7 @@ async function resumeExperiment(): Promise<void> {
|
|||
// globals.showLog();
|
||||
// explicitly reset the websocket channel because it is singleton, does not work when two experiments
|
||||
// (one is start and the other is resume) run in the same process.
|
||||
UnitTestHelpers.resetChannelSingleton();
|
||||
UnitTestHelpers.reset();
|
||||
await initContainer('resume');
|
||||
nniManager = component.get(Manager);
|
||||
|
||||
|
|
|
@ -1,131 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
import assert from 'assert/strict';
|
||||
import { setTimeout } from 'timers/promises';
|
||||
|
||||
import WebSocket from 'ws';
|
||||
|
||||
import { Deferred } from 'common/deferred';
|
||||
import { getWebSocketChannel, serveWebSocket } from 'core/tuner_command_channel';
|
||||
import { UnitTestHelpers } from 'core/tuner_command_channel/websocket_channel';
|
||||
|
||||
const heartbeatInterval: number = 10;
|
||||
|
||||
// for testError, must be set before serveWebSocket()
|
||||
UnitTestHelpers.setHeartbeatInterval(heartbeatInterval);
|
||||
|
||||
/* test cases */
|
||||
|
||||
// Start serving and let a client connect.
|
||||
async function testInit(): Promise<void> {
|
||||
const channel = getWebSocketChannel();
|
||||
channel.onCommand(command => { serverReceived.push(command); });
|
||||
channel.onError(error => { catchedError = error; });
|
||||
|
||||
server.on('connection', serveWebSocket);
|
||||
client1 = new Client('client1');
|
||||
await channel.init();
|
||||
}
|
||||
|
||||
// Send commands from server to client.
|
||||
async function testSend(client: Client): Promise<void> {
|
||||
const channel = getWebSocketChannel();
|
||||
|
||||
channel.sendCommand(command1);
|
||||
channel.sendCommand(command2);
|
||||
await setTimeout(heartbeatInterval);
|
||||
|
||||
assert.deepEqual(client.received, [command1, command2]);
|
||||
}
|
||||
|
||||
// Send commands from client to server.
|
||||
async function testReceive(client: Client): Promise<void> {
|
||||
serverReceived.length = 0;
|
||||
|
||||
client.ws.send(command2);
|
||||
client.ws.send(command1);
|
||||
await setTimeout(heartbeatInterval);
|
||||
|
||||
assert.deepEqual(serverReceived, [command2, command1]);
|
||||
}
|
||||
|
||||
// Simulate client side crash.
|
||||
async function testError(): Promise<void> {
|
||||
if (process.platform !== 'linux') {
|
||||
// it is performance sensitive for the test case to yield error,
|
||||
// but windows & mac agents of devops are too slow
|
||||
client1.ws.terminate();
|
||||
return;
|
||||
}
|
||||
|
||||
// we have set heartbeat interval to 10ms, so pause for 30ms should make it timeout
|
||||
client1.ws.pause();
|
||||
await setTimeout(heartbeatInterval * 3);
|
||||
client1.ws.resume();
|
||||
|
||||
assert.notEqual(catchedError, undefined);
|
||||
}
|
||||
|
||||
// If the client losses connection by accident but not crashed, it will reconnect.
|
||||
async function testReconnect(): Promise<void> {
|
||||
client2 = new Client('client2');
|
||||
await client2.deferred.promise;
|
||||
}
|
||||
|
||||
// Clean up.
|
||||
async function testShutdown(): Promise<void> {
|
||||
const channel = getWebSocketChannel();
|
||||
await channel.shutdown();
|
||||
|
||||
client1.ws.close();
|
||||
client2.ws.close();
|
||||
server.close();
|
||||
}
|
||||
|
||||
/* register */
|
||||
describe('## tuner_command_channel ##', () => {
|
||||
it('init', testInit);
|
||||
|
||||
it('send', () => testSend(client1));
|
||||
it('receive', () => testReceive(client1));
|
||||
|
||||
// it('mock timeout', testError);
|
||||
it('reconnect', testReconnect);
|
||||
|
||||
it('send after reconnect', () => testSend(client2));
|
||||
it('receive after reconnect', () => testReceive(client2));
|
||||
|
||||
it('shutdown', testShutdown);
|
||||
});
|
||||
|
||||
/** helpers **/
|
||||
|
||||
const command1 = 'T_hello world';
|
||||
const command2 = 'T_你好';
|
||||
|
||||
const server = new WebSocket.Server({ port: 0 });
|
||||
let client1!: Client;
|
||||
let client2!: Client;
|
||||
|
||||
const serverReceived: string[] = [];
|
||||
let catchedError: Error | undefined;
|
||||
|
||||
class Client {
|
||||
name: string;
|
||||
received: string[] = [];
|
||||
ws!: WebSocket;
|
||||
deferred: Deferred<void> = new Deferred();
|
||||
|
||||
constructor(name: string) {
|
||||
this.name = name;
|
||||
const port = (server.address() as any).port;
|
||||
this.ws = new WebSocket(`ws://localhost:${port}`);
|
||||
this.ws.on('message', (data, _isBinary) => {
|
||||
this.received.push(data.toString());
|
||||
});
|
||||
this.ws.on('open', () => {
|
||||
this.deferred.resolve();
|
||||
});
|
||||
}
|
||||
}
|
|
@ -12,9 +12,8 @@ import { ExperimentStartupInfo } from '../../../../common/experimentStartupInfo'
|
|||
import { EnvironmentInformation } from '../../environment';
|
||||
import { KubernetesEnvironmentService } from './kubernetesEnvironmentService';
|
||||
import { FrameworkControllerClientFactory } from '../../../kubernetes/frameworkcontroller/frameworkcontrollerApiClient';
|
||||
import { FrameworkControllerClusterConfigAzure, FrameworkControllerJobStatus, FrameworkControllerTrialConfigTemplate,
|
||||
import { FrameworkControllerJobStatus, FrameworkControllerTrialConfigTemplate,
|
||||
FrameworkControllerJobCompleteStatus } from '../../../kubernetes/frameworkcontroller/frameworkcontrollerConfig';
|
||||
import { KeyVaultConfig, AzureStorage } from '../../../kubernetes/kubernetesConfig';
|
||||
|
||||
@component.Singleton
|
||||
export class FrameworkControllerEnvironmentService extends KubernetesEnvironmentService {
|
||||
|
|
|
@ -6,7 +6,6 @@ import path from 'path';
|
|||
import azureStorage from 'azure-storage';
|
||||
import {Base64} from 'js-base64';
|
||||
import {String} from 'typescript-string-operations';
|
||||
import { ExperimentConfig } from 'common/experimentConfig';
|
||||
import { ExperimentStartupInfo } from 'common/experimentStartupInfo';
|
||||
import { getLogger, Logger } from 'common/log';
|
||||
import { EnvironmentInformation, EnvironmentService } from 'training_service/reusable/environment';
|
||||
|
|
|
@ -11,7 +11,7 @@ import type { TrainingServiceConfig } from 'common/experimentConfig';
|
|||
import globals from 'common/globals';
|
||||
import { getLogger } from 'common/log';
|
||||
import {
|
||||
TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus
|
||||
TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
|
||||
} from 'common/trainingService';
|
||||
import type { EnvironmentInfo, Parameter, TrainingServiceV3 } from 'common/training_service_v3';
|
||||
import { trainingServiceFactoryV3 } from './factory';
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
* For now we only have "local_v3" and "remote_v3" as PoC.
|
||||
**/
|
||||
|
||||
import type { LocalConfig, RemoteConfig, TrainingServiceConfig } from 'common/experimentConfig';
|
||||
import type { TrainingServiceConfig } from 'common/experimentConfig';
|
||||
import type { TrainingServiceV3 } from 'common/training_service_v3';
|
||||
import { LocalTrainingServiceV3 } from '../local_v3';
|
||||
import { RemoteTrainingServiceV3 } from '../remote_v3';
|
||||
|
|
Загрузка…
Ссылка в новой задаче