Bug 1912276 - Allow using custom model hubs r=Mardak,atossou

Differential Revision: https://phabricator.services.mozilla.com/D218845
This commit is contained in:
Tarek Ziadé 2024-08-08 19:08:23 +00:00
Родитель b007c241bc
Коммит 7d0018a6d3
9 изменённых файлов: 274 добавлений и 19 удалений

Просмотреть файл

@ -93,16 +93,17 @@ export class MLEngineChild extends JSWindowActorChild {
*/
async #onNewPortCreated({ port, pipelineOptions }) {
try {
// Override some options using prefs
let options = new lazy.PipelineOptions(pipelineOptions);
options.updateOptions({
// We get some default options from the prefs
let options = new lazy.PipelineOptions({
modelHubRootUrl: lazy.MODEL_HUB_ROOT_URL,
modelHubUrlTemplate: lazy.MODEL_HUB_URL_TEMPLATE,
timeoutMS: lazy.CACHE_TIMEOUT_MS,
logLevel: lazy.LOG_LEVEL,
});
// And then overwrite with the ones passed in the message
options.updateOptions(pipelineOptions);
// Check if we already have an engine under this id.
if (this.#engineDispatchers.has(options.engineId)) {
let currentEngineDispatcher = this.#engineDispatchers.get(
@ -484,16 +485,24 @@ class EngineDispatcher {
* @param {string} config.taskName - name of the inference task.
* @param {string} config.url - The URL of the model file to fetch. Can be a path relative to
* the model hub root or an absolute URL.
* @param {string} config.modelHubRootUrl - root url of the model hub. When not provided, uses the default from prefs.
* @param {string} config.modefHubUrlTemplate - url template of the model hub. When not provided, uses the default from prefs.
* @param {?function(object):Promise<[ArrayBuffer, object]>} config.getModelFileFn - A function that actually retrieves the model data and headers.
* @returns {Promise} A promise that resolves to a Meta object containing the URL, response headers,
* and data as an ArrayBuffer. The data is marked for transfer to avoid cloning.
*/
async function getModelFile({ taskName, url, getModelFileFn }) {
async function getModelFile({
taskName,
url,
getModelFileFn,
modelHubRootUrl,
modefHubUrlTemplate,
}) {
const [data, headers] = await getModelFileFn({
taskName,
url,
rootUrl: lazy.MODEL_HUB_ROOT_URL,
urlTemplate: lazy.MODEL_HUB_URL_TEMPLATE,
rootUrl: modelHubRootUrl || lazy.MODEL_HUB_ROOT_URL,
urlTemplate: modefHubUrlTemplate || lazy.MODEL_HUB_URL_TEMPLATE,
});
return new lazy.BasePromiseWorker.Meta([url, headers, data], {
transfers: [data],
@ -533,6 +542,8 @@ class InferenceEngine {
url,
taskName: pipelineOptions.taskName,
getModelFileFn,
modelHubRootUrl: pipelineOptions.modelHubRootUrl,
modelHubUrlTemplate: pipelineOptions.modelHubUrlTemplate,
}),
}
);

Просмотреть файл

@ -302,6 +302,8 @@ export class MLEngineParent extends JSWindowActorParent {
const [data, headers] = await this.modelHub.getModelFileAsArrayBuffer({
taskName,
...parsedUrl,
modelHubRootUrl: rootUrl,
modelHubUrlTemplate: urlTemplate,
progressCallback: this.notificationsCallback?.bind(this),
});

Просмотреть файл

@ -148,10 +148,19 @@ export class PipelineOptions {
"runtimeFilename",
];
allowedKeys.forEach(key => {
if (options[key]) {
this[key] = options[key];
if (options instanceof PipelineOptions) {
options = options.getOptions();
}
let optionsKeys = Object.keys(options);
allowedKeys.forEach(key => {
// If options does not have the key we can ignore it.
// We also ignore `null` values.
if (!optionsKeys.includes(key) || options[key] == null) {
return;
}
this[key] = options[key];
});
}

Просмотреть файл

@ -867,6 +867,7 @@ export class ModelHub {
* @param {string} config.urlTemplate - The template to retrieve the full URL using a model name and revision.
*/
constructor({ rootUrl, urlTemplate = DEFAULT_URL_TEMPLATE } = {}) {
// Early error when the hub is created on a disallowed url - #fileURL also checks this so API calls with custom hubs are also covered.
if (!allowedHub(rootUrl)) {
throw new Error(`Invalid model hub root url: ${rootUrl}`);
}
@ -956,13 +957,22 @@ export class ModelHub {
/** Creates the file URL from the organization, model, and version.
*
* @param {string} model
* @param {string} revision
* @param {string} file
* @param {object} config - The configuration object to be updated.
* @param {string} config.model - model name
* @param {string} config.revision - model revision
* @param {string} config.file - filename
* @param {string} config.modelHubRootUrl - root url of the model hub
* @param {string} config.modelHubUrlTemplate - url template of the model hub
* @returns {string} The full URL
*/
#fileUrl(model, revision, file) {
const baseUrl = new URL(this.rootUrl);
#fileUrl({ model, revision, file, modelHubRootUrl, modelHubUrlTemplate }) {
const rootUrl = modelHubRootUrl || this.rootUrl;
if (!allowedHub(rootUrl)) {
throw new Error(`Invalid model hub root url: ${rootUrl}`);
}
const urlTemplate = modelHubUrlTemplate || this.urlTemplate;
const baseUrl = new URL(rootUrl);
if (!baseUrl.pathname.endsWith("/")) {
baseUrl.pathname += "/";
}
@ -973,7 +983,7 @@ export class ModelHub {
model,
revision,
};
let path = this.urlTemplate.replace(
let path = urlTemplate.replace(
/\{(\w+)\}/g,
(match, key) => data[key] || match
);
@ -1112,14 +1122,25 @@ export class ModelHub {
* @param {string} config.model - The model name (organization/name).
* @param {string} config.revision - The model revision.
* @param {string} config.file - The file name.
* @param {string} config.modelHubRootUrl - root url of the model hub
* @param {string} config.modelHubUrlTemplate - url template of the model hub
* @returns {Promise<Response>} The file content
*/
async getModelFileAsResponse({ taskName, model, revision, file }) {
async getModelFileAsResponse({
taskName,
model,
revision,
file,
modelHubRootUrl,
modelHubUrlTemplate,
}) {
const [blob, headers] = await this.getModelFileAsBlob({
taskName,
model,
revision,
file,
modelHubRootUrl,
modelHubUrlTemplate,
});
return new Response(blob, { headers });
@ -1133,14 +1154,25 @@ export class ModelHub {
* @param {string} config.model - The model name (organization/name).
* @param {string} config.revision - The model revision.
* @param {string} config.file - The file name.
* @param {string} config.modelHubRootUrl - root url of the model hub
* @param {string} config.modelHubUrlTemplate - url template of the model hub
* @returns {Promise<[Blob, object]>} The file content
*/
async getModelFileAsBlob({ taskName, model, revision, file }) {
async getModelFileAsBlob({
taskName,
model,
revision,
file,
modelHubRootUrl,
modelHubUrlTemplate,
}) {
const [buffer, headers] = await this.getModelFileAsArrayBuffer({
taskName,
model,
revision,
file,
modelHubRootUrl,
modelHubUrlTemplate,
});
return [new Blob([buffer]), headers];
}
@ -1154,6 +1186,8 @@ export class ModelHub {
* @param {string} config.model - The model name (organization/name).
* @param {string} config.revision - The model revision.
* @param {string} config.file - The file name.
* @param {string} config.modelHubRootUrl - root url of the model hub
* @param {string} config.modelHubUrlTemplate - url template of the model hub
* @param {?function(ProgressAndStatusCallbackParams):void} config.progressCallback A function to call to indicate progress status.
* @returns {Promise<[ArrayBuffer, headers]>} The file content
*/
@ -1162,6 +1196,8 @@ export class ModelHub {
model,
revision,
file,
modelHubRootUrl,
modelHubUrlTemplate,
progressCallback,
}) {
// Make sure inputs are clean. We don't sanitize them but throw an exception
@ -1169,7 +1205,13 @@ export class ModelHub {
if (checkError) {
throw checkError;
}
const url = this.#fileUrl(model, revision, file);
const url = this.#fileUrl({
model,
revision,
file,
modelHubRootUrl,
modelHubUrlTemplate,
});
lazy.console.debug(`Getting model file from ${url}`);
await this.#initCache();

Просмотреть файл

@ -393,7 +393,7 @@ export class Pipeline {
if (this.#genericPipelineFunction) {
if (this.#config.modelId === "test-echo") {
result = { output: request.args };
result = { output: request.args, config: this.#config };
} else {
result = await this.#genericPipelineFunction(
...request.args,

Просмотреть файл

@ -12,5 +12,7 @@ lineno = "7"
skip-if = ["os == 'linux'"] # see bug 1911083
lineno = "10"
["browser_ml_engine_process.js"]
["browser_ml_utils.js"]
lineno = "13"

Просмотреть файл

@ -1440,3 +1440,101 @@ add_task(async function test_initDbFromExistingElseWhereStoreChanges() {
await deleteCache(cache2);
});
/**
* Test that we can use a custom hub on every API call to get files.
*/
add_task(async function test_getting_file_custom_hub() {
// The hub is configured to use localhost
const hub = new ModelHub({
rootUrl: "https://localhost",
urlTemplate: "{model}/boo/revision",
});
// but we can use APIs against another hub
const args = {
model: "acme/bert",
revision: "main",
file: "config.json",
taskName: "task_model",
modelHubRootUrl: FAKE_HUB,
modelHubUrlTemplate: "{model}/{revision}",
};
let [array, headers] = await hub.getModelFileAsArrayBuffer(args);
Assert.equal(headers["Content-Type"], "application/json");
// check the content of the file.
let jsonData = JSON.parse(
String.fromCharCode.apply(null, new Uint8Array(array))
);
Assert.equal(jsonData.hidden_size, 768);
let res = await hub.getModelFileAsBlob(args);
Assert.equal(res[0].size, 562);
let response = await hub.getModelFileAsResponse(args);
Assert.equal((await response.blob()).size, 562);
});
/**
* Make sure that we can't pass a rootUrl that is not allowed when using the API calls
*/
add_task(async function test_getting_file_disallowed_custom_hub() {
// The hub is configured to use localhost
const hub = new ModelHub({
rootUrl: "https://localhost",
urlTemplate: "{model}/boo/revision",
});
// and we can't use APIs against another hub if it's not allowed
const args = {
model: "acme/bert",
revision: "main",
file: "config.json",
taskName: "task_model",
modelHubRootUrl: "https://forbidden.com",
modelHubUrlTemplate: "{model}/{revision}",
};
try {
await hub.getModelFileAsArrayBuffer(args);
throw new Error("Expected method to reject.");
} catch (error) {
Assert.throws(
() => {
throw error;
},
new RegExp(`Error: Invalid model hub root url: https://forbidden.com`),
`Should throw with https://forbidden.com`
);
}
try {
await hub.getModelFileAsBlob(args);
throw new Error("Expected method to reject.");
} catch (error) {
Assert.throws(
() => {
throw error;
},
new RegExp(`Error: Invalid model hub root url: https://forbidden.com`),
`Should throw with https://forbidden.com`
);
}
try {
await hub.getModelFileAsResponse(args);
throw new Error("Expected method to reject.");
} catch (error) {
Assert.throws(
() => {
throw error;
},
new RegExp(`Error: Invalid model hub root url: https://forbidden.com`),
`Should throw with https://forbidden.com`
);
}
});

Просмотреть файл

@ -451,3 +451,55 @@ add_task(async function test_ml_engine_override_options() {
await EngineProcess.destroyMLEngine();
await cleanup();
});
/**
* Tests a custom model hub
*/
add_task(async function test_ml_custom_hub() {
const { cleanup, remoteClients } = await setup();
info("Get the engine process");
const mlEngineParent = await EngineProcess.getMLEngineParent();
info("Get engineInstance");
const options = new PipelineOptions({
taskName: "summarization",
modelId: "test-echo",
modelRevision: "main",
modelHubRootUrl: "https://example.com",
modelHubUrlTemplate: "models/{model}/{revision}",
});
const engineInstance = await mlEngineParent.getEngine(options);
info("Run the inference");
const inferencePromise = engineInstance.run({
args: ["This gets echoed."],
});
info("Wait for the pending downloads.");
await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1);
let res = await inferencePromise;
Assert.equal(
res.output,
"This gets echoed.",
"The text get echoed exercising the whole flow."
);
Assert.equal(
res.config.modelHubRootUrl,
"https://example.com",
"The pipeline used the custom hub"
);
ok(
!EngineProcess.areAllEnginesTerminated(),
"The engine process is still active."
);
await EngineProcess.destroyMLEngine();
await cleanup();
});

Просмотреть файл

@ -0,0 +1,39 @@
/* Any copyright is dedicated to the Public Domain.
http://creativecommons.org/publicdomain/zero/1.0/ */
"use strict";
/**
* Check that options are overwritten, but not if there's a null value
*/
add_task(async function test_options_overwrite() {
const options = new PipelineOptions({
taskName: "summarization",
modelId: "test-echo",
modelRevision: "main",
});
Assert.equal(options.taskName, "summarization");
options.updateOptions({ taskName: "summarization2", modelId: null });
Assert.equal(options.taskName, "summarization2");
Assert.equal(options.modelId, "test-echo");
});
/**
* Check that updateOptions accepts a PipelineOptions object
*/
add_task(async function test_options_updated_with_options() {
const options = new PipelineOptions({
taskName: "summarization",
modelId: "test-echo",
modelRevision: "main",
});
const options2 = new PipelineOptions({
taskName: "summarization2",
modelId: "test-echo",
modelRevision: "main",
});
Assert.equal(options.taskName, "summarization");
options.updateOptions(options2);
Assert.equal(options.taskName, "summarization2");
});