Get rid of IoC and remove unused training services (#5567)

This commit is contained in:
liuzhe-lz 2023-05-18 19:06:43 +08:00 коммит произвёл GitHub
Родитель f17385c0ac
Коммит 5676de40d7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
53 изменённых файлов: 207 добавлений и 2958 удалений

Просмотреть файл

@ -1,15 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import * as ioc from 'typescript-ioc';
const Inject: (...args: any[]) => any = ioc.Inject;
const Singleton: (target: any) => void = ioc.Singleton;
const Container = ioc.Container;
const Provides = ioc.Provides;
function get<T>(source: any): T {
return ioc.Container.get(source) as T;
}
export { Provides, Container, Inject, Singleton, get };

Просмотреть файл

@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import assert from 'node:assert/strict';
type AbstractClass = {
name: string;
};
type Class = {
name: string;
new(): any;
};
class IocShimClass {
private singletons: Map<string, any> = new Map();
private snapshots: Map<string, any> = new Map();
public bind(keyClass: AbstractClass, valueClass: Class): void {
const key = keyClass.name;
assert.ok(!this.singletons.has(key));
this.singletons.set(key, new valueClass());
}
public bindInstance(keyClass: AbstractClass, value: any): void {
const key = keyClass.name;
assert.ok(!this.singletons.has(key));
this.singletons.set(key, value);
}
public get<T>(keyClass: AbstractClass): T {
const key = keyClass.name;
assert.ok(this.singletons.has(key));
return this.singletons.get(key);
}
public snapshot(keyClass: AbstractClass): void {
const key = keyClass.name;
const value = this.singletons.get(key);
this.snapshots.set(key, value);
}
public restore(keyClass: AbstractClass): void {
const key = keyClass.name;
const value = this.snapshots.get(key);
this.singletons.set(key, value);
}
// NOTE: for unit test only
public clear(): void {
this.singletons.clear();
}
}
export const IocShim: IocShimClass = new IocShimClass();

Просмотреть файл

@ -1,26 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import rx from 'rx';
import * as component from '../common/component';
@component.Singleton
class ObservableTimer {
private observableSource: rx.Observable<number>;
constructor() {
// TODO: move 100 and 1000 into constants class
this.observableSource = rx.Observable.timer(100, 1000).takeWhile(() => true);
}
public subscribe(onNext?: (value: any) => void, onError?: (exception: any) => void, onCompleted?: () => void): Rx.IDisposable {
return this.observableSource.subscribe(onNext, onError, onCompleted);
}
public unsubscribe( subscription: Rx.IDisposable): void {
if(typeof subscription !== 'undefined') {
subscription.dispose();
}
}
}
export { ObservableTimer };

Просмотреть файл

@ -12,11 +12,11 @@ import net from 'net';
import path from 'path';
import * as timersPromises from 'timers/promises';
import { Deferred } from 'ts-deferred';
import { Container } from 'typescript-ioc';
import { Database, DataStore } from './datastore';
import globals from './globals';
import { resetGlobals } from './globals/unittest'; // TODO: this file should not contain unittest helpers
import { IocShim } from './ioc_shim';
import { ExperimentConfig, Manager } from './manager';
import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService';
@ -132,10 +132,10 @@ function generateParamFileName(hyperParameters: HyperParameters): string {
* Must be paired with `cleanupUnitTest()`.
*/
function prepareUnitTest(): void {
Container.snapshot(Database);
Container.snapshot(DataStore);
Container.snapshot(TrainingService);
Container.snapshot(Manager);
IocShim.snapshot(Database);
IocShim.snapshot(DataStore);
IocShim.snapshot(TrainingService);
IocShim.snapshot(Manager);
resetGlobals();
@ -152,10 +152,10 @@ function prepareUnitTest(): void {
* Must be paired with `prepareUnitTest()`.
*/
function cleanupUnitTest(): void {
Container.restore(Manager);
Container.restore(TrainingService);
Container.restore(DataStore);
Container.restore(Database);
IocShim.restore(Manager);
IocShim.restore(TrainingService);
IocShim.restore(DataStore);
IocShim.restore(Database);
}
let cachedIpv4Address: string | null = null;

Просмотреть файл

@ -4,7 +4,7 @@
import assert from 'assert';
import { Deferred } from 'ts-deferred';
import * as component from '../common/component';
import { IocShim } from 'common/ioc_shim';
import { Database, DataStore, MetricData, MetricDataRecord, MetricType,
TrialJobEvent, TrialJobEventRecord, TrialJobInfo, HyperParameterFormat,
ExportedDataFormat } from '../common/datastore';
@ -16,7 +16,7 @@ import { TrialJobDetail, TrialJobStatus } from '../common/trainingService';
import { getDefaultDatabaseDir, mkDirP } from '../common/utils';
class NNIDataStore implements DataStore {
private db: Database = component.get(Database);
private db: Database = IocShim.get(Database);
private log: Logger = getLogger('NNIDataStore');
private initTask!: Deferred<void>;

Просмотреть файл

@ -4,7 +4,7 @@
import assert from 'assert';
import { ChildProcess, StdioOptions } from 'child_process';
import { Deferred } from 'ts-deferred';
import * as component from '../common/component';
import { IocShim } from 'common/ioc_shim';
import { DataStore, MetricDataRecord, MetricType, TrialJobInfo } from '../common/datastore';
import { NNIError } from '../common/errors';
import { getExperimentId } from '../common/experimentStartupInfo';
@ -64,7 +64,7 @@ class NNIManager implements Manager {
this.readonly = false;
this.log = getLogger('NNIManager');
this.dataStore = component.get(DataStore);
this.dataStore = IocShim.get(DataStore);
this.status = {
status: 'INITIALIZED',
errors: []
@ -315,11 +315,6 @@ class NNIManager implements Manager {
this.trainingService = new fcModule.FrameworkControllerTrainingService();
break;
}
case 'adl_config': {
const adlModule = await import('../training_service/kubernetes/adl/adlTrainingService');
this.trainingService = new adlModule.AdlTrainingService();
break;
}
default:
throw new Error("Setup training service failed.");
}
@ -395,7 +390,7 @@ class NNIManager implements Manager {
this.setStatus('STOPPED');
this.log.info('Experiment stopped.');
await component.get<TensorboardManager>(TensorboardManager).stop();
await IocShim.get<TensorboardManager>(TensorboardManager).stop();
await this.dataStore.close();
}
@ -492,9 +487,6 @@ class NNIManager implements Manager {
} else if (platform === 'frameworkcontroller') {
const module_ = await import('../training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService');
return new module_.FrameworkControllerTrainingService();
} else if (platform === 'adl') {
const module_ = await import('../training_service/kubernetes/adl/adlTrainingService');
return new module_.AdlTrainingService();
} else {
this.pollInterval = 0.5;
const module_ = await import('../training_service/v3/compat');

Просмотреть файл

@ -6,12 +6,12 @@ import cp from 'child_process';
import path from 'path';
import { ChildProcess } from 'child_process';
import * as component from '../common/component';
import { getLogger, Logger } from '../common/log';
import { getTunerProc, isAlive, uniqueString, mkDirPSync, getFreePort } from '../common/utils';
import { Manager } from '../common/manager';
import { TensorboardParams, TensorboardTaskStatus, TensorboardTaskInfo, TensorboardManager } from '../common/tensorboardManager';
import globals from 'common/globals';
import { globals } from 'common/globals';
import { IocShim } from 'common/ioc_shim';
class TensorboardTaskDetail implements TensorboardTaskInfo {
public id: string;
@ -39,7 +39,7 @@ class NNITensorboardManager implements TensorboardManager {
this.log = getLogger('NNITensorboardManager');
this.tensorboardTaskMap = new Map<string, TensorboardTaskDetail>();
this.setTensorboardVersion();
this.nniManager = component.get(Manager);
this.nniManager = IocShim.get(Manager);
}
public async startTensorboardTask(tensorboardParams: TensorboardParams): Promise<TensorboardTaskDetail> {

Просмотреть файл

@ -21,13 +21,11 @@
import 'app-module-path/register'; // so we can use absolute path to import
import { Container, Scope } from 'typescript-ioc';
import { globals, initGlobals } from 'common/globals';
initGlobals();
import * as component from 'common/component';
import { Database, DataStore } from 'common/datastore';
import { IocShim } from 'common/ioc_shim';
import { Logger, getLogger } from 'common/log';
import { Manager } from 'common/manager';
import { TensorboardManager } from 'common/tensorboardManager';
@ -47,12 +45,12 @@ async function start(): Promise<void> {
const restServer = new RestServer(globals.args.port, globals.args.urlPrefix);
await restServer.start();
Container.bind(Manager).to(NNIManager).scope(Scope.Singleton);
Container.bind(Database).to(SqlDB).scope(Scope.Singleton);
Container.bind(DataStore).to(NNIDataStore).scope(Scope.Singleton);
Container.bind(TensorboardManager).to(NNITensorboardManager).scope(Scope.Singleton);
IocShim.bind(Database, SqlDB);
IocShim.bind(DataStore, NNIDataStore);
IocShim.bind(Manager, NNIManager);
IocShim.bind(TensorboardManager, NNITensorboardManager);
const ds: DataStore = component.get(DataStore);
const ds: DataStore = IocShim.get(DataStore);
await ds.init();
globals.rest.registerExpressRouter('/api/v1/nni', createRestHandler());

145
ts/nni_manager/package-lock.json сгенерированный
Просмотреть файл

@ -31,7 +31,6 @@
"tar": "^6.1.13",
"tree-kill": "^1.2.2",
"ts-deferred": "^1.0.4",
"typescript-ioc": "^1.2.6",
"typescript-string-operations": "^1.4.1",
"ws": "^8.13.0",
"yargs": "^17.7.1"
@ -819,6 +818,7 @@
"version": "2.1.5",
"resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz",
"integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==",
"dev": true,
"dependencies": {
"@nodelib/fs.stat": "2.0.5",
"run-parallel": "^1.1.9"
@ -831,6 +831,7 @@
"version": "2.0.5",
"resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz",
"integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==",
"dev": true,
"engines": {
"node": ">= 8"
}
@ -839,6 +840,7 @@
"version": "1.2.8",
"resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz",
"integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==",
"dev": true,
"dependencies": {
"@nodelib/fs.scandir": "2.1.5",
"fastq": "^1.6.0"
@ -1737,6 +1739,7 @@
"version": "2.1.0",
"resolved": "https://registry.npmjs.org/array-union/-/array-union-2.1.0.tgz",
"integrity": "sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw==",
"dev": true,
"engines": {
"node": ">=8"
}
@ -1922,6 +1925,7 @@
"version": "3.0.2",
"resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz",
"integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==",
"dev": true,
"dependencies": {
"fill-range": "^7.0.1"
},
@ -2113,6 +2117,7 @@
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz",
"integrity": "sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==",
"dev": true,
"engines": {
"node": ">=6"
}
@ -2569,6 +2574,7 @@
"version": "3.0.1",
"resolved": "https://registry.npmjs.org/dir-glob/-/dir-glob-3.0.1.tgz",
"integrity": "sha512-WkrWp9GR4KXfKGYzOLmTuGVi1UWFfws377n9cc55/tb6DuqyF6pcQ5AbiHEshaDpY9v6oaSr2XCDidGmMwdzIA==",
"dev": true,
"dependencies": {
"path-type": "^4.0.0"
},
@ -3082,6 +3088,7 @@
"version": "3.2.12",
"resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.2.12.tgz",
"integrity": "sha512-DVj4CQIYYow0BlaelwK1pHl5n5cRSJfM60UA0zK891sVInoPri2Ekj7+e1CT3/3qxXenpI+nBBmQAcJPJgaj4w==",
"dev": true,
"dependencies": {
"@nodelib/fs.stat": "^2.0.2",
"@nodelib/fs.walk": "^1.2.3",
@ -3097,6 +3104,7 @@
"version": "5.1.2",
"resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz",
"integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==",
"dev": true,
"dependencies": {
"is-glob": "^4.0.1"
},
@ -3119,6 +3127,7 @@
"version": "1.13.0",
"resolved": "https://registry.npmjs.org/fastq/-/fastq-1.13.0.tgz",
"integrity": "sha512-YpkpUnK8od0o1hmeSc7UUs/eB/vIPWJYjKck2QKIzAf71Vm1AAQ3EbuZB3g2JIy+pg+ERD0vqI79KyZiB2e2Nw==",
"dev": true,
"dependencies": {
"reusify": "^1.0.4"
}
@ -3139,6 +3148,7 @@
"version": "7.0.1",
"resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz",
"integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==",
"dev": true,
"dependencies": {
"to-regex-range": "^5.0.1"
},
@ -3512,6 +3522,7 @@
"version": "6.0.2",
"resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz",
"integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==",
"dev": true,
"dependencies": {
"is-glob": "^4.0.3"
},
@ -3557,6 +3568,7 @@
"version": "11.1.0",
"resolved": "https://registry.npmjs.org/globby/-/globby-11.1.0.tgz",
"integrity": "sha512-jhIXaOzy1sb8IyocaruWSn1TjmnBVs8Ayhcy83rmxNJ8q2uWKCAj3CnJY+KpGSXCueAPc0i05kVvVKtP1t9S3g==",
"dev": true,
"dependencies": {
"array-union": "^2.1.0",
"dir-glob": "^3.0.1",
@ -3966,6 +3978,7 @@
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz",
"integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==",
"dev": true,
"engines": {
"node": ">=0.10.0"
}
@ -3982,6 +3995,7 @@
"version": "4.0.3",
"resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz",
"integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==",
"dev": true,
"dependencies": {
"is-extglob": "^2.1.1"
},
@ -3999,6 +4013,7 @@
"version": "7.0.0",
"resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz",
"integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==",
"dev": true,
"engines": {
"node": ">=0.12.0"
}
@ -4654,6 +4669,7 @@
"version": "1.4.1",
"resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz",
"integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==",
"dev": true,
"engines": {
"node": ">= 8"
}
@ -4670,6 +4686,7 @@
"version": "4.0.5",
"resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz",
"integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==",
"dev": true,
"dependencies": {
"braces": "^3.0.2",
"picomatch": "^2.3.1"
@ -8699,6 +8716,7 @@
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/path-type/-/path-type-4.0.0.tgz",
"integrity": "sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==",
"dev": true,
"engines": {
"node": ">=8"
}
@ -8727,6 +8745,7 @@
"version": "2.3.1",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz",
"integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==",
"dev": true,
"engines": {
"node": ">=8.6"
},
@ -8912,6 +8931,7 @@
"version": "1.2.3",
"resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz",
"integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==",
"dev": true,
"funding": [
{
"type": "github",
@ -9000,11 +9020,6 @@
"node": ">=8.10.0"
}
},
"node_modules/reflect-metadata": {
"version": "0.1.13",
"resolved": "https://registry.npmjs.org/reflect-metadata/-/reflect-metadata-0.1.13.tgz",
"integrity": "sha512-Ts1Y/anZELhSsjMcU605fU9RE4Oi3p5ORujwbIKXfWa+0Zxs510Qrmrce5/Jowq3cHSZSJqBjypxmHarc+vEWg=="
},
"node_modules/regexpp": {
"version": "3.2.0",
"resolved": "https://registry.npmjs.org/regexpp/-/regexpp-3.2.0.tgz",
@ -9089,30 +9104,6 @@
"node": ">=0.10.0"
}
},
"node_modules/require-glob": {
"version": "4.1.0",
"resolved": "https://registry.npmjs.org/require-glob/-/require-glob-4.1.0.tgz",
"integrity": "sha512-c66YRk0kDUUz9t+/nEG11dnVh6nLppztiE/TLBerRlAGd75AuCLXHQ6xauOPgZaw9T+6wfG8u8ibfMD9GwmDYw==",
"dependencies": {
"glob-parent": "^6.0.0",
"globby": "^11.0.3",
"parent-module": "^2.0.0"
},
"engines": {
"node": ">= 10"
}
},
"node_modules/require-glob/node_modules/parent-module": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/parent-module/-/parent-module-2.0.0.tgz",
"integrity": "sha512-uo0Z9JJeWzv8BG+tRcapBKNJ0dro9cLyczGzulS6EfeyAdeC9sbojtW6XwvYxJkEne9En+J2XEl4zyglVeIwFg==",
"dependencies": {
"callsites": "^3.1.0"
},
"engines": {
"node": ">=8"
}
},
"node_modules/require-main-filename": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/require-main-filename/-/require-main-filename-2.0.0.tgz",
@ -9165,6 +9156,7 @@
"version": "1.0.4",
"resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz",
"integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==",
"dev": true,
"engines": {
"iojs": ">=1.0.0",
"node": ">=0.10.0"
@ -9207,6 +9199,7 @@
"version": "1.2.0",
"resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz",
"integrity": "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==",
"dev": true,
"funding": [
{
"type": "github",
@ -9439,6 +9432,7 @@
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/slash/-/slash-3.0.0.tgz",
"integrity": "sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==",
"dev": true,
"engines": {
"node": ">=8"
}
@ -10052,6 +10046,7 @@
"version": "5.0.1",
"resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz",
"integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==",
"dev": true,
"dependencies": {
"is-number": "^7.0.0"
},
@ -10265,15 +10260,6 @@
"node": ">=4.2.0"
}
},
"node_modules/typescript-ioc": {
"version": "1.2.6",
"resolved": "https://registry.npmjs.org/typescript-ioc/-/typescript-ioc-1.2.6.tgz",
"integrity": "sha512-ksyRctgYtHsjmKBceEgeifV3Zq3tnqLh6/q9HlWC08lnng9ZHA3IwXw8oQlv77TpHbs2J3GVUbxTuhmLLSWCTg==",
"dependencies": {
"reflect-metadata": "^0.1.13",
"require-glob": "^3.2.0"
}
},
"node_modules/typescript-string-operations": {
"version": "1.5.0",
"resolved": "https://registry.npmjs.org/typescript-string-operations/-/typescript-string-operations-1.5.0.tgz",
@ -11242,6 +11228,7 @@
"version": "2.1.5",
"resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz",
"integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==",
"dev": true,
"requires": {
"@nodelib/fs.stat": "2.0.5",
"run-parallel": "^1.1.9"
@ -11250,12 +11237,14 @@
"@nodelib/fs.stat": {
"version": "2.0.5",
"resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz",
"integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A=="
"integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==",
"dev": true
},
"@nodelib/fs.walk": {
"version": "1.2.8",
"resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz",
"integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==",
"dev": true,
"requires": {
"@nodelib/fs.scandir": "2.1.5",
"fastq": "^1.6.0"
@ -11989,7 +11978,8 @@
"array-union": {
"version": "2.1.0",
"resolved": "https://registry.npmjs.org/array-union/-/array-union-2.1.0.tgz",
"integrity": "sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw=="
"integrity": "sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw==",
"dev": true
},
"asn1": {
"version": "0.2.6",
@ -12133,6 +12123,7 @@
"version": "3.0.2",
"resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz",
"integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==",
"dev": true,
"requires": {
"fill-range": "^7.0.1"
}
@ -12264,7 +12255,8 @@
"callsites": {
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz",
"integrity": "sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ=="
"integrity": "sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==",
"dev": true
},
"camelcase": {
"version": "5.3.1",
@ -12593,6 +12585,7 @@
"version": "3.0.1",
"resolved": "https://registry.npmjs.org/dir-glob/-/dir-glob-3.0.1.tgz",
"integrity": "sha512-WkrWp9GR4KXfKGYzOLmTuGVi1UWFfws377n9cc55/tb6DuqyF6pcQ5AbiHEshaDpY9v6oaSr2XCDidGmMwdzIA==",
"dev": true,
"requires": {
"path-type": "^4.0.0"
}
@ -12990,6 +12983,7 @@
"version": "3.2.12",
"resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.2.12.tgz",
"integrity": "sha512-DVj4CQIYYow0BlaelwK1pHl5n5cRSJfM60UA0zK891sVInoPri2Ekj7+e1CT3/3qxXenpI+nBBmQAcJPJgaj4w==",
"dev": true,
"requires": {
"@nodelib/fs.stat": "^2.0.2",
"@nodelib/fs.walk": "^1.2.3",
@ -13002,6 +12996,7 @@
"version": "5.1.2",
"resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz",
"integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==",
"dev": true,
"requires": {
"is-glob": "^4.0.1"
}
@ -13023,6 +13018,7 @@
"version": "1.13.0",
"resolved": "https://registry.npmjs.org/fastq/-/fastq-1.13.0.tgz",
"integrity": "sha512-YpkpUnK8od0o1hmeSc7UUs/eB/vIPWJYjKck2QKIzAf71Vm1AAQ3EbuZB3g2JIy+pg+ERD0vqI79KyZiB2e2Nw==",
"dev": true,
"requires": {
"reusify": "^1.0.4"
}
@ -13040,6 +13036,7 @@
"version": "7.0.1",
"resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz",
"integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==",
"dev": true,
"requires": {
"to-regex-range": "^5.0.1"
}
@ -13317,6 +13314,7 @@
"version": "6.0.2",
"resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz",
"integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==",
"dev": true,
"requires": {
"is-glob": "^4.0.3"
}
@ -13334,6 +13332,7 @@
"version": "11.1.0",
"resolved": "https://registry.npmjs.org/globby/-/globby-11.1.0.tgz",
"integrity": "sha512-jhIXaOzy1sb8IyocaruWSn1TjmnBVs8Ayhcy83rmxNJ8q2uWKCAj3CnJY+KpGSXCueAPc0i05kVvVKtP1t9S3g==",
"dev": true,
"requires": {
"array-union": "^2.1.0",
"dir-glob": "^3.0.1",
@ -13636,7 +13635,8 @@
"is-extglob": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz",
"integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ=="
"integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==",
"dev": true
},
"is-fullwidth-code-point": {
"version": "3.0.0",
@ -13647,6 +13647,7 @@
"version": "4.0.3",
"resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz",
"integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==",
"dev": true,
"requires": {
"is-extglob": "^2.1.1"
}
@ -13660,7 +13661,8 @@
"is-number": {
"version": "7.0.0",
"resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz",
"integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng=="
"integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==",
"dev": true
},
"is-path-inside": {
"version": "3.0.3",
@ -14176,7 +14178,8 @@
"merge2": {
"version": "1.4.1",
"resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz",
"integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg=="
"integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==",
"dev": true
},
"methods": {
"version": "1.1.2",
@ -14187,6 +14190,7 @@
"version": "4.0.5",
"resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz",
"integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==",
"dev": true,
"requires": {
"braces": "^3.0.2",
"picomatch": "^2.3.1"
@ -17017,7 +17021,8 @@
"path-type": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/path-type/-/path-type-4.0.0.tgz",
"integrity": "sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw=="
"integrity": "sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==",
"dev": true
},
"pathval": {
"version": "1.1.1",
@ -17039,7 +17044,8 @@
"picomatch": {
"version": "2.3.1",
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz",
"integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA=="
"integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==",
"dev": true
},
"pkg-dir": {
"version": "4.2.0",
@ -17175,7 +17181,8 @@
"queue-microtask": {
"version": "1.2.3",
"resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz",
"integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A=="
"integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==",
"dev": true
},
"quick-lru": {
"version": "5.1.1",
@ -17237,11 +17244,6 @@
"picomatch": "^2.2.1"
}
},
"reflect-metadata": {
"version": "0.1.13",
"resolved": "https://registry.npmjs.org/reflect-metadata/-/reflect-metadata-0.1.13.tgz",
"integrity": "sha512-Ts1Y/anZELhSsjMcU605fU9RE4Oi3p5ORujwbIKXfWa+0Zxs510Qrmrce5/Jowq3cHSZSJqBjypxmHarc+vEWg=="
},
"regexpp": {
"version": "3.2.0",
"resolved": "https://registry.npmjs.org/regexpp/-/regexpp-3.2.0.tgz",
@ -17306,26 +17308,6 @@
"resolved": "https://registry.npmjs.org/require-directory/-/require-directory-2.1.1.tgz",
"integrity": "sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q=="
},
"require-glob": {
"version": "4.1.0",
"resolved": "https://registry.npmjs.org/require-glob/-/require-glob-4.1.0.tgz",
"integrity": "sha512-c66YRk0kDUUz9t+/nEG11dnVh6nLppztiE/TLBerRlAGd75AuCLXHQ6xauOPgZaw9T+6wfG8u8ibfMD9GwmDYw==",
"requires": {
"glob-parent": "^6.0.0",
"globby": "^11.0.3",
"parent-module": "^2.0.0"
},
"dependencies": {
"parent-module": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/parent-module/-/parent-module-2.0.0.tgz",
"integrity": "sha512-uo0Z9JJeWzv8BG+tRcapBKNJ0dro9cLyczGzulS6EfeyAdeC9sbojtW6XwvYxJkEne9En+J2XEl4zyglVeIwFg==",
"requires": {
"callsites": "^3.1.0"
}
}
}
},
"require-main-filename": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/require-main-filename/-/require-main-filename-2.0.0.tgz",
@ -17365,7 +17347,8 @@
"reusify": {
"version": "1.0.4",
"resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz",
"integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw=="
"integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==",
"dev": true
},
"rimraf": {
"version": "3.0.2",
@ -17394,6 +17377,7 @@
"version": "1.2.0",
"resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz",
"integrity": "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==",
"dev": true,
"requires": {
"queue-microtask": "^1.2.2"
}
@ -17572,7 +17556,8 @@
"slash": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/slash/-/slash-3.0.0.tgz",
"integrity": "sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q=="
"integrity": "sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==",
"dev": true
},
"smart-buffer": {
"version": "4.2.0",
@ -18043,6 +18028,7 @@
"version": "5.0.1",
"resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz",
"integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==",
"dev": true,
"requires": {
"is-number": "^7.0.0"
}
@ -18186,15 +18172,6 @@
"integrity": "sha512-QCh+85mCy+h0IGff8r5XWzOVSbBO+KfeYrMQh7NJ58QujwcE22u+NUSmUxqF+un70P9GXKxa2HCNiTTMJknyjQ==",
"dev": true
},
"typescript-ioc": {
"version": "1.2.6",
"resolved": "https://registry.npmjs.org/typescript-ioc/-/typescript-ioc-1.2.6.tgz",
"integrity": "sha512-ksyRctgYtHsjmKBceEgeifV3Zq3tnqLh6/q9HlWC08lnng9ZHA3IwXw8oQlv77TpHbs2J3GVUbxTuhmLLSWCTg==",
"requires": {
"reflect-metadata": "^0.1.13",
"require-glob": ">=4.0.1"
}
},
"typescript-string-operations": {
"version": "1.5.0",
"resolved": "https://registry.npmjs.org/typescript-string-operations/-/typescript-string-operations-1.5.0.tgz",

Просмотреть файл

@ -32,7 +32,6 @@
"tar": "^6.1.13",
"tree-kill": "^1.2.2",
"ts-deferred": "^1.0.4",
"typescript-ioc": "^1.2.6",
"typescript-string-operations": "^1.4.1",
"ws": "^8.13.0",
"yargs": "^17.7.1"

Просмотреть файл

@ -4,7 +4,7 @@
import { Request, Response, Router } from 'express';
import path from 'path';
import * as component from '../common/component';
import { IocShim } from 'common/ioc_shim';
import { DataStore, MetricDataRecord, TrialJobInfo } from '../common/datastore';
import { NNIError, NNIErrorNames } from '../common/errors';
import { isNewExperiment, isReadonly } from '../common/experimentStartupInfo';
@ -27,8 +27,8 @@ class NNIRestHandler {
private log: Logger;
constructor() {
this.nniManager = component.get(Manager);
this.tensorboardManager = component.get(TensorboardManager);
this.nniManager = IocShim.get(Manager);
this.tensorboardManager = IocShim.get(TensorboardManager);
this.log = getLogger('NNIRestHandler');
}
@ -113,7 +113,7 @@ class NNIRestHandler {
// TODO add validators for request params, query, body
private checkStatus(router: Router): void {
router.get('/check-status', (_req: Request, res: Response) => {
const ds: DataStore = component.get<DataStore>(DataStore);
const ds: DataStore = IocShim.get<DataStore>(DataStore);
ds.init().then(() => {
res.send(this.nniManager.getStatus());
}).catch(async (err: Error) => {

Просмотреть файл

@ -4,9 +4,8 @@
'use strict';
import { expect } from 'chai';
import { Container, Scope } from 'typescript-ioc';
import * as component from '../../common/component';
import { IocShim } from 'common/ioc_shim';
import { Database, DataStore, TrialJobInfo } from '../../common/datastore';
import { ExperimentProfile, TrialJobStatistics } from '../../common/manager';
import { TrialJobStatus } from '../../common/trainingService';
@ -18,9 +17,9 @@ describe('Unit test for dataStore', () => {
let ds: DataStore;
before(async () => {
prepareUnitTest();
Container.bind(Database).to(SqlDB).scope(Scope.Singleton);
Container.bind(DataStore).to(NNIDataStore).scope(Scope.Singleton);
ds = component.get(DataStore);
IocShim.bind(Database, SqlDB);
IocShim.bind(DataStore, NNIDataStore);
ds = IocShim.get(DataStore);
await ds.init();
});

Просмотреть файл

@ -6,9 +6,8 @@
import * as fs from 'fs';
import * as os from 'os';
import { assert } from 'chai';
import { Container, Scope } from 'typescript-ioc';
import * as component from '../../common/component';
import { IocShim } from 'common/ioc_shim';
import { Database, DataStore, TrialJobInfo } from '../../common/datastore';
import { Manager, TrialJobStatistics} from '../../common/manager';
import { TrialJobDetail } from '../../common/trainingService';
@ -79,11 +78,11 @@ async function initContainer(mode: string = 'create'): Promise<void> {
}
restServer = new RestServer(globals.args.port, globals.args.urlPrefix);
await restServer.start();
Container.bind(Manager).to(NNIManager).scope(Scope.Singleton);
Container.bind(Database).to(SqlDB).scope(Scope.Singleton);
Container.bind(DataStore).to(NNIDataStore).scope(Scope.Singleton);
Container.bind(TensorboardManager).to(NNITensorboardManager).scope(Scope.Singleton);
await component.get<DataStore>(DataStore).init();
IocShim.bind(Database, SqlDB);
IocShim.bind(DataStore, NNIDataStore);
IocShim.bind(Manager, NNIManager);
IocShim.bind(TensorboardManager, NNITensorboardManager);
await IocShim.get<DataStore>(DataStore).init();
}
async function prepareExperiment(): Promise<void> {
@ -107,7 +106,7 @@ async function prepareExperiment(): Promise<void> {
fs.writeFileSync(globals.paths.experimentsList, JSON.stringify(experimentsInformation, null, 4));
await initContainer();
nniManager = component.get(Manager);
nniManager = IocShim.get(Manager);
// if trainingService is assigned, startExperiment won't create training service again
const manager = nniManager as any;
@ -135,6 +134,7 @@ async function cleanExperiment(): Promise<void> {
await manager.stopExperimentTopHalf();
await manager.stopExperimentBottomHalf();
await restServer.shutdown();
IocShim.clear();
}
async function testListTrialJobs(): Promise<void> {
@ -326,7 +326,7 @@ async function resumeExperiment(): Promise<void> {
// (one is start and the other is resume) run in the same process.
UnitTestHelpers.reset();
await initContainer('resume');
nniManager = component.get(Manager);
nniManager = IocShim.get(Manager);
// if trainingService is assigned, startExperiment won't create training service again
const manager = nniManager as any;

Просмотреть файл

@ -6,8 +6,6 @@
import * as assert from 'assert';
import * as os from 'os';
import * as path from 'path';
import { Container } from 'typescript-ioc';
import * as component from '../../common/component';
import { Database, MetricDataRecord, TrialJobEvent, TrialJobEventRecord } from '../../common/datastore';
import { ExperimentConfig, ExperimentProfile } from '../../common/manager';
import { cleanupUnitTest, getDefaultDatabaseDir, mkDirP, prepareUnitTest } from '../../common/utils';

Просмотреть файл

@ -3,11 +3,9 @@
import { assert, expect } from 'chai';
import fs from 'fs';
import { Container, Scope } from 'typescript-ioc';
import os from 'os';
import path from 'path';
import * as component from '../../common/component';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { ExperimentsManager } from '../../extensions/experiments_manager';
import globals from '../../common/globals/unittest';

Просмотреть файл

@ -4,7 +4,6 @@
'use strict';
import { Deferred } from 'ts-deferred';
import { Provider } from 'typescript-ioc';
import { MetricDataRecord, MetricType, TrialJobInfo } from '../../common/datastore';
import { MethodNotImplementedError } from '../../common/errors';
@ -16,10 +15,6 @@ import {
TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
} from '../../common/trainingService';
export const testManagerProvider: Provider = {
get: (): Manager => { return new MockedNNIManager(); }
};
export class MockedNNIManager extends Manager {
public getStatus(): NNIManagerStatus {
return {

Просмотреть файл

@ -6,15 +6,10 @@
import { assert } from 'chai';
import { EventEmitter } from 'events';
import { Deferred } from 'ts-deferred';
import { Provider } from 'typescript-ioc';
import { MethodNotImplementedError } from '../../common/errors';
import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric } from '../../common/trainingService';
const testTrainingServiceProvider: Provider = {
get: () => { return new MockedTrainingService(''); }
};
const jobDetailTemplate: TrialJobDetail = {
id: 'xxxx',
status: 'WAITING',
@ -45,14 +40,14 @@ const idStatusListResume = [
{id: '1011', status: 'RUNNING'},
{id: '1112', status: 'RUNNING'}];
class MockedTrainingService implements TrainingService {
export class MockedTrainingService implements TrainingService {
private readonly eventEmitter: EventEmitter;
private mockedMetaDataValue: string = "default";
private jobDetailList: Map<string, TrialJobDetail>;
private mode: string;
private submittedCnt: number = 0;
constructor(mode: string) {
constructor(mode: string = '') {
this.eventEmitter = new EventEmitter();
this.mode = mode;
this.jobDetailList = new Map<string, TrialJobDetail>();
@ -198,5 +193,3 @@ class MockedTrainingService implements TrainingService {
throw new MethodNotImplementedError();
}
}
export{MockedTrainingService, testTrainingServiceProvider}

Просмотреть файл

@ -5,17 +5,17 @@
import { assert, expect } from 'chai';
import request from 'request';
import { Container } from 'typescript-ioc';
import * as component from '../../common/component';
import { DataStore } from '../../common/datastore';
import { IocShim } from 'common/ioc_shim';
import { Database, DataStore } from '../../common/datastore';
import { ExperimentProfile, Manager } from '../../common/manager';
import { TrainingService } from '../../common/trainingService';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { SqlDB } from '../../core/sqlDatabase';
import { MockedDataStore } from '../mock/datastore';
import { MockedTrainingService } from '../mock/trainingService';
import { RestServer, UnitTestHelpers } from 'rest_server';
import { testManagerProvider } from '../mock/nniManager';
import { MockedNNIManager } from '../mock/nniManager';
import { MockedExperimentManager } from '../mock/experimentManager';
import { TensorboardManager } from '../../common/tensorboardManager';
import { MockTensorboardManager } from '../mock/mockTensorboardManager';
@ -32,10 +32,12 @@ describe('Unit test for rest handler', () => {
before(async () => {
ExpsMgrHelpers.setExperimentsManager(new MockedExperimentManager());
prepareUnitTest();
Container.bind(Manager).provider(testManagerProvider);
Container.bind(DataStore).to(MockedDataStore);
Container.bind(TrainingService).to(MockedTrainingService);
Container.bind(TensorboardManager).to(MockTensorboardManager);
IocShim.clear();
IocShim.bind(Database, SqlDB);
IocShim.bind(DataStore, MockedDataStore);
IocShim.bind(TrainingService, MockedTrainingService);
IocShim.bind(Manager, MockedNNIManager);
IocShim.bind(TensorboardManager, MockTensorboardManager);
restServer = new RestServer(0, '');
await restServer.start();
const port = UnitTestHelpers.getPort(restServer);

Просмотреть файл

@ -1,138 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import chai from 'chai';
import chaiAsPromised from 'chai-as-promised';
import fs from 'fs';
import tmp from 'tmp';
import * as component from '../../common/component';
import { TrialJobApplicationForm, TrialJobDetail, TrainingService } from '../../common/trainingService';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { TrialConfigMetadataKey } from '../../training_service/common/trialConfigMetadataKey';
import { AdlTrainingService } from '../../training_service/kubernetes/adl/adlTrainingService';
const localCodeDir: string = tmp.dirSync().name
describe('Unit Test for AdlTrainingService', () => {
let skip: boolean = false;
try {
const testKubeflowConfig = fs.readFileSync('/home/vsts/.kube/config', 'utf8');
} catch (err) {
console.log('Please have kubernetes cluster to enable its training service unit test.');
skip = true;
}
let testAdlTrialConfig: any = JSON.stringify({
"command": "python3 /root/apps/nni_linear_regression/main.py",
"codeDir": ".",
"gpuNum": 0,
"image": "test.image:latest",
"imagePullSecrets": [
{
"name": "stagingsecrets"
}
],
"nfs": {
"server": "172.20.188.236",
"path": "/exports",
"containerMountPath": "/nfs"
},
"memorySize": "1Gi",
"cpuNum": 1
});
let testAdlTrialConfig2: any = JSON.stringify({
"command": "python3 /root/apps/nni_linear_regression/main.py",
"codeDir": ".",
"gpuNum": 0,
"image": "test.image:latest",
"imagePullSecrets": [
{
"name": "stagingsecrets"
}
],
"adaptive": true,
"checkpoint": {
"storageClass": "aws-efs",
"storageSize": "1Gi"
},
"nfs": {
"server": "172.20.188.236",
"path": "/exports",
"containerMountPath": "/nfs"
}
});
let testNniManagerIp: any = JSON.stringify({
"nniManagerIp": "0.0.0.0"
});
let adlTrainingService: AdlTrainingService;
console.log(tmp.dirSync().name);
before(() => {
chai.should();
chai.use(chaiAsPromised);
prepareUnitTest();
});
after(() => {
cleanupUnitTest();
});
beforeEach(() => {
if (skip) {
return;
}
adlTrainingService = component.get(AdlTrainingService);
adlTrainingService.run()
});
afterEach(() => {
if (skip) {
return;
}
adlTrainingService.cleanUp();
});
it('Set and get cluster metadata', async () => {
if (skip) {
return;
}
await adlTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, testAdlTrialConfig2);
await adlTrainingService.setClusterMetadata(TrialConfigMetadataKey.NNI_MANAGER_IP, testNniManagerIp);
let data:string = await adlTrainingService.getClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG);
chai.expect(data).to.be.equals(testAdlTrialConfig2);
});
it('Submit job', async () => {
if (skip) {
return;
}
// job without given checkpoint, with resource config
await adlTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, testAdlTrialConfig);
let form: TrialJobApplicationForm = {
sequenceId: 0,
hyperParameters: {
value: 'mock hyperparameters',
index: 0
}
};
let jobDetail: TrialJobDetail = await adlTrainingService.submitTrialJob(form);
chai.expect(jobDetail.status).to.be.equals('WAITING');
await adlTrainingService.cancelTrialJob(jobDetail.id);
chai.expect(jobDetail.status).to.be.equals('USER_CANCELED');
// job with given checkpoint
await adlTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, testAdlTrialConfig2);
form = {
sequenceId: 0,
hyperParameters: {
value: 'mock hyperparameters',
index: 0
}
};
jobDetail = await adlTrainingService.submitTrialJob(form);
chai.expect(jobDetail.status).to.be.equals('WAITING');
await adlTrainingService.cancelTrialJob(jobDetail.id);
chai.expect(jobDetail.status).to.be.equals('USER_CANCELED');
}).timeout(3000000);
});

Просмотреть файл

@ -7,7 +7,6 @@ import chai from 'chai';
import chaiAsPromised from 'chai-as-promised';
import fs from 'fs';
import tmp from 'tmp';
import * as component from '../../common/component';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { TrialConfigMetadataKey } from '../../training_service/common/trialConfigMetadataKey';
import { KubeflowTrainingService } from '../../training_service/kubernetes/kubeflow/kubeflowTrainingService';
@ -47,7 +46,7 @@ describe('Unit Test for KubeflowTrainingService', () => {
if (skip) {
return;
}
kubeflowTrainingService = component.get(KubeflowTrainingService);
kubeflowTrainingService = new KubeflowTrainingService();
});
afterEach(() => {

Просмотреть файл

@ -8,7 +8,6 @@ import chaiAsPromised from 'chai-as-promised';
import fs from 'fs';
import path from 'path';
import tmp from 'tmp';
import * as component from '../../common/component';
import { TrialJobApplicationForm, TrialJobDetail} from '../../common/trainingService';
import { cleanupUnitTest, delay, prepareUnitTest, getExperimentRootDir } from '../../common/utils';
import { TrialConfigMetadataKey } from '../../training_service/common/trialConfigMetadataKey';

Просмотреть файл

@ -5,7 +5,6 @@
import chai from 'chai';
import chaiAsPromised from 'chai-as-promised';
import * as component from '../../../common/component';
import { cleanupUnitTest, prepareUnitTest } from '../../../common/utils';
import { LinuxCommands } from '../../../training_service/remote_machine/extends/linuxCommands';
@ -25,7 +24,7 @@ describe('Unit Test for linuxCommands', () => {
});
beforeEach(() => {
linuxCommands = component.get(LinuxCommands);
linuxCommands = new LinuxCommands();
});
afterEach(() => {

Просмотреть файл

@ -5,7 +5,6 @@
import chai from 'chai';
import chaiAsPromised from 'chai-as-promised';
import * as component from '../../../common/component';
import { cleanupUnitTest, prepareUnitTest } from '../../../common/utils';
import { WindowsCommands } from '../../../training_service/remote_machine/extends/windowsCommands';
@ -25,7 +24,7 @@ describe('Unit Test for Windows Commands', () => {
});
beforeEach(() => {
windowsCommands = component.get(WindowsCommands);
windowsCommands = new WindowsCommands();
});
afterEach(() => {

Просмотреть файл

@ -1,158 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import assert from 'assert';
import chai from 'chai';
import chaiAsPromised from 'chai-as-promised';
import fs from 'fs';
import tmp from 'tmp';
import * as component from '../../common/component';
import { TrialJobApplicationForm, TrialJobDetail, TrainingService } from '../../common/trainingService';
import { cleanupUnitTest, delay, prepareUnitTest } from '../../common/utils';
import { TrialConfigMetadataKey } from '../../training_service/common/trialConfigMetadataKey';
import { RemoteMachineTrainingService } from '../../training_service/remote_machine/remoteMachineTrainingService';
// copy mockedTrail.py to local folder
const localCodeDir: string = tmp.dirSync().name
const mockedTrialPath: string = './test/mock/mockedTrial.py'
fs.copyFileSync(mockedTrialPath, localCodeDir + '/mockedTrial.py')
describe('Unit Test for RemoteMachineTrainingService', () => {
/*
To enable remote machine unit test, remote machine information needs to be configured in:
Default/.vscode/rminfo.json, whose content looks like:
{
"ip": "10.172.121.40",
"username": "user1",
"passwd": "mypassword"
}
*/
let skip: boolean = false;
let testRmInfo: any;
let machineList: any;
try {
testRmInfo = JSON.parse(fs.readFileSync('../../.vscode/rminfo.json', 'utf8'));
console.log(testRmInfo);
machineList = `[{\"ip\":\"${testRmInfo.ip}\",\"port\":22,\"username\":\"${testRmInfo.user}\",\"passwd\":\"${testRmInfo.password}\"}]`;
} catch (err) {
console.log('Please configure rminfo.json to enable remote machine unit test.');
skip = true;
}
let remoteMachineTrainingService: RemoteMachineTrainingService
before(() => {
chai.should();
chai.use(chaiAsPromised);
prepareUnitTest();
});
after(() => {
cleanupUnitTest();
});
beforeEach(() => {
if (skip) {
return;
}
remoteMachineTrainingService = component.get(RemoteMachineTrainingService);
remoteMachineTrainingService.run();
});
afterEach(() => {
if (skip) {
return;
}
remoteMachineTrainingService.cleanUp();
});
it('List trial jobs', async () => {
if (skip) {
return;
}
chai.expect(await remoteMachineTrainingService.listTrialJobs()).to.be.empty;
});
it('Set cluster metadata', async () => {
if (skip) {
return;
}
await remoteMachineTrainingService.setClusterMetadata(TrialConfigMetadataKey.MACHINE_LIST, machineList);
await remoteMachineTrainingService.setClusterMetadata(
TrialConfigMetadataKey.TRIAL_CONFIG, `{"command":"sleep 1h && echo ","codeDir":"${localCodeDir}","gpuNum":1}`);
const form: TrialJobApplicationForm = {
sequenceId: 0,
hyperParameters: {
value: 'mock hyperparameters',
index: 0
}
};
const trialJob = await remoteMachineTrainingService.submitTrialJob(form);
// After a job is cancelled, the status should be changed to 'USER_CANCELED'
await remoteMachineTrainingService.cancelTrialJob(trialJob.id);
// After a job is cancelled, the status should be changed to 'USER_CANCELED'
const trialJob2 = await remoteMachineTrainingService.getTrialJob(trialJob.id);
chai.expect(trialJob2.status).to.be.equals('USER_CANCELED');
//Expect rejected if passing invalid trial job id
await remoteMachineTrainingService.cancelTrialJob(trialJob.id + 'ddd').should.eventually.be.rejected;
});
it('Submit job test', async () => {
if (skip) {
return;
}
});
it('Submit job and read metrics data', async () => {
if (skip) {
return;
}
// set machine list'
await remoteMachineTrainingService.setClusterMetadata(TrialConfigMetadataKey.MACHINE_LIST, machineList);
// set meta data
const trialConfig: string = `{\"command\":\"python3 mockedTrial.py\", \"codeDir\":\"${localCodeDir}\",\"gpuNum\":0}`
await remoteMachineTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, trialConfig);
// submit job
const form: TrialJobApplicationForm = {
sequenceId: 0,
hyperParameters: {
value: 'mock hyperparameters',
index: 0
},
placementConstraint: {
type: "None",
gpus: []
}
};
const jobDetail: TrialJobDetail = await remoteMachineTrainingService.submitTrialJob(form);
// Add metrics listeners
const listener1 = function f1(_metric: any) {
}
const listener2 = function f1(_metric: any) {
}
remoteMachineTrainingService.addTrialJobMetricListener(listener1);
remoteMachineTrainingService.addTrialJobMetricListener(listener2);
await delay(10000);
// remove listender1
remoteMachineTrainingService.removeTrialJobMetricListener(listener1);
await delay(5000);
}).timeout(30000);
it('Test getTrialJob exception', async () => {
if (skip) {
return;
}
await remoteMachineTrainingService.getTrialJob('wrongid').catch((err) => {
assert(err !== undefined);
});
});
});

Просмотреть файл

@ -8,7 +8,6 @@ import fs from 'fs';
import path from 'path';
import { Writable } from 'stream';
import { String } from 'typescript-string-operations';
import * as component from 'common/component';
import { getBasePort, getExperimentId } from 'common/experimentStartupInfo';
import { LegacyRestServer } from 'common/restServer';
import { getExperimentRootDir, mkDirPSync } from 'common/utils';
@ -18,7 +17,6 @@ import { getExperimentRootDir, mkDirPSync } from 'common/utils';
*
* FIXME: This should be a router, not a separate REST server.
*/
@component.Singleton
export abstract class ClusterJobRestServer extends LegacyRestServer {
private readonly API_ROOT_URL: string = '/api/v1/nni-pai';
private readonly NNI_METRICS_PATTERN: string = `NNISDK_MEb'(?<metrics>.*?)'`;

Просмотреть файл

@ -1,57 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import fs from 'fs';
import { GeneralK8sClient, KubernetesCRDClient } from '../kubernetesApiClient';
/**
* Adl ClientV1
*/
class AdlClientV1 extends KubernetesCRDClient {
/**
* constructor, to initialize adl CRD definition
*/
public readonly namespace: string;
public constructor(namespace: string) {
super();
this.namespace = namespace;
this.crdSchema = JSON.parse(fs.readFileSync('./config/adl/adaptdl-crd-v1.json', 'utf8'));
this.client.addCustomResourceDefinition(this.crdSchema);
}
protected get operator(): any {
return this.client.apis['adaptdl.petuum.com'].v1.namespaces(this.namespace).adaptdljobs;
}
public get containerName(): string {
return 'main';
}
public async getKubernetesPods(jobName: string): Promise<any> {
let result: Promise<any>;
const response = await this.client.api.v1.namespaces(this.namespace).pods
.get({ qs: { labelSelector: `adaptdl/job=${jobName}` } });
if (response.statusCode && (response.statusCode >= 200 && response.statusCode <= 299)) {
result = Promise.resolve(response.body);
} else {
result = Promise.reject(`AdlClient getKubernetesPods failed, statusCode is ${response.statusCode}`);
}
return result;
}
}
/**
* Adl Client
*/
class AdlClientFactory {
/**
* Factory method to generate operator client
*/
public static createClient(namespace: string): KubernetesCRDClient {
return new AdlClientV1(namespace);
}
}
export { AdlClientFactory, GeneralK8sClient };
export { AdlClientV1 }

Просмотреть файл

@ -1,95 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import {KubernetesTrialConfig} from "../kubernetesConfig";
/**
* Checkpoint Config
*/
export class CheckpointConfig {
public readonly storageClass: string;
public readonly storageSize: string;
constructor(storageClass: string, storageSize: string) {
this.storageClass = storageClass;
this.storageSize = storageSize;
}
}
/**
* imagePullSecret Config
*/
export class ImagePullSecretConfig{
public readonly name: string;
constructor(name: string) {
this.name = name
}
}
/**
* NFS Config
*/
export class NFSConfig {
public readonly server: string;
public readonly path: string;
public readonly containerMountPath: string;
constructor(server: string, path: string, containerMountPath: string) {
this.server = server;
this.path = path;
this.containerMountPath = containerMountPath;
}
}
/**
* Trial job configuration for Adl
*/
export class AdlTrialConfig extends KubernetesTrialConfig {
public readonly command: string;
public readonly gpuNum: number;
public readonly image: string;
public readonly namespace?: string;
public readonly imagePullSecrets?: ImagePullSecretConfig[];
public readonly nfs?: NFSConfig;
public readonly checkpoint?: CheckpointConfig;
public readonly cpuNum?: number;
public readonly memorySize?: string;
public readonly adaptive?: boolean; // adaptive == preemptible
constructor(codeDir: string,
command: string, gpuNum: number,
image: string, namespace?: string,
imagePullSecrets?: ImagePullSecretConfig[],
nfs?: NFSConfig, checkpoint?: CheckpointConfig,
cpuNum?: number, memorySize?: string,
adaptive?: boolean
) {
super(codeDir);
this.command = command;
this.gpuNum = gpuNum;
this.image = image;
this.namespace = namespace;
this.imagePullSecrets = imagePullSecrets;
this.nfs = nfs;
this.checkpoint = checkpoint;
this.cpuNum = cpuNum;
this.memorySize = memorySize;
this.adaptive = adaptive;
}
}
export type AdlJobStatus = "Pending" | "Running" | "Starting" | "Stopping" | "Failed" | "Succeeded";

Просмотреть файл

@ -1,92 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import { AdlClientV1 } from './adlApiClient';
import { KubernetesTrialJobDetail} from '../kubernetesData';
import { KubernetesJobInfoCollector } from '../kubernetesJobInfoCollector';
import { AdlJobStatus } from './adlConfig';
/**
* Collector Adl jobs info from Kubernetes cluster, and update adl job status locally
*/
export class AdlJobInfoCollector extends KubernetesJobInfoCollector {
constructor(jobMap: Map<string, KubernetesTrialJobDetail>) {
super(jobMap);
}
protected async retrieveSingleTrialJobInfo(adlClient: AdlClientV1 | undefined,
kubernetesTrialJob: KubernetesTrialJobDetail): Promise<void> {
if (!this.statusesNeedToCheck.includes(kubernetesTrialJob.status)) {
return Promise.resolve();
}
if (adlClient === undefined) {
return Promise.reject('AdlClient is undefined');
}
let kubernetesJobInfo: any;
let kubernetesPodsInfo: any;
try {
kubernetesJobInfo = await adlClient.getKubernetesJob(kubernetesTrialJob.kubernetesJobName);
kubernetesPodsInfo = await adlClient.getKubernetesPods(kubernetesTrialJob.kubernetesJobName);
} catch (error) {
// Notice: it maynot be a 'real' error since cancel trial job can also cause getKubernetesJob failed.
this.log.error(`Get job ${kubernetesTrialJob.kubernetesJobName} info failed, error is ${error}`);
//This is not treat as a error status
return Promise.resolve();
}
/* eslint-disable require-atomic-updates */
if (kubernetesJobInfo.status) {
const phase: AdlJobStatus = <AdlJobStatus>kubernetesJobInfo.status.phase
switch (phase) {
case 'Pending':
case 'Starting':
kubernetesTrialJob.status = 'WAITING';
if (kubernetesPodsInfo.items.length > 0){
if (kubernetesPodsInfo.items[0].status.containerStatuses != undefined) {
const currState: any = kubernetesPodsInfo.items[0].status.containerStatuses[0].state
if (currState.waiting != undefined) {
const msg: string = currState.waiting.reason
if (msg == "ImagePullBackOff" || msg == "ErrImagePull") {
kubernetesTrialJob.status = 'FAILED';
}
}
}
kubernetesTrialJob.message = kubernetesPodsInfo.items
.map((pod: any) => JSON.stringify(pod.status.containerStatuses))
.join('\n');
}
kubernetesTrialJob.startTime = Date.parse(<string>kubernetesJobInfo.metadata.creationTimestamp);
break;
case 'Running':
case 'Stopping':
kubernetesTrialJob.status = 'RUNNING';
kubernetesTrialJob.message = `Use 'nnictl log trial --trial_id ${kubernetesTrialJob.id}' to check the log stream.`;
if (kubernetesTrialJob.startTime === undefined) {
kubernetesTrialJob.startTime = Date.parse(<string>kubernetesJobInfo.metadata.creationTimestamp);
}
break;
case 'Failed':
kubernetesTrialJob.status = 'FAILED';
kubernetesTrialJob.message = kubernetesJobInfo.status.message;
if (kubernetesPodsInfo.items.length > 0) {
kubernetesTrialJob.message += " ; ";
kubernetesTrialJob.message += `Use 'nnictl log trial --trial_id ${kubernetesTrialJob.id}' for the path of the collected logs.`;
}
// undefined => NaN as endTime here
kubernetesTrialJob.endTime = Date.parse(<string>kubernetesJobInfo.status.completionTimestamp);
break;
case 'Succeeded':
kubernetesTrialJob.status = 'SUCCEEDED';
kubernetesTrialJob.endTime = Date.parse(<string>kubernetesJobInfo.status.completionTimestamp);
kubernetesTrialJob.message = `Succeeded at ${kubernetesJobInfo.status.completionTimestamp}`
break;
default:
}
}
/* eslint-enable require-atomic-updates */
return Promise.resolve();
}
}

Просмотреть файл

@ -1,20 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import * as component from 'common/component';
import { KubernetesJobRestServer } from '../kubernetesJobRestServer';
import { AdlTrainingService } from './adlTrainingService';
/**
* Adl Training service Rest server, provides rest API to support adl job metrics update
*
*/
@component.Singleton
export class AdlJobRestServer extends KubernetesJobRestServer {
/**
* constructor to provide NNIRestServer's own rest property, e.g. port
*/
constructor() {
super(component.get(AdlTrainingService));
}
}

Просмотреть файл

@ -1,362 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import fs from 'fs';
import * as component from 'common/component';
import { String } from 'typescript-string-operations';
import { getExperimentId } from 'common/experimentStartupInfo';
import {
NNIManagerIpConfig, TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
} from 'common/trainingService';
import { delay, generateParamFileName, getVersion, uniqueString } from 'common/utils';
import { TrialConfigMetadataKey } from 'training_service/common/trialConfigMetadataKey';
import { KubernetesTrialJobDetail } from '../kubernetesData';
import { KubernetesTrainingService } from '../kubernetesTrainingService';
import { AdlClientFactory } from './adlApiClient'
import { AdlJobInfoCollector } from './adlJobInfoCollector';
import { AdlJobRestServer } from './adlJobRestServer';
import { AdlTrialConfig } from './adlConfig'
/**
* Training Service implementation for Adl
*/
@component.Singleton
class AdlTrainingService extends KubernetesTrainingService implements KubernetesTrainingService {
private adlTrialConfig?: AdlTrialConfig;
private readonly adlJobInfoCollector: AdlJobInfoCollector;
private configmapTemplateStr: string;
private jobTemplateStr: string;
private pvcTemplateStr: string;
private tensorboardPvcTemplate: any;
private tensorboardDeploymentTemplate: any;
//TODO: change the logic here when we want to support multiple tensorboard
private tensorboardName: string = "adaptdl-tensorboard-" + getExperimentId().toLowerCase();
constructor() {
super();
this.adlJobInfoCollector = new AdlJobInfoCollector(this.trialJobsMap);
this.experimentId = getExperimentId();
this.configmapTemplateStr = fs.readFileSync(
'./config/adl/adaptdl-nni-configmap-template.json', 'utf8');
this.jobTemplateStr = fs.readFileSync('./config/adl/adaptdljob-template.json', 'utf8');
this.pvcTemplateStr = fs.readFileSync('./config/adl/adaptdl-pvc-template.json', 'utf8');
this.tensorboardPvcTemplate = JSON.parse(
fs.readFileSync('./config/adl/adaptdl-tensorboard-pvc-template.json', 'utf8'));
this.tensorboardDeploymentTemplate = JSON.parse(
fs.readFileSync('./config/adl/adaptdl-tensorboard-deployment-template.json', 'utf8'));
this.log.info('Construct Adl training service.');
}
public async run(): Promise<void> {
this.log.info(this.tensorboardName);
this.log.info('Start tensorboard deployment.');
await this.launchTensorboard()
this.log.info('Run Adl training service.');
this.kubernetesJobRestServer = component.get(AdlJobRestServer);
if (this.kubernetesJobRestServer === undefined) {
throw new Error('kubernetesJobRestServer not initialized!');
}
await this.kubernetesJobRestServer.start();
this.kubernetesJobRestServer.setEnableVersionCheck = this.versionCheck;
this.log.info(`Adl Training service rest server listening on: ${this.kubernetesJobRestServer.endPoint}`);
while (!this.stopping) {
// collect metrics for Adl jobs by interacting with Kubernetes API server
await delay(3000);
await this.adlJobInfoCollector.retrieveTrialStatus(this.kubernetesCRDClient);
if (this.kubernetesJobRestServer.getErrorMessage !== undefined) {
throw new Error(this.kubernetesJobRestServer.getErrorMessage);
}
}
this.log.info('Adl training service exit.');
}
private async launchTensorboard(): Promise<void> {
// Start the tensorboard at the beginning of the experiment.
if (this.adlTrialConfig === undefined) {
throw new Error('Adl trial config is undefined');
}
// Create tensorboard deployment
this.tensorboardDeploymentTemplate.metadata.name = this.tensorboardName
this.tensorboardDeploymentTemplate.metadata.labels.expId = this.experimentId
this.tensorboardDeploymentTemplate.spec.selector.matchLabels.app = this.tensorboardName
this.tensorboardDeploymentTemplate.spec.template.metadata.labels.app = this.tensorboardName
this.tensorboardDeploymentTemplate.spec.template.spec.volumes[0]
.persistentVolumeClaim.claimName = this.tensorboardName
const deploymentUid: string = await this.genericK8sClient.createDeployment(this.tensorboardDeploymentTemplate);
// Create pvc
this.tensorboardPvcTemplate.metadata.name = this.tensorboardName;
this.tensorboardPvcTemplate.metadata.ownerReferences[0].name = this.tensorboardName;
this.tensorboardPvcTemplate.metadata.ownerReferences[0].uid = deploymentUid
if (this.adlTrialConfig.checkpoint != undefined) {
this.tensorboardPvcTemplate.spec.resources.requests.storage = this.adlTrialConfig.checkpoint.storageSize;
this.tensorboardPvcTemplate.spec.storageClassName = this.adlTrialConfig.checkpoint.storageClass;
}
else {
this.tensorboardPvcTemplate.spec.resources.requests.storage = "1Gi"
this.tensorboardPvcTemplate.spec.storageClassName = await this.genericK8sClient.getStorageClass();
}
await this.genericK8sClient.createPersistentVolumeClaim(this.tensorboardPvcTemplate);
return Promise.resolve()
}
public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.kubernetesCRDClient === undefined) {
throw new Error('Adl job operator client is undefined');
}
if (this.adlTrialConfig === undefined) {
throw new Error('Adl trial config is undefined');
}
if (this.kubernetesRestServerPort === undefined) {
const restServer: AdlJobRestServer = component.get(AdlJobRestServer);
this.kubernetesRestServerPort = restServer.clusterRestServerPort;
}
const trialJobId: string = form.id === undefined ? uniqueString(5) : form.id;
const adlJobName: string = `nni-exp-${this.experimentId}-trial-${trialJobId}`.toLowerCase();
const initStatus: TrialJobStatus = 'WAITING';
const codeDir = this.adlTrialConfig.codeDir;
const outputDir = "output"
const trialJobDetail: KubernetesTrialJobDetail = new KubernetesTrialJobDetail(
trialJobId,
initStatus,
Date.now(),
codeDir,
form,
adlJobName,
outputDir
);
// Create adljob
const job: any = JSON.parse(this.jobTemplateStr);
job.metadata.name = adlJobName
job.metadata.labels.app = this.NNI_KUBERNETES_TRIAL_LABEL
job.metadata.labels.expId = this.experimentId
job.metadata.labels.trialId = trialJobId
if (this.adlTrialConfig.adaptive !== undefined){
job.spec.preemptible = this.adlTrialConfig.adaptive
}
job.spec.template.spec.containers[0]
.image = this.adlTrialConfig.image;
job.spec.template.spec.volumes[0]
.persistentVolumeClaim.claimName = adlJobName
job.spec.template.spec.volumes[1]
.persistentVolumeClaim.claimName = this.tensorboardName
job.spec.template.spec.volumes[2]
.configMap.name = adlJobName
// Handle Pod Resource
let cpu: number = 1;
let memory: string = "1Gi";
if (this.adlTrialConfig.cpuNum !== undefined) {
cpu = this.adlTrialConfig.cpuNum;
}
if (this.adlTrialConfig.memorySize !== undefined) {
memory = this.adlTrialConfig.memorySize;
}
job.spec.template.spec.containers[0]
.resources.requests.memory = memory;
job.spec.template.spec.containers[0]
.resources.requests.cpu = cpu;
job.spec.template.spec.containers[0]
.resources.limits["nvidia.com/gpu"] = this.adlTrialConfig.gpuNum;
// Handle imagePullSecrets
if (this.adlTrialConfig.imagePullSecrets !== undefined) {
job.spec.template.spec.imagePullSecrets = job.spec.template.spec
.imagePullSecrets.concat(this.adlTrialConfig.imagePullSecrets);
}
// Handle NFS
if (this.adlTrialConfig.nfs !== undefined) {
job.spec.template.spec.volumes.push({
"name": "nfs",
"nfs": {
"server": this.adlTrialConfig.nfs.server,
"path": this.adlTrialConfig.nfs.path,
"readOnly": false
}
});
job.spec.template.spec.containers[0].volumeMounts.push({
"name": "nfs",
"mountPath": this.adlTrialConfig.nfs.containerMountPath
});
}
await this.kubernetesCRDClient.createKubernetesJob(job);
const k8sadlJob: any = await this.kubernetesCRDClient.getKubernetesJob(adlJobName);
// Create pvc
const pvc: any = JSON.parse(this.pvcTemplateStr);
pvc.metadata.name = adlJobName;
pvc.metadata.ownerReferences[0].name = adlJobName;
pvc.metadata.ownerReferences[0].uid = k8sadlJob.metadata.uid;
if (this.adlTrialConfig.checkpoint != undefined) {
pvc.spec.resources.requests.storage = this.adlTrialConfig
.checkpoint.storageSize;
pvc.spec.storageClassName = this.adlTrialConfig.checkpoint.storageClass;
}
else {
pvc.spec.resources.requests.storage = "1Gi"
pvc.spec.storageClassName = await this.genericK8sClient.getStorageClass();
}
await this.genericK8sClient.createPersistentVolumeClaim(pvc);
// prepare the runscript and convert it to configmap and mount it
const configmap: any = JSON.parse(this.configmapTemplateStr);
configmap.metadata.name = adlJobName;
configmap.metadata.ownerReferences[0].name = adlJobName;
configmap.metadata.ownerReferences[0].uid = k8sadlJob.metadata.uid;
configmap.data["run.sh"] = await this.prepareRunScript(
trialJobId, form, codeDir, outputDir)
const cleanupScriptTemplate: string =
`#!/bin/bash
ps aux | grep "python3 -m nni.tools.trial_tool.trial_keeper" | awk '{print $2}' | xargs kill -2
while true;
do
proc=\`ps aux | grep "python3 -m nni.tools.trial_tool.trial_keeper" | awk '{print $2}' | grep "" -c\`
if (( $proc == 1 )); then
exit 0
else
echo "waiting"
fi
sleep 1
done
`;
configmap.data["cleanup.sh"] = cleanupScriptTemplate
await this.genericK8sClient.createConfigMap(configmap)
// Set trial job detail until create Adl job successfully
this.trialJobsMap.set(trialJobId, trialJobDetail);
return Promise.resolve(trialJobDetail);
}
private async prepareRunScript(jobId: string,
form: TrialJobApplicationForm,
codeDir: string,
outputDir: string): Promise<string> {
if (this.adlTrialConfig === undefined) {
throw new Error('Adl trial config is undefined');
}
if (this.kubernetesRestServerPort === undefined) {
throw new Error('Adl rest server port is undefined');
}
if (this.nniManagerIpConfig === undefined) {
throw new Error('Adl nniManager ip config is undefined');
}
const expId: string = this.experimentId;
const seqId: string = form.sequenceId.toString();
const command: string = this.adlTrialConfig.command;
const hyperParameters: string = form.hyperParameters.value;
const hyperParametersFile: string = generateParamFileName(form.hyperParameters);
const nniManagerPort: string = this.kubernetesRestServerPort.toString();
const nniManagerIp: string = this.nniManagerIpConfig.nniManagerIp;
let nniManagerVersion: string = '';
if (this.versionCheck) {
nniManagerVersion = await getVersion();
}
let nvidiaScript: string = '';
if (this.adlTrialConfig.gpuNum == 0) {
nvidiaScript = 'export CUDA_VISIBLE_DEVICES=';
}
const runScriptTemplate: string =
`#!/bin/bash
export NNI_PLATFORM=adl
export MULTI_PHASE=false
export NNI_SYS_DIR={0}
export NNI_CODE_DIR={0}
export NNI_OUTPUT_DIR={1}
export NNI_TRIAL_JOB_ID={2}
export NNI_EXP_ID={3}
export NNI_TRIAL_SEQ_ID={4}
mkdir -p $NNI_OUTPUT_DIR
{5}
echo '{6}' > $NNI_CODE_DIR/{7}
python3 -m nni.tools.trial_tool.trial_keeper --trial_command '{8}' \
--nnimanager_ip {9} --nnimanager_port {10} \
--nni_manager_version '{11}' --log_collection '{12}'
`;
const runScript = String.Format(
runScriptTemplate, codeDir, outputDir,
jobId, expId, seqId, nvidiaScript,
hyperParameters, hyperParametersFile, command,
nniManagerIp, nniManagerPort, nniManagerVersion,
this.logCollection);
return Promise.resolve(runScript);
}
public async cleanUp(): Promise<void> {
super.cleanUp();
// Delete Tensorboard deployment
try {
await this.genericK8sClient.deleteDeployment("adaptdl-tensorboard-" + this.experimentId.toLowerCase());
this.log.info('tensorboard deployment deleted');
} catch (error) {
this.log.error(`tensorboard deployment deletion failed: ${(error as any).message}`);
}
}
public async setClusterMetadata(key: string, value: string): Promise<void> {
this.log.info('SetCluster ' + key + ', ' +value);
switch (key) {
case TrialConfigMetadataKey.NNI_MANAGER_IP:
this.nniManagerIpConfig = <NNIManagerIpConfig>JSON.parse(value);
break;
case TrialConfigMetadataKey.TRIAL_CONFIG: {
this.adlTrialConfig = <AdlTrialConfig>JSON.parse(value);
let namespace: string = 'default';
if (this.adlTrialConfig.namespace !== undefined) {
namespace = this.adlTrialConfig.namespace;
}
this.genericK8sClient.setNamespace = namespace;
this.kubernetesCRDClient = AdlClientFactory.createClient(namespace);
break;
}
case TrialConfigMetadataKey.VERSION_CHECK:
this.versionCheck = (value === 'true' || value === 'True');
break;
case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value;
break;
default:
}
return Promise.resolve();
}
public getClusterMetadata(key: string): Promise<string> {
let result: string;
switch (key) {
case TrialConfigMetadataKey.TRIAL_CONFIG:
if (this.adlTrialConfig === undefined) {
return Promise.reject(`${key} is not set yet`);
}
result = JSON.stringify(this.adlTrialConfig);
break;
case TrialConfigMetadataKey.NNI_MANAGER_IP:
if (this.nniManagerIpConfig === undefined) {
return Promise.reject(`${key} is not set yet`);
}
result = JSON.stringify(this.nniManagerIpConfig);
break;
default:
return Promise.reject(`${key} not set`);
}
return Promise.resolve(result);
}
public async updateTrialJob(_1: any, _2: any): Promise<TrialJobDetail> {
throw new Error('not supported');
}
}
export { AdlTrainingService };

Просмотреть файл

@ -1,17 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import * as component from 'common/component';
import { KubernetesJobRestServer } from '../kubernetesJobRestServer';
import { FrameworkControllerTrainingService } from './frameworkcontrollerTrainingService';
/**
* frameworkcontroller Training service Rest server, provides rest API to support frameworkcontroller job metrics update
*
*/
@component.Singleton
export class FrameworkControllerJobRestServer extends KubernetesJobRestServer {
constructor() {
super(component.get(FrameworkControllerTrainingService));
}
}

Просмотреть файл

@ -5,7 +5,6 @@ import assert from 'assert';
import cpp from 'child-process-promise';
import fs from 'fs';
import path from 'path';
import * as component from 'common/component';
import {getExperimentId} from 'common/experimentStartupInfo';
import {
NNIManagerIpConfig, TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
@ -16,8 +15,8 @@ import {TrialConfigMetadataKey} from 'training_service/common/trialConfigMetadat
import {validateCodeDir} from 'training_service/common/util';
import {NFSConfig} from '../kubernetesConfig';
import {KubernetesTrialJobDetail} from '../kubernetesData';
import { KubernetesJobRestServer } from '../kubernetesJobRestServer';
import {KubernetesTrainingService} from '../kubernetesTrainingService';
import {KubernetesJobRestServer} from '../kubernetesJobRestServer';
import {FrameworkControllerClientFactory} from './frameworkcontrollerApiClient';
import {
FrameworkControllerClusterConfig,
@ -28,14 +27,12 @@ import {
FrameworkControllerTrialConfigTemplate,
} from './frameworkcontrollerConfig';
import {FrameworkControllerJobInfoCollector} from './frameworkcontrollerJobInfoCollector';
import {FrameworkControllerJobRestServer} from './frameworkcontrollerJobRestServer';
const yaml = require('js-yaml');
/**
* Training Service implementation for frameworkcontroller
*/
@component.Singleton
class FrameworkControllerTrainingService extends KubernetesTrainingService implements KubernetesTrainingService {
private fcTrialConfig?: FrameworkControllerTrialConfig; // frameworkcontroller trial configuration
private fcTemplate: any = undefined; // custom frameworkcontroller template
@ -122,8 +119,7 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
this.genericK8sClient.setNamespace = namespace;
if (this.kubernetesRestServerPort === undefined) {
const restServer: FrameworkControllerJobRestServer = component.get(FrameworkControllerJobRestServer);
this.kubernetesRestServerPort = restServer.clusterRestServerPort;
this.kubernetesRestServerPort = this.kubernetesJobRestServer!.clusterRestServerPort;
}
// wait upload of code Dir to finish

Просмотреть файл

@ -1,20 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import * as component from 'common/component';
import { KubernetesJobRestServer } from '../kubernetesJobRestServer';
import { KubeflowTrainingService } from './kubeflowTrainingService';
/**
* Kubeflow Training service Rest server, provides rest API to support kubeflow job metrics update
*
*/
@component.Singleton
export class KubeflowJobRestServer extends KubernetesJobRestServer {
/**
* constructor to provide NNIRestServer's own rest property, e.g. port
*/
constructor(kubeflowTrainingService: KubeflowTrainingService) {
super(kubeflowTrainingService);
}
}

Просмотреть файл

@ -5,7 +5,6 @@ import assert from 'assert';
import cpp from 'child-process-promise';
import fs from 'fs';
import path from 'path';
import * as component from 'common/component';
import { getExperimentId } from 'common/experimentStartupInfo';
import {
@ -24,13 +23,11 @@ import { KubeflowClusterConfig, KubeflowClusterConfigAzure, KubeflowClusterConfi
KubeflowTrialConfig, KubeflowTrialConfigFactory, KubeflowTrialConfigPytorch, KubeflowTrialConfigTensorflow
} from './kubeflowConfig';
import { KubeflowJobInfoCollector } from './kubeflowJobInfoCollector';
import { KubeflowJobRestServer } from './kubeflowJobRestServer';
/**
* Training Service implementation for Kubeflow
* Refer https://github.com/kubeflow/kubeflow for more info about Kubeflow
*/
@component.Singleton
class KubeflowTrainingService extends KubernetesTrainingService implements KubernetesTrainingService {
private kubeflowClusterConfig?: KubeflowClusterConfig;
private kubeflowTrialConfig?: KubeflowTrialConfig;
@ -69,8 +66,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
}
if (this.kubernetesRestServerPort === undefined) {
const restServer: KubeflowJobRestServer = new KubeflowJobRestServer(this);
this.kubernetesRestServerPort = restServer.clusterRestServerPort;
this.kubernetesRestServerPort = this.kubernetesJobRestServer!.clusterRestServerPort;
}
// upload code Dir to storage

Просмотреть файл

@ -1,8 +1,6 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import { Inject } from 'typescript-ioc';
import * as component from 'common/component';
import { ClusterJobRestServer } from '../common/clusterJobRestServer';
import { KubernetesTrainingService } from './kubernetesTrainingService';
@ -10,10 +8,8 @@ import { KubernetesTrainingService } from './kubernetesTrainingService';
* Kubeflow Training service Rest server, provides rest API to support kubeflow job metrics update
*
*/
@component.Singleton
export class KubernetesJobRestServer extends ClusterJobRestServer {
@Inject
private readonly kubernetesTrainingService? : KubernetesTrainingService;
private readonly kubernetesTrainingService: KubernetesTrainingService;
/**
* constructor to provide NNIRestServer's own rest property, e.g. port
*/

Просмотреть файл

@ -8,9 +8,9 @@ import azureStorage from 'azure-storage';
import {EventEmitter} from 'events';
import {Base64} from 'js-base64';
import {String} from 'typescript-string-operations';
import {MethodNotImplementedError} from 'common/errors';
import {getExperimentId} from 'common/experimentStartupInfo';
import {getLogger, Logger} from 'common/log';
import {MethodNotImplementedError} from 'common/errors';
import {
NNIManagerIpConfig, TrialJobDetail, TrialJobMetric
} from 'common/trainingService';

Просмотреть файл

@ -1,106 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from 'common/trainingService';
import {TrialConfig} from '../common/trialConfig';
export class PAIClusterConfig {
public readonly userName: string;
public readonly passWord?: string;
public host: string;
public readonly token?: string;
public readonly reuse?: boolean;
public cpuNum?: number;
public memoryMB?: number;
public gpuNum?: number;
public useActiveGpu?: boolean;
public maxTrialNumPerGpu?: number;
/**
* Constructor
* @param userName User name of PAI Cluster
* @param passWord password of PAI Cluster
* @param host Host IP of PAI Cluster
* @param token PAI token of PAI Cluster
* @param reuse If job is reusable for multiple trials
*/
constructor(userName: string, host: string, passWord?: string, token?: string, reuse?: boolean,
cpuNum?: number, memoryMB?: number, gpuNum?: number) {
this.userName = userName;
this.passWord = passWord;
this.host = host;
this.token = token;
this.reuse = reuse;
this.cpuNum = cpuNum;
this.memoryMB = memoryMB;
this.gpuNum = gpuNum;
}
}
/**
* PAI trial job detail
*/
export class PAITrialJobDetail implements TrialJobDetail {
public id: string;
public status: TrialJobStatus;
public paiJobName: string;
public submitTime: number;
public startTime?: number;
public endTime?: number;
public tags?: string[];
public url?: string;
public workingDirectory: string;
public form: TrialJobApplicationForm;
public logPath: string;
public isEarlyStopped?: boolean;
public paiJobDetailUrl?: string;
constructor(id: string, status: TrialJobStatus, paiJobName: string,
submitTime: number, workingDirectory: string, form: TrialJobApplicationForm, logPath: string, paiJobDetailUrl?: string) {
this.id = id;
this.status = status;
this.paiJobName = paiJobName;
this.submitTime = submitTime;
this.workingDirectory = workingDirectory;
this.form = form;
this.tags = [];
this.logPath = logPath;
this.paiJobDetailUrl = paiJobDetailUrl;
}
}
export const PAI_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} MULTI_PHASE={5} \
&& NNI_CODE_DIR={6} && mkdir -p $NNI_SYS_DIR/code && cp -r $NNI_CODE_DIR/. $NNI_SYS_DIR/code && sh $NNI_SYS_DIR/install_nni.sh \
&& cd $NNI_SYS_DIR/code && python3 -m nni.tools.trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8}' --nnimanager_port '{9}' \
--nni_manager_version '{10}' --log_collection '{11}' | tee $NNI_OUTPUT_DIR/trial.log`;
/**
* PAI trial configuration
*/
export class NNIPAITrialConfig extends TrialConfig {
public readonly cpuNum: number;
public readonly memoryMB: number;
public readonly image: string;
public virtualCluster?: string;
public readonly nniManagerNFSMountPath: string;
public readonly containerNFSMountPath: string;
public readonly paiStorageConfigName: string;
public readonly paiConfigPath?: string;
constructor(command: string, codeDir: string, gpuNum: number, cpuNum: number, memoryMB: number,
image: string, nniManagerNFSMountPath: string, containerNFSMountPath: string,
paiStorageConfigName: string, virtualCluster?: string, paiConfigPath?: string) {
super(command, codeDir, gpuNum);
this.cpuNum = cpuNum;
this.memoryMB = memoryMB;
this.image = image;
this.virtualCluster = virtualCluster;
this.nniManagerNFSMountPath = nniManagerNFSMountPath;
this.containerNFSMountPath = containerNFSMountPath;
this.paiStorageConfigName = paiStorageConfigName;
this.paiConfigPath = paiConfigPath;
}
}

Просмотреть файл

@ -1,136 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import request from 'request';
import { Deferred } from 'ts-deferred';
import { NNIError, NNIErrorNames } from 'common/errors';
import { getLogger, Logger } from 'common/log';
import { TrialJobStatus } from 'common/trainingService';
import { OpenpaiConfig } from 'common/experimentConfig';
import { PAITrialJobDetail } from './paiConfig';
/**
* Collector PAI jobs info from PAI cluster, and update pai job status locally
*/
export class PAIJobInfoCollector {
private readonly trialJobsMap: Map<string, PAITrialJobDetail>;
private readonly log: Logger = getLogger('PAIJobInfoCollector');
private readonly statusesNeedToCheck: TrialJobStatus[];
private readonly finalStatuses: TrialJobStatus[];
constructor(jobMap: Map<string, PAITrialJobDetail>) {
this.trialJobsMap = jobMap;
this.statusesNeedToCheck = ['RUNNING', 'UNKNOWN', 'WAITING'];
this.finalStatuses = ['SUCCEEDED', 'FAILED', 'USER_CANCELED', 'SYS_CANCELED', 'EARLY_STOPPED'];
}
public async retrieveTrialStatus(protocol: string, token? : string, config?: OpenpaiConfig): Promise<void> {
if (config === undefined || token === undefined) {
return Promise.resolve();
}
const updatePaiTrialJobs: Promise<void>[] = [];
for (const [trialJobId, paiTrialJob] of this.trialJobsMap) {
if (paiTrialJob === undefined) {
throw new NNIError(NNIErrorNames.NOT_FOUND, `trial job id ${trialJobId} not found`);
}
updatePaiTrialJobs.push(this.getSinglePAITrialJobInfo(protocol, paiTrialJob, token, config));
}
await Promise.all(updatePaiTrialJobs);
}
private getSinglePAITrialJobInfo(_protocol: string, paiTrialJob: PAITrialJobDetail, paiToken: string, config: OpenpaiConfig): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
if (!this.statusesNeedToCheck.includes(paiTrialJob.status)) {
deferred.resolve();
return deferred.promise;
}
// Rest call to get PAI job info and update status
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const getJobInfoRequest: request.Options = {
uri: `${config.host}/rest-server/api/v2/jobs/${config.username}~${paiTrialJob.paiJobName}`,
method: 'GET',
json: true,
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${paiToken}`
}
};
//TODO : pass in request timeout param?
request(getJobInfoRequest, (error: Error, response: request.Response, _body: any) => {
// Status code 200 for success
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
// The job refresh time could be ealier than job submission, so it might return 404 error code, need refactor
// Queried PAI job info failed, set job status to UNKNOWN
if (paiTrialJob.status === 'WAITING' || paiTrialJob.status === 'RUNNING') {
paiTrialJob.status = 'UNKNOWN';
}
} else {
if (response.body.jobStatus && response.body.jobStatus.state) {
switch (response.body.jobStatus.state) {
case 'WAITING':
paiTrialJob.status = 'WAITING';
break;
case 'RUNNING':
paiTrialJob.status = 'RUNNING';
if (paiTrialJob.startTime === undefined) {
paiTrialJob.startTime = response.body.jobStatus.appLaunchedTime;
}
if (paiTrialJob.url === undefined) {
if (response.body.jobStatus.appTrackingUrl) {
paiTrialJob.url = response.body.jobStatus.appTrackingUrl;
} else {
paiTrialJob.url = paiTrialJob.paiJobDetailUrl;
}
}
break;
case 'SUCCEEDED':
paiTrialJob.status = 'SUCCEEDED';
break;
case 'STOPPED':
case 'STOPPING':
if (paiTrialJob.isEarlyStopped !== undefined) {
paiTrialJob.status = paiTrialJob.isEarlyStopped === true ?
'EARLY_STOPPED' : 'USER_CANCELED';
} else {
/* if paiTrialJob's isEarlyStopped is undefined, that mean we didn't stop it via cancellation,
* mark it as SYS_CANCELLED by PAI
*/
paiTrialJob.status = 'SYS_CANCELED';
}
break;
case 'FAILED':
paiTrialJob.status = 'FAILED';
break;
default:
paiTrialJob.status = 'UNKNOWN';
}
// For final job statues, update startTime, endTime and url
if (this.finalStatuses.includes(paiTrialJob.status)) {
if (paiTrialJob.startTime === undefined) {
paiTrialJob.startTime = response.body.jobStatus.appLaunchedTime;
}
if (paiTrialJob.endTime === undefined) {
paiTrialJob.endTime = response.body.jobStatus.completedTime;
}
// Set pai trial job's url to WebHDFS output path
if (paiTrialJob.logPath !== undefined) {
if (paiTrialJob.url && paiTrialJob.url !== paiTrialJob.logPath) {
paiTrialJob.url += `,${paiTrialJob.logPath}`;
} else {
paiTrialJob.url = `${paiTrialJob.logPath}`;
}
}
}
}
}
deferred.resolve();
});
return deferred.promise;
}
}

Просмотреть файл

@ -1,70 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import { Request, Response, Router } from 'express';
import { ClusterJobRestServer } from '../common/clusterJobRestServer';
import { PAITrainingService } from './paiTrainingService';
export interface ParameterFileMeta {
readonly experimentId: string;
readonly trialId: string;
readonly filePath: string;
}
/**
* PAI Training service Rest server, provides rest API to support pai job metrics update
*
*/
export class PAIJobRestServer extends ClusterJobRestServer {
protected parameterFileMetaList: ParameterFileMeta[] = [];
protected readonly paiTrainingService: PAITrainingService;
/**
* constructor to provide NNIRestServer's own rest property, e.g. port
*/
constructor (paiTrainingService: PAITrainingService) {
super();
this.paiTrainingService = paiTrainingService;
}
protected handleTrialMetrics(jobId: string, metrics: any[]): void {
// Split metrics array into single metric, then emit
// Warning: If not split metrics into single ones, the behavior will be UNKNOWN
for (const singleMetric of metrics) {
this.paiTrainingService.MetricsEmitter.emit('metric', {
id : jobId,
data : singleMetric
});
}
}
protected createRestHandler(): Router {
const router: Router = super.createRestHandler();
router.post(`/parameter-file-meta`, (req: Request, res: Response) => {
try {
this.log.info('POST /parameter-file-meta, body is', req.body);
this.parameterFileMetaList.push(req.body);
res.send();
} catch (err) {
this.log.error(`POST parameter-file-meta error: ${err}`);
res.status(500);
res.send((err as any).message);
}
});
router.get(`/parameter-file-meta`, (_req: Request, res: Response) => {
try {
this.log.info(`GET /parameter-file-meta`);
res.send(this.parameterFileMetaList);
} catch (err) {
this.log.error(`GET parameter-file-meta error: ${err}`);
res.status(500);
res.send((err as any).message);
}
});
return router;
}
}

Просмотреть файл

@ -1,442 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import fs from 'fs';
import path from 'path';
import request from 'request';
import * as component from 'common/component';
import { EventEmitter } from 'events';
import { Deferred } from 'ts-deferred';
import { getExperimentId } from 'common/experimentStartupInfo';
import { getLogger, Logger } from 'common/log';
import { MethodNotImplementedError } from 'common/errors';
import {
HyperParameters, NNIManagerIpConfig, TrainingService,
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from 'common/trainingService';
import { delay } from 'common/utils';
import { OpenpaiConfig, toMegaBytes } from 'common/experimentConfig';
import { PAIJobInfoCollector } from './paiJobInfoCollector';
import { PAIJobRestServer } from './paiJobRestServer';
import { PAITrialJobDetail, PAI_TRIAL_COMMAND_FORMAT } from './paiConfig';
import { String } from 'typescript-string-operations';
import { generateParamFileName, getIPV4Address, uniqueString } from 'common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { execMkdir, validateCodeDir, execCopydir } from '../common/util';
const yaml = require('js-yaml');
/**
* Training Service implementation for OpenPAI (Open Platform for AI)
* Refer https://github.com/Microsoft/pai for more info about OpenPAI
*/
@component.Singleton
class PAITrainingService implements TrainingService {
private readonly log!: Logger;
private readonly metricsEmitter: EventEmitter;
private readonly trialJobsMap: Map<string, PAITrialJobDetail>;
private readonly expRootDir: string;
private readonly jobQueue: string[];
private stopping: boolean = false;
private paiToken?: string;
private paiTokenUpdateTime?: number;
private readonly paiTokenUpdateInterval: number;
private readonly experimentId!: string;
private readonly paiJobCollector: PAIJobInfoCollector;
private paiRestServerPort?: number;
private nniManagerIpConfig?: NNIManagerIpConfig;
private versionCheck: boolean = true;
private logCollection: string = 'none';
private paiJobRestServer?: PAIJobRestServer;
private protocol: string;
private copyExpCodeDirPromise?: Promise<void>;
private paiJobConfig: any;
private nniVersion: string | undefined;
private config: OpenpaiConfig;
constructor(config: OpenpaiConfig) {
this.log = getLogger('PAITrainingService');
this.metricsEmitter = new EventEmitter();
this.trialJobsMap = new Map<string, PAITrialJobDetail>();
this.jobQueue = [];
this.expRootDir = path.join('/nni-experiments', getExperimentId());
this.experimentId = getExperimentId();
this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
this.paiTokenUpdateInterval = 7200000; //2hours
this.log.info('Construct paiBase training service.');
this.config = config;
this.versionCheck = !this.config.debug;
this.paiJobRestServer = new PAIJobRestServer(this);
this.paiToken = this.config.token;
this.protocol = this.config.host.toLowerCase().startsWith('https://') ? 'https' : 'http';
this.copyExpCodeDirPromise = this.copyTrialCode();
}
private async copyTrialCode(): Promise<void> {
await validateCodeDir(this.config.trialCodeDirectory);
const nniManagerNFSExpCodeDir = path.join(this.config.localStorageMountPoint, this.experimentId, 'nni-code');
await execMkdir(nniManagerNFSExpCodeDir);
this.log.info(`Starting copy codeDir data from ${this.config.trialCodeDirectory} to ${nniManagerNFSExpCodeDir}`);
await execCopydir(this.config.trialCodeDirectory, nniManagerNFSExpCodeDir);
}
public async run(): Promise<void> {
this.log.info('Run PAI training service.');
if (this.paiJobRestServer === undefined) {
throw new Error('paiJobRestServer not initialized!');
}
await this.paiJobRestServer.start();
this.paiJobRestServer.setEnableVersionCheck = this.versionCheck;
this.log.info(`PAI Training service rest server listening on: ${this.paiJobRestServer.endPoint}`);
await Promise.all([
this.statusCheckingLoop(),
this.submitJobLoop()]);
this.log.info('PAI training service exit.');
}
protected async submitJobLoop(): Promise<void> {
while (!this.stopping) {
while (!this.stopping && this.jobQueue.length > 0) {
const trialJobId: string = this.jobQueue[0];
if (await this.submitTrialJobToPAI(trialJobId)) {
// Remove trial job with trialJobId from job queue
this.jobQueue.shift();
} else {
// Break the while loop since failed to submitJob
break;
}
}
await delay(3000);
}
}
public async listTrialJobs(): Promise<TrialJobDetail[]> {
const jobs: TrialJobDetail[] = [];
for (const key of this.trialJobsMap.keys()) {
jobs.push(await this.getTrialJob(key));
}
return jobs;
}
public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
throw new MethodNotImplementedError();
}
public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (paiTrialJob === undefined) {
throw new Error(`trial job ${trialJobId} not found`);
}
return paiTrialJob;
}
public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
this.metricsEmitter.on('metric', listener);
}
public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
this.metricsEmitter.off('metric', listener);
}
public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
return Promise.reject(new Error(`cancelTrialJob: trial job id ${trialJobId} not found`));
}
if (trialJobDetail.status === 'UNKNOWN') {
trialJobDetail.status = 'USER_CANCELED';
return Promise.resolve();
}
const stopJobRequest: request.Options = {
uri: `${this.config.host}/rest-server/api/v2/jobs/${this.config.username}~${trialJobDetail.paiJobName}/executionType`,
method: 'PUT',
json: true,
body: { value: 'STOP' },
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.paiToken}`
}
};
// Set trialjobDetail's early stopped field, to mark the job's cancellation source
trialJobDetail.isEarlyStopped = isEarlyStopped;
const deferred: Deferred<void> = new Deferred<void>();
request(stopJobRequest, (error: Error, response: request.Response, _body: any) => {
// Status code 202 for success.
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
deferred.reject((error !== undefined && error !== null) ? error.message :
`Stop trial failed, http code: ${response.statusCode}`);
} else {
deferred.resolve();
}
});
return deferred.promise;
}
public async cleanUp(): Promise<void> {
this.log.info('Stopping PAI training service...');
this.stopping = true;
if (this.paiJobRestServer === undefined) {
throw new Error('paiJobRestServer not initialized!');
}
try {
await this.paiJobRestServer.stop();
this.log.info('PAI Training service rest server stopped successfully.');
} catch (error) {
this.log.error(`PAI Training service rest server stopped failed, error: ${(error as any).message}`);
}
}
public get MetricsEmitter(): EventEmitter {
return this.metricsEmitter;
}
protected formatPAIHost(host: string): string {
// If users' host start with 'http://' or 'https://', use the original host,
// or format to 'http//${host}'
if (host.startsWith('http://')) {
this.protocol = 'http';
return host.replace('http://', '');
} else if (host.startsWith('https://')) {
this.protocol = 'https';
return host.replace('https://', '');
} else {
return host;
}
}
protected async statusCheckingLoop(): Promise<void> {
while (!this.stopping) {
await this.paiJobCollector.retrieveTrialStatus(this.protocol, this.paiToken, this.config);
if (this.paiJobRestServer === undefined) {
throw new Error('paiBaseJobRestServer not implemented!');
}
if (this.paiJobRestServer.getErrorMessage !== undefined) {
throw new Error(this.paiJobRestServer.getErrorMessage);
}
await delay(3000);
}
}
public async setClusterMetadata(_key: string, _value: string): Promise<void> { return; }
public async getClusterMetadata(_key: string): Promise<string> { return ''; }
// update trial parameters for multi-phase
public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
// Write file content ( parameter.cfg ) to working folders
await this.writeParameterFile(trialJobDetail.logPath, form.hyperParameters);
return trialJobDetail;
}
public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
this.log.info('submitTrialJob: form:', form);
const trialJobId: string = form.id === undefined ? uniqueString(5) : form.id;
//TODO: use HDFS working folder instead
const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`;
const logPath: string = path.join(this.config.localStorageMountPoint, this.experimentId, trialJobId);
const paiJobDetailUrl: string = `${this.config.host}/job-detail.html?username=${this.config.username}&jobName=${paiJobName}`;
const trialJobDetail: PAITrialJobDetail = new PAITrialJobDetail(
trialJobId,
'WAITING',
paiJobName,
Date.now(),
trialWorkingFolder,
form,
logPath,
paiJobDetailUrl);
this.trialJobsMap.set(trialJobId, trialJobDetail);
this.jobQueue.push(trialJobId);
return trialJobDetail;
}
private async generateNNITrialCommand(trialJobDetail: PAITrialJobDetail, command: string): Promise<string> {
const containerNFSExpCodeDir = `${this.config.containerStorageMountPoint}/${this.experimentId}/nni-code`;
const containerWorkingDir: string = `${this.config.containerStorageMountPoint}/${this.experimentId}/${trialJobDetail.id}`;
const nniPaiTrialCommand: string = String.Format(
PAI_TRIAL_COMMAND_FORMAT,
`${containerWorkingDir}`,
`${containerWorkingDir}/nnioutput`,
trialJobDetail.id,
this.experimentId,
trialJobDetail.form.sequenceId,
false, // multi-phase
containerNFSExpCodeDir,
command,
this.config.nniManagerIp || await getIPV4Address(),
this.paiRestServerPort,
this.nniVersion,
this.logCollection
)
.replace(/\r\n|\n|\r/gm, '');
return nniPaiTrialCommand;
}
private async generateJobConfigInYamlFormat(trialJobDetail: PAITrialJobDetail): Promise<any> {
const jobName = `nni_exp_${this.experimentId}_trial_${trialJobDetail.id}`
let nniJobConfig: any = undefined;
if (this.config.openpaiConfig !== undefined) {
nniJobConfig = JSON.parse(JSON.stringify(this.config.openpaiConfig)); //Trick for deep clone in Typescript
nniJobConfig.name = jobName;
// Each taskRole will generate new command in NNI's command format
// Each command will be formatted to NNI style
for (const taskRoleIndex in nniJobConfig.taskRoles) {
const commands = nniJobConfig.taskRoles[taskRoleIndex].commands
const nniTrialCommand = await this.generateNNITrialCommand(trialJobDetail, commands.join(" && ").replace(/(["'$`\\])/g, '\\$1'));
nniJobConfig.taskRoles[taskRoleIndex].commands = [nniTrialCommand]
}
} else {
nniJobConfig = {
protocolVersion: 2,
name: jobName,
type: 'job',
jobRetryCount: 0,
prerequisites: [
{
type: 'dockerimage',
uri: this.config.dockerImage,
name: 'docker_image_0'
}
],
taskRoles: {
taskrole: {
instances: 1,
completion: {
minFailedInstances: 1,
minSucceededInstances: -1
},
taskRetryCount: 0,
dockerImage: 'docker_image_0',
resourcePerInstance: {
gpu: this.config.trialGpuNumber,
cpu: this.config.trialCpuNumber,
memoryMB: toMegaBytes(this.config.trialMemorySize)
},
commands: [
await this.generateNNITrialCommand(trialJobDetail, this.config.trialCommand)
]
}
},
extras: {
'storages': [
{
name: this.config.storageConfigName
}
],
submitFrom: 'submit-job-v2'
}
}
if (this.config.virtualCluster) {
nniJobConfig.defaults = {
virtualCluster: this.config.virtualCluster
}
}
}
return yaml.safeDump(nniJobConfig);
}
protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
const deferred: Deferred<boolean> = new Deferred<boolean>();
const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`Failed to find PAITrialJobDetail for job ${trialJobId}`);
}
if (this.paiJobRestServer === undefined) {
throw new Error('paiJobRestServer is not initialized');
}
// Make sure experiment code files is copied from local to NFS
if (this.copyExpCodeDirPromise !== undefined) {
await this.copyExpCodeDirPromise;
this.log.info(`Copy codeDir data finished.`);
// All trials share same destination NFS code folder, only copy codeDir once for an experiment.
// After copy data finished, set copyExpCodeDirPromise be undefined to avoid log content duplicated.
this.copyExpCodeDirPromise = undefined;
}
this.paiRestServerPort = this.paiJobRestServer.clusterRestServerPort;
// Step 1. Prepare PAI job configuration
//create trial local working folder locally.
await execMkdir(trialJobDetail.logPath);
// Write NNI installation file to local files
await fs.promises.writeFile(path.join(trialJobDetail.logPath, 'install_nni.sh'), CONTAINER_INSTALL_NNI_SHELL_FORMAT, { encoding: 'utf8' });
// Write file content ( parameter.cfg ) to local working folders
if (trialJobDetail.form !== undefined) {
await this.writeParameterFile(trialJobDetail.logPath, trialJobDetail.form.hyperParameters);
}
//Generate Job Configuration in yaml format
const paiJobConfig = await this.generateJobConfigInYamlFormat(trialJobDetail);
this.log.debug(paiJobConfig);
// Step 2. Submit PAI job via Rest call
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const submitJobRequest: request.Options = {
uri: `${this.config.host}/rest-server/api/v2/jobs`,
method: 'POST',
body: paiJobConfig,
followAllRedirects: true,
headers: {
'Content-Type': 'text/yaml',
Authorization: `Bearer ${this.paiToken}`
}
};
request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
// If submit success, will get status code 202. refer: https://github.com/microsoft/pai/blob/master/src/rest-server/docs/swagger.yaml
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${body}`;
this.log.error(errorMessage);
trialJobDetail.status = 'FAILED';
deferred.reject(errorMessage);
} else {
trialJobDetail.submitTime = Date.now();
}
deferred.resolve(true);
});
return deferred.promise;
}
private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> {
const filepath: string = path.join(directory, generateParamFileName(hyperParameters));
await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
}
public getTrialOutputLocalPath(_trialJobId: string): Promise<string> {
throw new MethodNotImplementedError();
}
public fetchTrialOutput(_trialJobId: string, _subpath: string): Promise<void> {
throw new MethodNotImplementedError();
}
}
export { PAITrainingService };

Просмотреть файл

@ -1,32 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import { ClusterJobRestServer } from '../common/clusterJobRestServer';
import { RemoteMachineTrainingService } from './remoteMachineTrainingService';
/**
* RemoteMachine Training service Rest server, provides rest RemoteMachine to support remotemachine job metrics update
*
*/
export class RemoteMachineJobRestServer extends ClusterJobRestServer {
private readonly remoteMachineTrainingService: RemoteMachineTrainingService;
/**
* constructor to provide NNIRestServer's own rest property, e.g. port
*/
constructor(remoteMachineTrainingService: RemoteMachineTrainingService) {
super();
this.remoteMachineTrainingService = remoteMachineTrainingService;
}
protected handleTrialMetrics(jobId: string, metrics: any[]): void {
// Split metrics array into single metric, then emit
// Warning: If not split metrics into single ones, the behavior will be UNKNOWNls
for (const singleMetric of metrics) {
this.remoteMachineTrainingService.MetricsEmitter.emit('metric', {
id : jobId,
data : singleMetric
});
}
}
}

Просмотреть файл

@ -1,604 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import assert from 'assert';
import { EventEmitter } from 'events';
import fs from 'fs';
import path from 'path';
import { ShellExecutor } from 'training_service/remote_machine/shellExecutor';
import { Deferred } from 'ts-deferred';
import * as component from 'common/component';
import { NNIError, NNIErrorNames, MethodNotImplementedError } from 'common/errors';
import { getExperimentId } from 'common/experimentStartupInfo';
import { getLogger, Logger } from 'common/log';
import { ObservableTimer } from 'common/observableTimer';
import {
HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric
} from 'common/trainingService';
import {
delay, generateParamFileName, getExperimentRootDir, getIPV4Address, getJobCancelStatus,
getVersion, uniqueString
} from 'common/utils';
import { RemoteConfig, RemoteMachineConfig } from 'common/experimentConfig';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { GPUSummary, ScheduleResultType } from '../common/gpuData';
import { execMkdir, validateCodeDir } from '../common/util';
import { GPUScheduler } from './gpuScheduler';
import {
ExecutorManager, RemoteMachineScheduleInfo, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail
} from './remoteMachineData';
import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer';
import { createScriptFile } from 'common/shellUtils';
/**
* Training Service implementation for Remote Machine (Linux)
*/
@component.Singleton
class RemoteMachineTrainingService implements TrainingService {
private readonly initExecutorId = "initConnection";
private readonly machineExecutorManagerMap: Map<RemoteMachineConfig, ExecutorManager>; //machine excutor map
private readonly machineCopyExpCodeDirPromiseMap: Map<RemoteMachineConfig, Promise<void>>;
private readonly trialExecutorManagerMap: Map<string, ExecutorManager>; //trial excutor map
private readonly trialJobsMap: Map<string, RemoteMachineTrialJobDetail>;
private readonly expRootDir: string;
private gpuScheduler?: GPUScheduler;
private readonly jobQueue: string[];
private readonly timer: ObservableTimer;
private stopping: boolean = false;
private readonly metricsEmitter: EventEmitter;
private readonly log: Logger;
private remoteRestServerPort?: number;
private versionCheck: boolean = true;
private logCollection: string = 'none';
private sshConnectionPromises: any[];
private config: RemoteConfig;
constructor(config: RemoteConfig) {
this.metricsEmitter = new EventEmitter();
this.trialJobsMap = new Map<string, RemoteMachineTrialJobDetail>();
this.trialExecutorManagerMap = new Map<string, ExecutorManager>();
this.machineCopyExpCodeDirPromiseMap = new Map<RemoteMachineConfig, Promise<void>>();
this.machineExecutorManagerMap = new Map<RemoteMachineConfig, ExecutorManager>();
this.jobQueue = [];
this.sshConnectionPromises = [];
this.expRootDir = getExperimentRootDir();
this.timer = component.get(ObservableTimer);
this.log = getLogger('RemoteMachineTrainingService');
this.log.info('Construct remote machine training service.');
this.config = config;
if (!fs.lstatSync(this.config.trialCodeDirectory).isDirectory()) {
throw new Error(`codeDir ${this.config.trialCodeDirectory} is not a directory`);
}
validateCodeDir(this.config.trialCodeDirectory);
this.sshConnectionPromises = this.config.machineList.map(
machine => this.initRemoteMachineOnConnected(machine)
);
}
/**
* Loop to launch trial jobs and collect trial metrics
*/
public async run(): Promise<void> {
const restServer = new RemoteMachineJobRestServer(this);
await restServer.start();
restServer.setEnableVersionCheck = this.versionCheck;
this.log.info('Run remote machine training service.');
if (this.sshConnectionPromises.length > 0) {
await Promise.all(this.sshConnectionPromises);
this.log.info('ssh connection initialized!');
// set sshConnectionPromises to [] to avoid log information duplicated
this.sshConnectionPromises = [];
// initialize gpuScheduler
this.gpuScheduler = new GPUScheduler(this.machineExecutorManagerMap);
// Copy codeDir to remote machine
for (const [machineConfig, executorManager] of this.machineExecutorManagerMap.entries()) {
const executor: ShellExecutor = await executorManager.getExecutor(this.initExecutorId);
if (executor !== undefined) {
this.machineCopyExpCodeDirPromiseMap.set(
machineConfig,
executor.copyDirectoryToRemote(this.config.trialCodeDirectory, executor.getRemoteCodePath(getExperimentId()))
);
}
}
}
while (!this.stopping) {
while (this.jobQueue.length > 0) {
this.updateGpuReservation();
const trialJobId: string = this.jobQueue[0];
const prepareResult: boolean = await this.prepareTrialJob(trialJobId);
if (prepareResult) {
// Remove trial job with trialJobId from job queue
this.jobQueue.shift();
} else {
// Break the while loop since no GPU resource is available right now,
// Wait to schedule job in next time iteration
break;
}
}
if (restServer.getErrorMessage !== undefined) {
this.stopping = true;
throw new Error(restServer.getErrorMessage);
}
await delay(3000);
}
this.log.info('RemoteMachineTrainingService run loop exited.');
}
/**
* give trial an executor
* @param trial remote machine trial job detail
*/
public allocateExecutorManagerForTrial(trial: RemoteMachineTrialJobDetail): void {
if (trial.rmMeta === undefined) {
throw new Error(`rmMeta not set in trial ${trial.id}`);
}
const executorManager: ExecutorManager | undefined = this.machineExecutorManagerMap.get(trial.rmMeta.config);
if (executorManager === undefined) {
throw new Error(`executorManager not initialized`);
}
this.trialExecutorManagerMap.set(trial.id, executorManager);
}
/**
* If a trial is finished, release the connection resource
* @param trial remote machine trial job detail
*/
public releaseTrialResource(trial: RemoteMachineTrialJobDetail): void {
if (trial.rmMeta === undefined) {
throw new Error(`rmMeta not set in trial ${trial.id}`);
}
const executorManager = this.trialExecutorManagerMap.get(trial.id);
if (executorManager === undefined) {
throw new Error(`ExecutorManager is not assigned for trial ${trial.id}`);
}
// Note, it still keep reference in trialExecutorManagerMap, as there may be following requests from nni manager.
executorManager.releaseExecutor(trial.id);
}
/**
* List submitted trial jobs
*/
public async listTrialJobs(): Promise<TrialJobDetail[]> {
const jobs: TrialJobDetail[] = [];
const deferred: Deferred<TrialJobDetail[]> = new Deferred<TrialJobDetail[]>();
for (const [key,] of this.trialJobsMap) {
jobs.push(await this.getTrialJob(key));
}
deferred.resolve(jobs);
return deferred.promise;
}
/**
* Get trial job detail information
* @param trialJobId ID of trial job
*/
public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
const trialJob: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJob === undefined) {
throw new NNIError(NNIErrorNames.NOT_FOUND, `trial job id ${trialJobId} not found`);
}
//TO DO: add another job status, and design new job status change logic
if (trialJob.status === 'RUNNING' || trialJob.status === 'UNKNOWN') {
// Get executor where the job is running
if (trialJob.rmMeta === undefined) {
throw new Error(`rmMeta not set for submitted job ${trialJobId}`);
}
const executor = await this.getExecutor(trialJob.id);
return this.updateTrialJobStatus(trialJob, executor);
} else {
return trialJob;
}
}
/**
* Get trial job log
* @param _trialJobId ID of trial job
* @param _logType 'TRIAL_LOG' | 'TRIAL_STDERR'
*/
public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
throw new MethodNotImplementedError();
}
/**
* Add job metrics listener
* @param listener callback listener
*/
public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
this.metricsEmitter.on('metric', listener);
}
/**
* Remove job metrics listener
* @param listener callback listener
*/
public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
this.metricsEmitter.off('metric', listener);
}
/**
* Submit trial job
* @param form trial job description form
*/
public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
// Generate trial job id(random)
const trialJobId: string = form.id === undefined ? uniqueString(5) : form.id;
const trialJobDetail: RemoteMachineTrialJobDetail = new RemoteMachineTrialJobDetail(
trialJobId,
'WAITING',
Date.now(),
"unset",
form
);
this.jobQueue.push(trialJobId);
this.trialJobsMap.set(trialJobId, trialJobDetail);
return Promise.resolve(trialJobDetail);
}
/**
* Update trial job for multi-phase
* @param trialJobId trial job id
* @param form job application form
*/
public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
await this.writeParameterFile(trialJobId, form.hyperParameters);
return trialJobDetail;
}
/**
* Cancel trial job
* @param trialJobId ID of trial job
*/
public async cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
const trialJob: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJob === undefined) {
throw new Error(`trial job id ${trialJobId} not found`);
}
// Remove the job with trialJobId from job queue
const index: number = this.jobQueue.indexOf(trialJobId);
if (index >= 0) {
this.jobQueue.splice(index, 1);
}
// Get executor where the job is running
if (trialJob.rmMeta !== undefined) {
// If the trial job is already scheduled, check its status and kill the trial process in remote machine
const executor = await this.getExecutor(trialJob.id);
if (trialJob.status === 'UNKNOWN') {
trialJob.status = 'USER_CANCELED';
this.releaseTrialResource(trialJob);
return
}
const jobpidPath: string = this.getJobPidPath(executor, trialJob.id);
try {
// Mark the toEarlyStop tag here
trialJob.isEarlyStopped = isEarlyStopped;
await executor.killChildProcesses(jobpidPath);
this.releaseTrialResource(trialJob);
} catch (error) {
// Not handle the error since pkill failed will not impact trial job's current status
this.log.error(`remoteTrainingService.cancelTrialJob: ${error}`);
}
} else {
// Job is not scheduled yet, set status to 'USER_CANCELLED' directly
assert(isEarlyStopped === false, 'isEarlyStopped is not supposed to be true here.');
trialJob.status = getJobCancelStatus(isEarlyStopped);
}
}
public async setClusterMetadata(_key: string, _value: string): Promise<void> { return; }
public async getClusterMetadata(_key: string): Promise<string> { return ''; }
/**
* cleanup() has a time out of 10s to clean remote connections
*/
public async cleanUp(): Promise<void> {
this.log.info('Stopping remote machine training service...');
this.stopping = true;
await this.cleanupConnections();
}
private async getExecutor(trialId: string): Promise<ShellExecutor> {
const executorManager = this.trialExecutorManagerMap.get(trialId);
if (executorManager === undefined) {
throw new Error(`ExecutorManager is not assigned for trial ${trialId}`);
}
return await executorManager.getExecutor(trialId);
}
/**
* remove gpu reversion when job is not running
*/
private updateGpuReservation(): void {
if (this.gpuScheduler) {
for (const [key, value] of this.trialJobsMap) {
if (!['WAITING', 'RUNNING'].includes(value.status)) {
this.gpuScheduler.removeGpuReservation(key, this.trialJobsMap);
}
}
}
}
/**
* stop gpu_metric_collector process in remote machine and remove unused scripts
*/
private async cleanupConnections(): Promise<void> {
try {
for (const executorManager of this.machineExecutorManagerMap.values()) {
const executor = await executorManager.getExecutor(this.initExecutorId);
if (executor !== undefined) {
this.log.info(`killing gpu metric collector on ${executor.name}`);
const gpuJobPidPath: string = executor.joinPath(executor.getRemoteScriptsPath(getExperimentId()), 'pid');
await executor.killChildProcesses(gpuJobPidPath, true);
}
executorManager.releaseAllExecutor();
}
} catch (error) {
//ignore error, this function is called to cleanup remote connections when experiment is stopping
this.log.error(`Cleanup connection exception, error is ${error}`);
}
}
private async initRemoteMachineOnConnected(machineConfig: RemoteMachineConfig): Promise<void> {
const executorManager: ExecutorManager = new ExecutorManager(machineConfig);
this.log.info(`connecting to ${machineConfig.user}@${machineConfig.host}:${machineConfig.port}`);
const executor: ShellExecutor = await executorManager.getExecutor(this.initExecutorId);
this.log.debug(`reached ${executor.name}`);
this.machineExecutorManagerMap.set(machineConfig, executorManager);
this.log.debug(`initializing ${executor.name}`);
// Create root working directory after executor is ready
const nniRootDir: string = executor.joinPath(executor.getTempPath(), 'nni');
await executor.createFolder(executor.getRemoteExperimentRootDir(getExperimentId()));
// the directory to store temp scripts in remote machine
const remoteGpuScriptCollectorDir: string = executor.getRemoteScriptsPath(getExperimentId());
// clean up previous result.
await executor.createFolder(remoteGpuScriptCollectorDir, true);
await executor.allowPermission(true, nniRootDir);
//Begin to execute gpu_metrics_collection scripts
const script = executor.generateGpuStatsScript(getExperimentId());
executor.executeScript(script, false, true);
// the timer is trigger in 1 second, it causes multiple runs on server.
// So reduce it's freqeunce, only allow one of it run.
const collectingCount: boolean[] = [];
const disposable: Rx.IDisposable = this.timer.subscribe(
async () => {
if (collectingCount.length == 0) {
collectingCount.push(true);
const cmdresult = await executor.readLastLines(executor.joinPath(remoteGpuScriptCollectorDir, 'gpu_metrics'));
if (cmdresult !== "") {
executorManager.rmMeta.gpuSummary = <GPUSummary>JSON.parse(cmdresult);
if (executorManager.rmMeta.gpuSummary.gpuCount === 0) {
this.log.warning(`No GPU found on remote machine ${machineConfig.host}`);
this.timer.unsubscribe(disposable);
}
}
if (this.stopping) {
this.timer.unsubscribe(disposable);
this.log.debug(`Stopped GPU collector on ${machineConfig.host}, since experiment is exiting.`);
}
collectingCount.pop();
}
}
);
}
private async prepareTrialJob(trialJobId: string): Promise<boolean> {
const deferred: Deferred<boolean> = new Deferred<boolean>();
if (this.gpuScheduler === undefined) {
throw new Error('gpuScheduler is not initialized');
}
const trialJobDetail: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new NNIError(NNIErrorNames.INVALID_JOB_DETAIL, `Invalid job detail information for trial job ${trialJobId}`);
}
// If job is not WATIING, Don't prepare and resolve true immediately
if (trialJobDetail.status !== 'WAITING') {
deferred.resolve(true);
return deferred.promise;
}
// get an executor from scheduler
const rmScheduleResult: RemoteMachineScheduleResult = this.gpuScheduler.scheduleMachine(this.config.trialGpuNumber, trialJobDetail);
if (rmScheduleResult.resultType === ScheduleResultType.REQUIRE_EXCEED_TOTAL) {
const errorMessage: string = `Required GPU number ${this.config.trialGpuNumber} is too large, no machine can meet`;
this.log.error(errorMessage);
deferred.reject();
throw new NNIError(NNIErrorNames.RESOURCE_NOT_AVAILABLE, errorMessage);
} else if (rmScheduleResult.resultType === ScheduleResultType.SUCCEED
&& rmScheduleResult.scheduleInfo !== undefined) {
const rmScheduleInfo: RemoteMachineScheduleInfo = rmScheduleResult.scheduleInfo;
trialJobDetail.rmMeta = rmScheduleInfo.rmMeta;
const copyExpCodeDirPromise = this.machineCopyExpCodeDirPromiseMap.get(rmScheduleInfo.rmMeta.config);
if (copyExpCodeDirPromise !== undefined) {
await copyExpCodeDirPromise;
}
this.allocateExecutorManagerForTrial(trialJobDetail);
const executor = await this.getExecutor(trialJobDetail.id);
trialJobDetail.workingDirectory = executor.joinPath(executor.getRemoteExperimentRootDir(getExperimentId()), 'trials', trialJobDetail.id);
await this.launchTrialOnScheduledMachine(
trialJobId, trialJobDetail.form, rmScheduleInfo);
trialJobDetail.status = 'RUNNING';
trialJobDetail.url = `file://${rmScheduleInfo.rmMeta.config.host}:${trialJobDetail.workingDirectory}`;
trialJobDetail.startTime = Date.now();
this.trialJobsMap.set(trialJobId, trialJobDetail);
deferred.resolve(true);
} else if (rmScheduleResult.resultType === ScheduleResultType.TMP_NO_AVAILABLE_GPU) {
this.log.info(`Right now no available GPU can be allocated for trial ${trialJobId}, will try to schedule later`);
deferred.resolve(false);
} else {
deferred.reject(`Invalid schedule resutl type: ${rmScheduleResult.resultType}`);
}
return deferred.promise;
}
private async launchTrialOnScheduledMachine(trialJobId: string, form: TrialJobApplicationForm,
rmScheduleInfo: RemoteMachineScheduleInfo): Promise<void> {
const cudaVisibleDevice: string = rmScheduleInfo.cudaVisibleDevice;
const executor = await this.getExecutor(trialJobId);
const trialJobDetail: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`Can not get trial job detail for job: ${trialJobId}`);
}
const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
await executor.createFolder(executor.joinPath(trialJobDetail.workingDirectory, '.nni'));
// RemoteMachineRunShellFormat is the run shell format string,
// See definition in remoteMachineData.ts
let cudaVisible: string;
// Set CUDA_VISIBLE_DEVICES environment variable based on cudaVisibleDevice
// If no valid cudaVisibleDevice is defined, set CUDA_VISIBLE_DEVICES to empty string to hide GPU device
// If gpuNum is undefined, will not set CUDA_VISIBLE_DEVICES in script
if (this.config.trialGpuNumber === undefined) {
cudaVisible = ""
} else {
if (typeof cudaVisibleDevice === 'string' && cudaVisibleDevice.length > 0) {
cudaVisible = `CUDA_VISIBLE_DEVICES=${cudaVisibleDevice}`;
} else {
cudaVisible = `CUDA_VISIBLE_DEVICES=" "`;
}
}
const nniManagerIp: string = this.config.nniManagerIp ? this.config.nniManagerIp : await getIPV4Address();
if (this.remoteRestServerPort === undefined) {
const restServer: RemoteMachineJobRestServer = component.get(RemoteMachineJobRestServer);
this.remoteRestServerPort = restServer.clusterRestServerPort;
}
const version: string = this.versionCheck ? await getVersion() : '';
const runScriptTrialContent: string = executor.generateStartScript(
trialJobDetail.workingDirectory,
trialJobId,
getExperimentId(),
trialJobDetail.form.sequenceId.toString(),
false, // multi-phase
this.config.trialCommand,
nniManagerIp,
this.remoteRestServerPort,
version,
this.logCollection,
cudaVisible);
//create tmp trial working folder locally.
await execMkdir(path.join(trialLocalTempFolder, '.nni'));
// Write install_nni.sh, it's not used in Windows platform.
await createScriptFile(path.join(trialLocalTempFolder, executor.getScriptName("install_nni")), CONTAINER_INSTALL_NNI_SHELL_FORMAT);
// Write file content ( run.sh and parameter.cfg ) to local tmp files
await createScriptFile(path.join(trialLocalTempFolder, executor.getScriptName("run")), runScriptTrialContent);
await this.writeParameterFile(trialJobId, form.hyperParameters);
// Copy files in codeDir to remote working directory
await executor.copyDirectoryToRemote(trialLocalTempFolder, trialJobDetail.workingDirectory);
// Execute command in remote machine
executor.executeScript(executor.joinPath(trialJobDetail.workingDirectory, executor.getScriptName("run")), true, true);
}
private async updateTrialJobStatus(trialJob: RemoteMachineTrialJobDetail, executor: ShellExecutor): Promise<TrialJobDetail> {
const deferred: Deferred<TrialJobDetail> = new Deferred<TrialJobDetail>();
const jobpidPath: string = this.getJobPidPath(executor, trialJob.id);
const trialReturnCodeFilePath: string = executor.joinPath(executor.getRemoteExperimentRootDir(getExperimentId()), 'trials', trialJob.id, '.nni', 'code');
/* eslint-disable require-atomic-updates */
try {
const isAlive = await executor.isProcessAlive(jobpidPath);
// if the process of jobpid is not alive any more
if (!isAlive) {
const trialReturnCode: string = await executor.getRemoteFileContent(trialReturnCodeFilePath);
this.log.debug(`trailjob ${trialJob.id} return code: ${trialReturnCode}`);
const match: RegExpMatchArray | null = trialReturnCode.trim()
.match(/^-?(\d+)\s+(\d+)$/);
if (match !== null) {
const { 1: code, 2: timestamp } = match;
// Update trial job's status based on result code
if (parseInt(code, 10) === 0) {
trialJob.status = 'SUCCEEDED';
} else {
// isEarlyStopped is never set, mean it's not cancelled by NNI, so if the process's exit code >0, mark it as FAILED
if (trialJob.isEarlyStopped === undefined) {
trialJob.status = 'FAILED';
} else {
trialJob.status = getJobCancelStatus(trialJob.isEarlyStopped);
}
}
trialJob.endTime = parseInt(timestamp, 10);
this.releaseTrialResource(trialJob);
}
this.log.debug(`trailJob status update: ${trialJob.id}, ${trialJob.status}`);
}
deferred.resolve(trialJob);
} catch (error) {
this.log.debug(`(Ignorable mostly)Update job status exception, error is ${(error as any).message}`);
if (error instanceof NNIError && error.name === NNIErrorNames.NOT_FOUND) {
deferred.resolve(trialJob);
} else {
trialJob.status = 'UNKNOWN';
deferred.resolve(trialJob);
}
}
/* eslint-enable require-atomic-updates */
return deferred.promise;
}
public get MetricsEmitter(): EventEmitter {
return this.metricsEmitter;
}
private getJobPidPath(executor: ShellExecutor, jobId: string): string {
const trialJobDetail: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(jobId);
if (trialJobDetail === undefined) {
throw new NNIError(NNIErrorNames.INVALID_JOB_DETAIL, `Invalid job detail information for trial job ${jobId}`);
}
return executor.joinPath(trialJobDetail.workingDirectory, '.nni', 'jobpid');
}
private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters): Promise<void> {
const executor = await this.getExecutor(trialJobId);
const trialWorkingFolder: string = executor.joinPath(executor.getRemoteExperimentRootDir(getExperimentId()), 'trials', trialJobId);
const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
const fileName: string = generateParamFileName(hyperParameters);
const localFilepath: string = path.join(trialLocalTempFolder, fileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });
await executor.copyFileToRemote(localFilepath, executor.joinPath(trialWorkingFolder, fileName));
}
public getTrialOutputLocalPath(_trialJobId: string): Promise<string> {
throw new MethodNotImplementedError();
}
public fetchTrialOutput(_trialJobId: string, _subpath: string): Promise<void> {
throw new MethodNotImplementedError();
}
}
export { RemoteMachineTrainingService };

Просмотреть файл

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import * as component from "common/component";
import { IocShim } from 'common/ioc_shim';
import { delay } from "common/utils";
import { CommandChannel, RunnerConnection } from "../commandChannel";
import { Channel, EnvironmentInformation } from "../environment";
@ -65,7 +65,7 @@ export class FileCommandChannel extends CommandChannel {
const start = new Date();
if (this.sendQueues.length > 0) {
const storageService = component.get<StorageService>(StorageService);
const storageService = IocShim.get<StorageService>(StorageService);
while (this.sendQueues.length > 0) {
const item = this.sendQueues.shift();
@ -90,7 +90,7 @@ export class FileCommandChannel extends CommandChannel {
private async receiveLoop(): Promise<void> {
const intervalSeconds = 2;
const storageService = component.get<StorageService>(StorageService);
const storageService = IocShim.get<StorageService>(StorageService);
while (!this.stopping) {
const start = new Date();

Просмотреть файл

@ -3,10 +3,10 @@
import fs from 'fs';
import path from 'path';
import * as component from 'common/component';
import { getLogger, Logger } from 'common/log';
import { AmlConfig } from 'common/experimentConfig';
import { ExperimentStartupInfo } from 'common/experimentStartupInfo';
import { IocShim } from 'common/ioc_shim';
import { getLogger, Logger } from 'common/log';
import { validateCodeDir } from 'training_service/common/util';
import { AMLClient } from '../aml/amlClient';
import { AMLEnvironmentInformation } from '../aml/amlConfig';
@ -18,7 +18,6 @@ import { SharedStorageService } from '../sharedStorage'
/**
* Collector AML jobs info from AML cluster, and update aml job status locally
*/
@component.Singleton
export class AMLEnvironmentService extends EnvironmentService {
private readonly log: Logger = getLogger('AMLEnvironmentService');
@ -89,8 +88,8 @@ export class AMLEnvironmentService extends EnvironmentService {
await fs.promises.mkdir(environmentLocalTempFolder, {recursive: true});
}
if (amlEnvironment.useSharedStorage) {
const environmentRoot = component.get<SharedStorageService>(SharedStorageService).remoteWorkingRoot;
const remoteMountCommand = component.get<SharedStorageService>(SharedStorageService).remoteMountCommand;
const environmentRoot = IocShim.get<SharedStorageService>(SharedStorageService).remoteWorkingRoot;
const remoteMountCommand = IocShim.get<SharedStorageService>(SharedStorageService).remoteMountCommand;
amlEnvironment.command = `${remoteMountCommand} && cd ${environmentRoot} && ${amlEnvironment.command}`.replace(/"/g, `\\"`);
} else {
amlEnvironment.command = `mv envs outputs/envs && cd outputs && ${amlEnvironment.command}`;

Просмотреть файл

@ -3,18 +3,17 @@
import fs from 'fs';
import path from 'path';
import * as component from 'common/component';
import { Deferred } from 'ts-deferred';
import { getLogger, Logger } from 'common/log';
import { DlcConfig } from 'common/experimentConfig';
import { ExperimentStartupInfo } from 'common/experimentStartupInfo';
import { IocShim } from 'common/ioc_shim';
import { getLogger, Logger } from 'common/log';
import { DlcClient } from '../dlc/dlcClient';
import { DlcEnvironmentInformation } from '../dlc/dlcConfig';
import { EnvironmentInformation, EnvironmentService } from '../environment';
import { EventEmitter } from "events";
import { FileCommandChannel } from '../channels/fileCommandChannel';
import { MountedStorageService } from '../storages/mountedStorageService';
import { Scope } from 'typescript-ioc';
import { StorageService } from '../storageService';
import { getLogDir } from 'common/utils';
import { setTimeout } from 'timers/promises';
@ -22,7 +21,6 @@ import { setTimeout } from 'timers/promises';
/**
* Collector DLC jobs info from DLC cluster, and update dlc job status locally
*/
@component.Singleton
export class DlcEnvironmentService extends EnvironmentService {
private readonly log: Logger = getLogger('dlcEnvironmentService');
@ -33,8 +31,8 @@ export class DlcEnvironmentService extends EnvironmentService {
super();
this.experimentId = info.experimentId;
this.config = config;
component.Container.bind(StorageService).to(MountedStorageService).scope(Scope.Singleton);
const storageService = component.get<StorageService>(StorageService)
IocShim.bind(StorageService, MountedStorageService);
const storageService = IocShim.get<StorageService>(StorageService)
const remoteRoot = storageService.joinPath(this.config.localStorageMountPoint, 'nni-experiments', this.experimentId);
const localRoot = storageService.joinPath(this.config.localStorageMountPoint, 'nni-experiments');
storageService.initialize(localRoot, remoteRoot);

Просмотреть файл

@ -1,5 +1,4 @@
import { AMLEnvironmentService } from './amlEnvironmentService';
import { OpenPaiEnvironmentService } from './openPaiEnvironmentService';
import { LocalEnvironmentService } from './localEnvironmentService';
import { RemoteEnvironmentService } from './remoteEnvironmentService';
import { KubeflowEnvironmentService } from './kubernetes/kubeflowEnvironmentService';
@ -22,8 +21,6 @@ export async function createEnvironmentService(config: TrainingServiceConfig): P
return new RemoteEnvironmentService(configAsAny, info);
case 'aml':
return new AMLEnvironmentService(configAsAny, info);
case 'openpai':
return new OpenPaiEnvironmentService(configAsAny, info);
case 'kubeflow':
return new KubeflowEnvironmentService(configAsAny, info);
case 'frameworkcontroller':

Просмотреть файл

@ -6,7 +6,6 @@
import cpp from 'child-process-promise';
import * as fs from 'fs';
import * as path from 'path';
import * as component from '../../../../common/component';
import { FrameworkControllerConfig, FrameworkControllerTaskRoleConfig, toMegaBytes } from '../../../../common/experimentConfig';
import { ExperimentStartupInfo } from '../../../../common/experimentStartupInfo';
import { EnvironmentInformation } from '../../environment';
@ -15,7 +14,6 @@ import { FrameworkControllerClientFactory } from '../../../kubernetes/frameworkc
import { FrameworkControllerJobStatus, FrameworkControllerTrialConfigTemplate,
FrameworkControllerJobCompleteStatus } from '../../../kubernetes/frameworkcontroller/frameworkcontrollerConfig';
@component.Singleton
export class FrameworkControllerEnvironmentService extends KubernetesEnvironmentService {
private config: FrameworkControllerConfig;

Просмотреть файл

@ -4,7 +4,6 @@
import cpp from 'child-process-promise';
import fs from 'fs';
import path from 'path';
import * as component from 'common/component';
import { KubeflowConfig, toMegaBytes } from 'common/experimentConfig';
import { ExperimentStartupInfo } from 'common/experimentStartupInfo';
import { EnvironmentInformation } from 'training_service/reusable/environment';
@ -13,7 +12,6 @@ import { KubeflowOperatorClientFactory } from 'training_service/kubernetes/kubef
import { KubeflowClusterConfigAzure } from 'training_service/kubernetes/kubeflow/kubeflowConfig';
import { KeyVaultConfig, AzureStorage } from 'training_service/kubernetes/kubernetesConfig';
@component.Singleton
export class KubeflowEnvironmentService extends KubernetesEnvironmentService {
private config: KubeflowConfig;

Просмотреть файл

@ -4,17 +4,16 @@
import fs from 'fs';
import path from 'path';
import tkill from 'tree-kill';
import * as component from 'common/component';
import { getLogger, Logger } from 'common/log';
import { ExperimentConfig } from 'common/experimentConfig';
import { ExperimentStartupInfo } from 'common/experimentStartupInfo';
import { IocShim } from 'common/ioc_shim';
import { getLogger, Logger } from 'common/log';
import { powershellString, createScriptFile } from 'common/shellUtils';
import { EnvironmentInformation, EnvironmentService } from '../environment';
import { isAlive, getNewLine } from 'common/utils';
import { execMkdir, runScript, getScriptName, execCopydir } from 'training_service/common/util';
import { SharedStorageService } from '../sharedStorage'
@component.Singleton
export class LocalEnvironmentService extends EnvironmentService {
private readonly log: Logger = getLogger('LocalEnvironmentService');
@ -106,7 +105,7 @@ export class LocalEnvironmentService extends EnvironmentService {
public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
// Need refactor, this temp folder path is not appropriate, there are two expId in this path
const sharedStorageService = component.get<SharedStorageService>(SharedStorageService);
const sharedStorageService = IocShim.get<SharedStorageService>(SharedStorageService);
if (environment.useSharedStorage && sharedStorageService.canLocalMounted) {
this.experimentRootDir = sharedStorageService.localWorkingRoot;
}

Просмотреть файл

@ -1,328 +0,0 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import yaml from 'js-yaml';
import request from 'request';
import { Container, Scope } from 'typescript-ioc';
import { Deferred } from 'ts-deferred';
import * as component from 'common/component';
import { OpenpaiConfig, toMegaBytes } from 'common/experimentConfig';
import { ExperimentStartupInfo } from 'common/experimentStartupInfo';
import { getLogger, Logger } from 'common/log';
import { PAIClusterConfig } from 'training_service/pai/paiConfig';
import { NNIPAITrialConfig } from 'training_service/pai/paiConfig';
import { EnvironmentInformation, EnvironmentService } from '../environment';
import { SharedStorageService } from '../sharedStorage';
import { MountedStorageService } from '../storages/mountedStorageService';
import { StorageService } from '../storageService';
/**
* Collector PAI jobs info from PAI cluster, and update pai job status locally
*/
@component.Singleton
export class OpenPaiEnvironmentService extends EnvironmentService {
private readonly log: Logger = getLogger('OpenPaiEnvironmentService');
private paiClusterConfig: PAIClusterConfig | undefined;
private paiTrialConfig: NNIPAITrialConfig | undefined;
private paiToken: string;
private protocol: string;
private experimentId: string;
private config: OpenpaiConfig;
constructor(config: OpenpaiConfig, info: ExperimentStartupInfo) {
super();
this.experimentId = info.experimentId;
this.config = config;
this.paiToken = this.config.token;
this.protocol = this.config.host.toLowerCase().startsWith('https://') ? 'https' : 'http';
Container.bind(StorageService)
.to(MountedStorageService)
.scope(Scope.Singleton);
const storageService = component.get<StorageService>(StorageService)
const remoteRoot = storageService.joinPath(this.config.localStorageMountPoint, this.experimentId);
storageService.initialize(this.config.localStorageMountPoint, remoteRoot);
}
public get environmentMaintenceLoopInterval(): number {
return 5000;
}
public get hasStorageService(): boolean {
return true;
}
public get getName(): string {
return 'pai';
}
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
if (this.paiToken === undefined) {
throw new Error('PAI token is not initialized');
}
const getJobInfoRequest: request.Options = {
uri: `${this.config.host}/rest-server/api/v2/jobs?username=${this.config.username}`,
method: 'GET',
json: true,
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.paiToken}`
}
};
request(getJobInfoRequest, async (error: any, response: request.Response, body: any) => {
// Status code 200 for success
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`OpenPAI: get environment list from PAI Cluster failed!, http code:${response.statusCode}, http body:' ${JSON.stringify(body)}`;
this.log.error(`${errorMessage}`);
deferred.reject(errorMessage);
} else {
const jobInfos = new Map<string, any>();
body.forEach((jobInfo: any) => {
jobInfos.set(jobInfo.name, jobInfo);
});
environments.forEach((environment) => {
if (jobInfos.has(environment.envId)) {
const jobResponse = jobInfos.get(environment.envId);
if (jobResponse && jobResponse.state) {
const oldEnvironmentStatus = environment.status;
switch (jobResponse.state) {
case 'RUNNING':
case 'WAITING':
case 'SUCCEEDED':
environment.setStatus(jobResponse.state);
break;
case 'FAILED':
environment.setStatus(jobResponse.state);
deferred.reject(`OpenPAI: job ${environment.envId} is failed!`);
break;
case 'STOPPED':
case 'STOPPING':
environment.setStatus('USER_CANCELED');
break;
default:
this.log.error(`OpenPAI: job ${environment.envId} returns unknown state ${jobResponse.state}.`);
environment.setStatus('UNKNOWN');
}
if (oldEnvironmentStatus !== environment.status) {
this.log.debug(`OpenPAI: job ${environment.envId} change status ${oldEnvironmentStatus} to ${environment.status} due to job is ${jobResponse.state}.`)
}
} else {
this.log.error(`OpenPAI: job ${environment.envId} has no state returned. body:`, jobResponse);
// some error happens, and mark this environment
environment.status = 'FAILED';
}
} else {
this.log.error(`OpenPAI job ${environment.envId} is not found in job list.`);
environment.status = 'UNKNOWN';
}
});
deferred.resolve();
}
});
return deferred.promise;
}
public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
if (this.paiToken === undefined) {
throw new Error('PAI token is not initialized');
}
// Step 1. Prepare PAI job configuration
let environmentRoot: string;
if (environment.useSharedStorage) {
environmentRoot = component.get<SharedStorageService>(SharedStorageService).remoteWorkingRoot;
environment.command = `${component.get<SharedStorageService>(SharedStorageService).remoteMountCommand.replace(/echo -e /g, `echo `).replace(/echo /g, `echo -e `)} && cd ${environmentRoot} && ${environment.command}`;
} else {
environmentRoot = `${this.config.containerStorageMountPoint}/${this.experimentId}`;
environment.command = `cd ${environmentRoot} && ${environment.command}`;
}
environment.runnerWorkingFolder = `${environmentRoot}/envs/${environment.id}`;
environment.trackingUrl = `${this.config.host}/job-detail.html?username=${this.config.username}&jobName=${environment.envId}`;
environment.useActiveGpu = false; // does openpai supports these?
environment.maxTrialNumberPerGpu = 1;
// Step 2. Generate Job Configuration in yaml format
const paiJobConfig = this.generateJobConfigInYamlFormat(environment);
this.log.debug(`generated paiJobConfig: ${paiJobConfig}`);
// Step 3. Submit PAI job via Rest call
const submitJobRequest: request.Options = {
uri: `${this.config.host}/rest-server/api/v2/jobs`,
method: 'POST',
body: paiJobConfig,
followAllRedirects: true,
headers: {
'Content-Type': 'text/yaml',
Authorization: `Bearer ${this.paiToken}`
}
};
request(submitJobRequest, (error, response, body) => {
// Status code 202 for success, refer https://github.com/microsoft/pai/blob/master/src/rest-server/docs/swagger.yaml
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`start environment ${environment.envId} failed, http code:${response.statusCode}, http body: ${body}`;
this.log.error(errorMessage);
environment.status = 'FAILED';
deferred.reject(errorMessage);
}
deferred.resolve();
});
return deferred.promise;
}
public async stopEnvironment(environment: EnvironmentInformation): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
if (environment.isAlive === false) {
return Promise.resolve();
}
if (this.paiToken === undefined) {
return Promise.reject(Error('PAI token is not initialized'));
}
const stopJobRequest: request.Options = {
uri: `${this.config.host}/rest-server/api/v2/jobs/${this.config.username}~${environment.envId}/executionType`,
method: 'PUT',
json: true,
body: { value: 'STOP' },
time: true,
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.paiToken}`
}
};
this.log.debug(`stopping OpenPAI environment ${environment.envId}, ${stopJobRequest.uri}`);
try {
request(stopJobRequest, (error, response, _body) => {
try {
// Status code 202 for success.
if ((error !== undefined && error !== null) || (response && response.statusCode >= 400)) {
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`OpenPAI: stop job ${environment.envId} failed, http code:${response.statusCode}, http body: ${_body}`;
this.log.error(`${errorMessage}`);
deferred.reject((error !== undefined && error !== null) ? error :
`Stop trial failed, http code: ${response.statusCode}`);
} else {
this.log.info(`OpenPAI job ${environment.envId} stopped.`);
}
deferred.resolve();
} catch (error) {
this.log.error(`OpenPAI error when inner stopping environment ${error}`);
deferred.reject(error);
}
});
} catch (error) {
this.log.error(`OpenPAI error when stopping environment ${error}`);
return Promise.reject(error);
}
return deferred.promise;
}
private generateJobConfigInYamlFormat(environment: EnvironmentInformation): any {
const jobName = environment.envId;
let nniJobConfig: any = undefined;
if (this.config.openpaiConfig !== undefined) {
nniJobConfig = JSON.parse(JSON.stringify(this.config.openpaiConfig)); //Trick for deep clone in Typescript
nniJobConfig.name = jobName;
if (nniJobConfig.taskRoles) {
environment.nodeCount = 0;
// count instance
for (const taskRoleName in nniJobConfig.taskRoles) {
const taskRole = nniJobConfig.taskRoles[taskRoleName];
let instanceCount = 1;
if (taskRole.instances) {
instanceCount = taskRole.instances;
}
environment.nodeCount += instanceCount;
}
// Each taskRole will generate new command in NNI's command format
// Each command will be formatted to NNI style
for (const taskRoleName in nniJobConfig.taskRoles) {
const taskRole = nniJobConfig.taskRoles[taskRoleName];
// replace ' to '\''
const joinedCommand = taskRole.commands.join(" && ").replace("'", "'\\''").trim();
const nniTrialCommand = `${environment.command} --node_count ${environment.nodeCount} --trial_command '${joinedCommand}'`;
this.log.debug(`replace command ${taskRole.commands} to ${[nniTrialCommand]}`);
taskRole.commands = [nniTrialCommand];
}
}
} else {
nniJobConfig = {
protocolVersion: 2,
name: jobName,
type: 'job',
jobRetryCount: 0,
prerequisites: [
{
type: 'dockerimage',
uri: this.config.dockerImage,
name: 'docker_image_0'
}
],
taskRoles: {
taskrole: {
instances: 1,
completion: {
minFailedInstances: 1,
minSucceededInstances: -1
},
taskRetryCount: 0,
dockerImage: 'docker_image_0',
resourcePerInstance: {
gpu: this.config.trialGpuNumber === undefined? 0: this.config.trialGpuNumber,
cpu: this.config.trialCpuNumber,
memoryMB: toMegaBytes(this.config.trialMemorySize)
},
commands: [
environment.command
]
}
},
extras: {
'storages': [
{
name: this.config.storageConfigName
}
],
submitFrom: 'submit-job-v2'
}
}
if (this.config.virtualCluster) {
nniJobConfig.defaults = {
virtualCluster: this.config.virtualCluster
}
}
}
return yaml.dump(nniJobConfig);
}
protected formatPAIHost(host: string): string {
// If users' host start with 'http://' or 'https://', use the original host,
// or format to 'http//${host}'
if (host.startsWith('http://')) {
this.protocol = 'http';
return host.replace('http://', '');
} else if (host.startsWith('https://')) {
this.protocol = 'https';
return host.replace('https://', '');
} else {
return host;
}
}
}

Просмотреть файл

@ -3,7 +3,7 @@
import fs from 'fs';
import path from 'path';
import * as component from 'common/component';
import { IocShim } from 'common/ioc_shim';
import { getLogger, Logger } from 'common/log';
import { EnvironmentInformation, EnvironmentService } from '../environment';
import { getLogLevel } from 'common/utils';
@ -16,7 +16,6 @@ import { RemoteMachineEnvironmentInformation } from '../remote/remoteConfig';
import { SharedStorageService } from '../sharedStorage';
import { createScriptFile } from 'common/shellUtils';
@component.Singleton
export class RemoteEnvironmentService extends EnvironmentService {
private readonly initExecutorId = "initConnection";
@ -162,7 +161,7 @@ export class RemoteEnvironmentService extends EnvironmentService {
private async releaseEnvironmentResource(environment: EnvironmentInformation): Promise<void> {
if (environment.useSharedStorage) {
const executor = await this.getExecutor(environment.id);
const remoteUmountCommand = component.get<SharedStorageService>(SharedStorageService).remoteUmountCommand;
const remoteUmountCommand = IocShim.get<SharedStorageService>(SharedStorageService).remoteUmountCommand;
const result = await executor.executeScript(remoteUmountCommand, false, false);
if (result.exitCode !== 0) {
this.log.error(`Umount shared storage on remote machine failed.\n ERROR: ${result.stderr}`);
@ -231,11 +230,11 @@ export class RemoteEnvironmentService extends EnvironmentService {
this.environmentExecutorManagerMap.set(environment.id, executorManager);
const executor = await this.getExecutor(environment.id);
if (environment.useSharedStorage) {
this.remoteExperimentRootDir = component.get<SharedStorageService>(SharedStorageService).remoteWorkingRoot;
this.remoteExperimentRootDir = IocShim.get<SharedStorageService>(SharedStorageService).remoteWorkingRoot;
if (!this.remoteExperimentRootDir.startsWith('/')) {
this.remoteExperimentRootDir = executor.joinPath((await executor.getCurrentPath()).trim(), this.remoteExperimentRootDir);
}
const remoteMountCommand = component.get<SharedStorageService>(SharedStorageService).remoteMountCommand.replace(/echo -e /g, `echo `).replace(/echo /g, `echo -e `).replace(/\\\$/g, `\\\\\\$`);
const remoteMountCommand = IocShim.get<SharedStorageService>(SharedStorageService).remoteMountCommand.replace(/echo -e /g, `echo `).replace(/echo /g, `echo -e `).replace(/\\\$/g, `\\\\\\$`);
const result = await executor.executeScript(remoteMountCommand, false, false);
if (result.exitCode !== 0) {
throw new Error(`Mount shared storage on remote machine failed.\n ERROR: ${result.stderr}`);

Просмотреть файл

@ -3,11 +3,9 @@
import { getLogger, Logger } from 'common/log';
import { MethodNotImplementedError } from 'common/errors';
import { ExperimentConfig, RemoteConfig, OpenpaiConfig, KubeflowConfig, FrameworkControllerConfig } from 'common/experimentConfig';
import { ExperimentConfig, RemoteConfig, KubeflowConfig, FrameworkControllerConfig } from 'common/experimentConfig';
import { TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric } from 'common/trainingService';
import { delay } from 'common/utils';
import { PAITrainingService } from '../pai/paiTrainingService';
import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService';
import { KubeflowTrainingService } from '../kubernetes/kubeflow/kubeflowTrainingService';
import { FrameworkControllerTrainingService } from '../kubernetes/frameworkcontroller/frameworkcontrollerTrainingService';
import { TrialDispatcher } from './trialDispatcher';
@ -26,9 +24,7 @@ class RouterTrainingService implements TrainingService {
instance.log = getLogger('RouterTrainingService');
const platform = Array.isArray(config.trainingService) ? 'hybrid' : config.trainingService.platform;
if (platform === 'remote' && (<RemoteConfig>config.trainingService).reuseMode === false) {
instance.internalTrainingService = new RemoteMachineTrainingService(<RemoteConfig>config.trainingService);
} else if (platform === 'openpai' && (<OpenpaiConfig>config.trainingService).reuseMode === false) {
instance.internalTrainingService = new PAITrainingService(<OpenpaiConfig>config.trainingService);
throw new Error('Unexpected: non-reuse remote enters RouterTrainingService');
} else if (platform === 'kubeflow' && (<KubeflowConfig>config.trainingService).reuseMode === false) {
instance.internalTrainingService = new KubeflowTrainingService();
} else if (platform === 'frameworkcontroller' && (<FrameworkControllerConfig>config.trainingService).reuseMode === false) {

Просмотреть файл

@ -5,11 +5,10 @@ import { EventEmitter } from 'events';
import fs from 'fs';
import path from 'path';
import { Writable } from 'stream';
import { Container, Scope } from 'typescript-ioc';
import { String } from 'typescript-string-operations';
import * as component from 'common/component';
import { NNIError, NNIErrorNames, MethodNotImplementedError } from 'common/errors';
import { getBasePort, getExperimentId } from 'common/experimentStartupInfo';
import { IocShim } from 'common/ioc_shim';
import { getLogger, Logger } from 'common/log';
import { TrainingService, TrialJobApplicationForm, TrialJobMetric, TrialJobStatus } from 'common/trainingService';
import { delay, getExperimentRootDir, getIPV4Address, getLogLevel, getVersion, mkDirPSync, randomSelect, uniqueString } from 'common/utils';
@ -35,7 +34,7 @@ import { TrialDetail } from './trial';
* It uses to manage jobs on training platforms
* and expose trial as trial job to upper level.
**/
@component.Singleton
//@component.Singleton MARK
class TrialDispatcher implements TrainingService {
private log: Logger;
private isDeveloping: boolean = false;
@ -213,10 +212,10 @@ class TrialDispatcher implements TrainingService {
let storageService: StorageService;
if (this.useSharedStorage) {
this.log.debug(`TrialDispatcher: use shared storage service.`);
storageService = component.get<SharedStorageService>(SharedStorageService).storageService;
storageService = IocShim.get<SharedStorageService>(SharedStorageService).storageService;
} else if (environmentService.hasStorageService) {
this.log.debug(`TrialDispatcher: use existing storage service.`);
storageService = component.get<StorageService>(StorageService);
storageService = IocShim.get<StorageService>(StorageService);
} else {
this.log.debug(`TrialDispatcher: create temp storage service to temp folder.`);
storageService = new MountedStorageService();
@ -322,7 +321,7 @@ class TrialDispatcher implements TrainingService {
}
if (this.useSharedStorage) {
this.log.info(`stopping shared storage...`)
await component.get<SharedStorageService>(SharedStorageService).cleanUp();
await IocShim.get<SharedStorageService>(SharedStorageService).cleanUp();
this.log.info(`shared storage stopped.`)
}
}
@ -736,10 +735,10 @@ class TrialDispatcher implements TrainingService {
}
trial.message = `Platform: ${environment.environmentService.getName}, environment: ${environment.id}`;
if (this.useSharedStorage) {
const storageService = component.get<SharedStorageService>(SharedStorageService).storageService;
const storageService = IocShim.get<SharedStorageService>(SharedStorageService).storageService;
trial.workingDirectory = storageService.joinPath('trials', trial.id);
} else if (environment.environmentService.hasStorageService) {
const storageService = component.get<StorageService>(StorageService);
const storageService = IocShim.get<StorageService>(StorageService);
trial.workingDirectory = storageService.joinPath('trials', trial.id);
}
trial.settings = {
@ -919,14 +918,10 @@ class TrialDispatcher implements TrainingService {
private async initializeSharedStorage(config: SharedStorageConfig): Promise<void> {
switch (config.storageType) {
case 'NFS':
Container.bind(SharedStorageService)
.to(NFSSharedStorageService)
.scope(Scope.Singleton);
IocShim.bind(SharedStorageService, NFSSharedStorageService);
break;
case 'AzureBlob':
Container.bind(SharedStorageService)
.to(AzureBlobSharedStorageService)
.scope(Scope.Singleton);
IocShim.bind(SharedStorageService, AzureBlobSharedStorageService);
break;
default: {
const errorMessage = `Shared storage type '${config.storageType}' not support.`;
@ -934,7 +929,7 @@ class TrialDispatcher implements TrainingService {
return Promise.reject(errorMessage);
}
}
await component.get<SharedStorageService>(SharedStorageService).config(config);
await IocShim.get<SharedStorageService>(SharedStorageService).config(config);
this.useSharedStorage = true;
return Promise.resolve();
}
@ -942,7 +937,7 @@ class TrialDispatcher implements TrainingService {
public async getTrialOutputLocalPath(trialJobId: string): Promise<string> {
// TODO: support non shared storage
if (this.useSharedStorage) {
const localWorkingRoot = component.get<SharedStorageService>(SharedStorageService).localWorkingRoot;
const localWorkingRoot = IocShim.get<SharedStorageService>(SharedStorageService).localWorkingRoot;
return Promise.resolve(path.join(localWorkingRoot, 'trials', trialJobId));
} else {
return Promise.reject(new Error('Only support shared storage right now.'));