зеркало из https://github.com/mozilla/gecko-dev.git
Bug 1912276 - Allow using custom model hubs r=Mardak,atossou
Differential Revision: https://phabricator.services.mozilla.com/D218845
This commit is contained in:
Родитель
b007c241bc
Коммит
7d0018a6d3
|
@ -93,16 +93,17 @@ export class MLEngineChild extends JSWindowActorChild {
|
|||
*/
|
||||
async #onNewPortCreated({ port, pipelineOptions }) {
|
||||
try {
|
||||
// Override some options using prefs
|
||||
let options = new lazy.PipelineOptions(pipelineOptions);
|
||||
|
||||
options.updateOptions({
|
||||
// We get some default options from the prefs
|
||||
let options = new lazy.PipelineOptions({
|
||||
modelHubRootUrl: lazy.MODEL_HUB_ROOT_URL,
|
||||
modelHubUrlTemplate: lazy.MODEL_HUB_URL_TEMPLATE,
|
||||
timeoutMS: lazy.CACHE_TIMEOUT_MS,
|
||||
logLevel: lazy.LOG_LEVEL,
|
||||
});
|
||||
|
||||
// And then overwrite with the ones passed in the message
|
||||
options.updateOptions(pipelineOptions);
|
||||
|
||||
// Check if we already have an engine under this id.
|
||||
if (this.#engineDispatchers.has(options.engineId)) {
|
||||
let currentEngineDispatcher = this.#engineDispatchers.get(
|
||||
|
@ -484,16 +485,24 @@ class EngineDispatcher {
|
|||
* @param {string} config.taskName - name of the inference task.
|
||||
* @param {string} config.url - The URL of the model file to fetch. Can be a path relative to
|
||||
* the model hub root or an absolute URL.
|
||||
* @param {string} config.modelHubRootUrl - root url of the model hub. When not provided, uses the default from prefs.
|
||||
* @param {string} config.modefHubUrlTemplate - url template of the model hub. When not provided, uses the default from prefs.
|
||||
* @param {?function(object):Promise<[ArrayBuffer, object]>} config.getModelFileFn - A function that actually retrieves the model data and headers.
|
||||
* @returns {Promise} A promise that resolves to a Meta object containing the URL, response headers,
|
||||
* and data as an ArrayBuffer. The data is marked for transfer to avoid cloning.
|
||||
*/
|
||||
async function getModelFile({ taskName, url, getModelFileFn }) {
|
||||
async function getModelFile({
|
||||
taskName,
|
||||
url,
|
||||
getModelFileFn,
|
||||
modelHubRootUrl,
|
||||
modefHubUrlTemplate,
|
||||
}) {
|
||||
const [data, headers] = await getModelFileFn({
|
||||
taskName,
|
||||
url,
|
||||
rootUrl: lazy.MODEL_HUB_ROOT_URL,
|
||||
urlTemplate: lazy.MODEL_HUB_URL_TEMPLATE,
|
||||
rootUrl: modelHubRootUrl || lazy.MODEL_HUB_ROOT_URL,
|
||||
urlTemplate: modefHubUrlTemplate || lazy.MODEL_HUB_URL_TEMPLATE,
|
||||
});
|
||||
return new lazy.BasePromiseWorker.Meta([url, headers, data], {
|
||||
transfers: [data],
|
||||
|
@ -533,6 +542,8 @@ class InferenceEngine {
|
|||
url,
|
||||
taskName: pipelineOptions.taskName,
|
||||
getModelFileFn,
|
||||
modelHubRootUrl: pipelineOptions.modelHubRootUrl,
|
||||
modelHubUrlTemplate: pipelineOptions.modelHubUrlTemplate,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
|
|
@ -302,6 +302,8 @@ export class MLEngineParent extends JSWindowActorParent {
|
|||
const [data, headers] = await this.modelHub.getModelFileAsArrayBuffer({
|
||||
taskName,
|
||||
...parsedUrl,
|
||||
modelHubRootUrl: rootUrl,
|
||||
modelHubUrlTemplate: urlTemplate,
|
||||
progressCallback: this.notificationsCallback?.bind(this),
|
||||
});
|
||||
|
||||
|
|
|
@ -148,10 +148,19 @@ export class PipelineOptions {
|
|||
"runtimeFilename",
|
||||
];
|
||||
|
||||
if (options instanceof PipelineOptions) {
|
||||
options = options.getOptions();
|
||||
}
|
||||
|
||||
let optionsKeys = Object.keys(options);
|
||||
|
||||
allowedKeys.forEach(key => {
|
||||
if (options[key]) {
|
||||
this[key] = options[key];
|
||||
// If options does not have the key we can ignore it.
|
||||
// We also ignore `null` values.
|
||||
if (!optionsKeys.includes(key) || options[key] == null) {
|
||||
return;
|
||||
}
|
||||
this[key] = options[key];
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -867,6 +867,7 @@ export class ModelHub {
|
|||
* @param {string} config.urlTemplate - The template to retrieve the full URL using a model name and revision.
|
||||
*/
|
||||
constructor({ rootUrl, urlTemplate = DEFAULT_URL_TEMPLATE } = {}) {
|
||||
// Early error when the hub is created on a disallowed url - #fileURL also checks this so API calls with custom hubs are also covered.
|
||||
if (!allowedHub(rootUrl)) {
|
||||
throw new Error(`Invalid model hub root url: ${rootUrl}`);
|
||||
}
|
||||
|
@ -956,13 +957,22 @@ export class ModelHub {
|
|||
|
||||
/** Creates the file URL from the organization, model, and version.
|
||||
*
|
||||
* @param {string} model
|
||||
* @param {string} revision
|
||||
* @param {string} file
|
||||
* @param {object} config - The configuration object to be updated.
|
||||
* @param {string} config.model - model name
|
||||
* @param {string} config.revision - model revision
|
||||
* @param {string} config.file - filename
|
||||
* @param {string} config.modelHubRootUrl - root url of the model hub
|
||||
* @param {string} config.modelHubUrlTemplate - url template of the model hub
|
||||
* @returns {string} The full URL
|
||||
*/
|
||||
#fileUrl(model, revision, file) {
|
||||
const baseUrl = new URL(this.rootUrl);
|
||||
#fileUrl({ model, revision, file, modelHubRootUrl, modelHubUrlTemplate }) {
|
||||
const rootUrl = modelHubRootUrl || this.rootUrl;
|
||||
if (!allowedHub(rootUrl)) {
|
||||
throw new Error(`Invalid model hub root url: ${rootUrl}`);
|
||||
}
|
||||
const urlTemplate = modelHubUrlTemplate || this.urlTemplate;
|
||||
const baseUrl = new URL(rootUrl);
|
||||
|
||||
if (!baseUrl.pathname.endsWith("/")) {
|
||||
baseUrl.pathname += "/";
|
||||
}
|
||||
|
@ -973,7 +983,7 @@ export class ModelHub {
|
|||
model,
|
||||
revision,
|
||||
};
|
||||
let path = this.urlTemplate.replace(
|
||||
let path = urlTemplate.replace(
|
||||
/\{(\w+)\}/g,
|
||||
(match, key) => data[key] || match
|
||||
);
|
||||
|
@ -1112,14 +1122,25 @@ export class ModelHub {
|
|||
* @param {string} config.model - The model name (organization/name).
|
||||
* @param {string} config.revision - The model revision.
|
||||
* @param {string} config.file - The file name.
|
||||
* @param {string} config.modelHubRootUrl - root url of the model hub
|
||||
* @param {string} config.modelHubUrlTemplate - url template of the model hub
|
||||
* @returns {Promise<Response>} The file content
|
||||
*/
|
||||
async getModelFileAsResponse({ taskName, model, revision, file }) {
|
||||
async getModelFileAsResponse({
|
||||
taskName,
|
||||
model,
|
||||
revision,
|
||||
file,
|
||||
modelHubRootUrl,
|
||||
modelHubUrlTemplate,
|
||||
}) {
|
||||
const [blob, headers] = await this.getModelFileAsBlob({
|
||||
taskName,
|
||||
model,
|
||||
revision,
|
||||
file,
|
||||
modelHubRootUrl,
|
||||
modelHubUrlTemplate,
|
||||
});
|
||||
|
||||
return new Response(blob, { headers });
|
||||
|
@ -1133,14 +1154,25 @@ export class ModelHub {
|
|||
* @param {string} config.model - The model name (organization/name).
|
||||
* @param {string} config.revision - The model revision.
|
||||
* @param {string} config.file - The file name.
|
||||
* @param {string} config.modelHubRootUrl - root url of the model hub
|
||||
* @param {string} config.modelHubUrlTemplate - url template of the model hub
|
||||
* @returns {Promise<[Blob, object]>} The file content
|
||||
*/
|
||||
async getModelFileAsBlob({ taskName, model, revision, file }) {
|
||||
async getModelFileAsBlob({
|
||||
taskName,
|
||||
model,
|
||||
revision,
|
||||
file,
|
||||
modelHubRootUrl,
|
||||
modelHubUrlTemplate,
|
||||
}) {
|
||||
const [buffer, headers] = await this.getModelFileAsArrayBuffer({
|
||||
taskName,
|
||||
model,
|
||||
revision,
|
||||
file,
|
||||
modelHubRootUrl,
|
||||
modelHubUrlTemplate,
|
||||
});
|
||||
return [new Blob([buffer]), headers];
|
||||
}
|
||||
|
@ -1154,6 +1186,8 @@ export class ModelHub {
|
|||
* @param {string} config.model - The model name (organization/name).
|
||||
* @param {string} config.revision - The model revision.
|
||||
* @param {string} config.file - The file name.
|
||||
* @param {string} config.modelHubRootUrl - root url of the model hub
|
||||
* @param {string} config.modelHubUrlTemplate - url template of the model hub
|
||||
* @param {?function(ProgressAndStatusCallbackParams):void} config.progressCallback A function to call to indicate progress status.
|
||||
* @returns {Promise<[ArrayBuffer, headers]>} The file content
|
||||
*/
|
||||
|
@ -1162,6 +1196,8 @@ export class ModelHub {
|
|||
model,
|
||||
revision,
|
||||
file,
|
||||
modelHubRootUrl,
|
||||
modelHubUrlTemplate,
|
||||
progressCallback,
|
||||
}) {
|
||||
// Make sure inputs are clean. We don't sanitize them but throw an exception
|
||||
|
@ -1169,7 +1205,13 @@ export class ModelHub {
|
|||
if (checkError) {
|
||||
throw checkError;
|
||||
}
|
||||
const url = this.#fileUrl(model, revision, file);
|
||||
const url = this.#fileUrl({
|
||||
model,
|
||||
revision,
|
||||
file,
|
||||
modelHubRootUrl,
|
||||
modelHubUrlTemplate,
|
||||
});
|
||||
lazy.console.debug(`Getting model file from ${url}`);
|
||||
|
||||
await this.#initCache();
|
||||
|
|
|
@ -393,7 +393,7 @@ export class Pipeline {
|
|||
|
||||
if (this.#genericPipelineFunction) {
|
||||
if (this.#config.modelId === "test-echo") {
|
||||
result = { output: request.args };
|
||||
result = { output: request.args, config: this.#config };
|
||||
} else {
|
||||
result = await this.#genericPipelineFunction(
|
||||
...request.args,
|
||||
|
|
|
@ -12,5 +12,7 @@ lineno = "7"
|
|||
skip-if = ["os == 'linux'"] # see bug 1911083
|
||||
lineno = "10"
|
||||
|
||||
["browser_ml_engine_process.js"]
|
||||
|
||||
["browser_ml_utils.js"]
|
||||
lineno = "13"
|
||||
|
|
|
@ -1440,3 +1440,101 @@ add_task(async function test_initDbFromExistingElseWhereStoreChanges() {
|
|||
|
||||
await deleteCache(cache2);
|
||||
});
|
||||
|
||||
/**
|
||||
* Test that we can use a custom hub on every API call to get files.
|
||||
*/
|
||||
add_task(async function test_getting_file_custom_hub() {
|
||||
// The hub is configured to use localhost
|
||||
const hub = new ModelHub({
|
||||
rootUrl: "https://localhost",
|
||||
urlTemplate: "{model}/boo/revision",
|
||||
});
|
||||
|
||||
// but we can use APIs against another hub
|
||||
const args = {
|
||||
model: "acme/bert",
|
||||
revision: "main",
|
||||
file: "config.json",
|
||||
taskName: "task_model",
|
||||
modelHubRootUrl: FAKE_HUB,
|
||||
modelHubUrlTemplate: "{model}/{revision}",
|
||||
};
|
||||
|
||||
let [array, headers] = await hub.getModelFileAsArrayBuffer(args);
|
||||
|
||||
Assert.equal(headers["Content-Type"], "application/json");
|
||||
|
||||
// check the content of the file.
|
||||
let jsonData = JSON.parse(
|
||||
String.fromCharCode.apply(null, new Uint8Array(array))
|
||||
);
|
||||
|
||||
Assert.equal(jsonData.hidden_size, 768);
|
||||
|
||||
let res = await hub.getModelFileAsBlob(args);
|
||||
Assert.equal(res[0].size, 562);
|
||||
|
||||
let response = await hub.getModelFileAsResponse(args);
|
||||
Assert.equal((await response.blob()).size, 562);
|
||||
});
|
||||
|
||||
/**
|
||||
* Make sure that we can't pass a rootUrl that is not allowed when using the API calls
|
||||
*/
|
||||
add_task(async function test_getting_file_disallowed_custom_hub() {
|
||||
// The hub is configured to use localhost
|
||||
const hub = new ModelHub({
|
||||
rootUrl: "https://localhost",
|
||||
urlTemplate: "{model}/boo/revision",
|
||||
});
|
||||
|
||||
// and we can't use APIs against another hub if it's not allowed
|
||||
const args = {
|
||||
model: "acme/bert",
|
||||
revision: "main",
|
||||
file: "config.json",
|
||||
taskName: "task_model",
|
||||
modelHubRootUrl: "https://forbidden.com",
|
||||
modelHubUrlTemplate: "{model}/{revision}",
|
||||
};
|
||||
|
||||
try {
|
||||
await hub.getModelFileAsArrayBuffer(args);
|
||||
throw new Error("Expected method to reject.");
|
||||
} catch (error) {
|
||||
Assert.throws(
|
||||
() => {
|
||||
throw error;
|
||||
},
|
||||
new RegExp(`Error: Invalid model hub root url: https://forbidden.com`),
|
||||
`Should throw with https://forbidden.com`
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
await hub.getModelFileAsBlob(args);
|
||||
throw new Error("Expected method to reject.");
|
||||
} catch (error) {
|
||||
Assert.throws(
|
||||
() => {
|
||||
throw error;
|
||||
},
|
||||
new RegExp(`Error: Invalid model hub root url: https://forbidden.com`),
|
||||
`Should throw with https://forbidden.com`
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
await hub.getModelFileAsResponse(args);
|
||||
throw new Error("Expected method to reject.");
|
||||
} catch (error) {
|
||||
Assert.throws(
|
||||
() => {
|
||||
throw error;
|
||||
},
|
||||
new RegExp(`Error: Invalid model hub root url: https://forbidden.com`),
|
||||
`Should throw with https://forbidden.com`
|
||||
);
|
||||
}
|
||||
});
|
||||
|
|
|
@ -451,3 +451,55 @@ add_task(async function test_ml_engine_override_options() {
|
|||
await EngineProcess.destroyMLEngine();
|
||||
await cleanup();
|
||||
});
|
||||
|
||||
/**
|
||||
* Tests a custom model hub
|
||||
*/
|
||||
add_task(async function test_ml_custom_hub() {
|
||||
const { cleanup, remoteClients } = await setup();
|
||||
|
||||
info("Get the engine process");
|
||||
const mlEngineParent = await EngineProcess.getMLEngineParent();
|
||||
|
||||
info("Get engineInstance");
|
||||
|
||||
const options = new PipelineOptions({
|
||||
taskName: "summarization",
|
||||
modelId: "test-echo",
|
||||
modelRevision: "main",
|
||||
modelHubRootUrl: "https://example.com",
|
||||
modelHubUrlTemplate: "models/{model}/{revision}",
|
||||
});
|
||||
|
||||
const engineInstance = await mlEngineParent.getEngine(options);
|
||||
|
||||
info("Run the inference");
|
||||
const inferencePromise = engineInstance.run({
|
||||
args: ["This gets echoed."],
|
||||
});
|
||||
|
||||
info("Wait for the pending downloads.");
|
||||
await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1);
|
||||
|
||||
let res = await inferencePromise;
|
||||
|
||||
Assert.equal(
|
||||
res.output,
|
||||
"This gets echoed.",
|
||||
"The text get echoed exercising the whole flow."
|
||||
);
|
||||
|
||||
Assert.equal(
|
||||
res.config.modelHubRootUrl,
|
||||
"https://example.com",
|
||||
"The pipeline used the custom hub"
|
||||
);
|
||||
|
||||
ok(
|
||||
!EngineProcess.areAllEnginesTerminated(),
|
||||
"The engine process is still active."
|
||||
);
|
||||
|
||||
await EngineProcess.destroyMLEngine();
|
||||
await cleanup();
|
||||
});
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
/* Any copyright is dedicated to the Public Domain.
|
||||
http://creativecommons.org/publicdomain/zero/1.0/ */
|
||||
"use strict";
|
||||
|
||||
/**
|
||||
* Check that options are overwritten, but not if there's a null value
|
||||
*/
|
||||
add_task(async function test_options_overwrite() {
|
||||
const options = new PipelineOptions({
|
||||
taskName: "summarization",
|
||||
modelId: "test-echo",
|
||||
modelRevision: "main",
|
||||
});
|
||||
|
||||
Assert.equal(options.taskName, "summarization");
|
||||
options.updateOptions({ taskName: "summarization2", modelId: null });
|
||||
Assert.equal(options.taskName, "summarization2");
|
||||
Assert.equal(options.modelId, "test-echo");
|
||||
});
|
||||
|
||||
/**
|
||||
* Check that updateOptions accepts a PipelineOptions object
|
||||
*/
|
||||
add_task(async function test_options_updated_with_options() {
|
||||
const options = new PipelineOptions({
|
||||
taskName: "summarization",
|
||||
modelId: "test-echo",
|
||||
modelRevision: "main",
|
||||
});
|
||||
const options2 = new PipelineOptions({
|
||||
taskName: "summarization2",
|
||||
modelId: "test-echo",
|
||||
modelRevision: "main",
|
||||
});
|
||||
|
||||
Assert.equal(options.taskName, "summarization");
|
||||
options.updateOptions(options2);
|
||||
Assert.equal(options.taskName, "summarization2");
|
||||
});
|
Загрузка…
Ссылка в новой задаче