This commit is contained in:
Patrick Brosset 2024-07-15 16:47:01 +02:00
Родитель 6d89893ee3
Коммит 649affc3ef
17 изменённых файлов: 67498 добавлений и 0 удалений

52
on-device-ai/README.md Normal file
Просмотреть файл

@ -0,0 +1,52 @@
# Local Chatbot in the browser using Phi3, ONNX Runtime Web and WebGPU
This repository contains an example of running [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) in your browser using [ONNX Runtime Web](https://github.com/microsoft/onnxruntime) with WebGPU.
You can try out the live demo [here](https://guschmue.github.io/ort-webgpu/chat/index.html).
We keep this example simple and use the onnxruntime-web api directly. ONNX Runtime Web has been powering
higher level frameworks like [transformers.js](https://github.com/xenova/transformers.js).
## Getting Started
### Prerequisites
Ensure that you have [Node.js](https://nodejs.org/) installed on your machine.
### Installation
Install the required dependencies:
```sh
npm install
```
### Building the project
Build the project:
```sh
npm run build
```
The output can be found in the ***dist*** directory.
### Building for developent
```sh
npm run dev
```
This will build the project and start a dev server.
Point your browser to http://localhost:8080/.
### The Phi3 ONNX Model
The model used in this example is hosted on [Hugging Face](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx-web). It is an optimized ONNX version specific to Web and slightly different than the ONNX model for CUDA or CPU:
1. The model output 'logits' is kept as float32 (even for float16 models) since Javascript does not support float16.
2. Our WebGPU implementation uses the custom Multiheaded Attention operator instread of Group Query Attention.
3. Phi3 is larger then 2GB and we need to use external data files. To keep them cacheable in the browser,
both model.onnx and model.onnx.data are kept under 2GB.
If you like to optimize your fine-tuned pytorch Phi-3-min model, you can use [Olive](https://github.com/microsoft/Olive/) which supports float data type conversion and [ONNX genai model builder toolkit](https://github.com/microsoft/onnxruntime-genai/tree/main/src/python/py/models).
An example how to optimize Phi-3-min model for ONNX Runtime Web with Olive can be found [here](https://github.com/microsoft/Olive/tree/main/examples/phi3).

57
on-device-ai/chat.css Normal file
Просмотреть файл

@ -0,0 +1,57 @@
body {
color: #f5f5f5;
font-family: 'Arial', sans-serif;
}
.user-message {
background-color: rgb(86, 144, 163);
color: white;
padding: 10px;
border-radius: 10px;
white-space: pre-wrap;
width: fit-content;
}
.response-message {
background-color: rgb(62, 62, 62);
color: white;
padding: 10px;
border-radius: 10px;
padding-right: 20px;
position: relative;
margin-right: auto;
}
.response-message p {
margin-right: 40px;
}
#chat-container {
display: none;
margin: 0 auto;
overflow: auto;
}
#chat-history {
display: flex;
flex-direction: column;
}
.copy-button {
position: absolute;
bottom: 5px;
right: 5px;
margin: 0 5px 5px 0;
}
#scroll-wrapper {
padding-bottom: 5.5rem;
}
#input-area {
position: fixed;
bottom: 0;
margin-bottom: 5px;
left: 50%;
transform: translateX(-50%);
}

43
on-device-ai/chat.html Normal file
Просмотреть файл

@ -0,0 +1,43 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.1/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-4bw+/aepP/YC94hEpVNVgiZdgIC5+VKNBQNGCHeKRQN+PtmoHDEXuppvnDJzQIu9" crossorigin="anonymous" />
<link rel="stylesheet" href="chat.css">
<title>Chat with Phi-3 mini</title>
</head>
<body data-bs-theme="dark">
<div id="root"></div>
<div class="container">
<div class="row pt-3">
<div class="col-md-8 col-12">
<h2>Chat with Phi-3 mini</h2>
</div>
<div id="status">
</div>
</div>
<div id="scroll-wrapper">
<div id="chat-container" class="card">
<div class="card-body">
<div id="chat-history"></div>
</div>
</div>
</div>
</div>
<div class="container p-0 card" id="input-area">
<div class="input-group">
<textarea class="form-control" id="user-input" placeholder="Type your question here ..."></textarea>
<button id="send-button" class="btn btn-primary">Send</button>
</div>
</div>
<script type="module" src="dist/chat.js"></script>
</body>
</html>

174
on-device-ai/chat.js Normal file
Просмотреть файл

@ -0,0 +1,174 @@
import { Init, Query, Abort } from "./main.js";
import { marked } from "marked";
const preCannedQueries = {
1: "Tell me about the lighthouse of Alexandria.",
2: "Did the lighthouse of Alexandria existed at the same time the library of Alexandria existed?",
3: "How did the Pharos lighthouse impact ancient maritime trade?",
4: "Tell me about Constantinople.",
};
const clipboardIcon = `<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-clipboard" viewBox="0 0 16 16">
<path d="M4 1.5H3a2 2 0 0 0-2 2V14a2 2 0 0 0 2 2h10a2 2 0 0 0 2-2V3.5a2 2 0 0 0-2-2h-1v1h1a1 1 0 0 1 1 1V14a1 1 0 0 1-1 1H3a1 1 0 0 1-1-1V3.5a1 1 0 0 1 1-1h1v-1z"/>
<path d="M9.5 1a.5.5 0 0 1 .5.5v1a.5.5 0 0 1-.5.5h-3a.5.5 0 0 1-.5-.5v-1a.5.5 0 0 1 .5-.5h3zm-3-1A1.5 1.5 0 0 0 5 1.5v1A1.5 1.5 0 0 0 6.5 4h3A1.5 1.5 0 0 0 11 2.5v-1A1.5 1.5 0 0 0 9.5 0h-3z"/>
</svg>`;
marked.use({ mangle: false, headerIds: false });
const sendButton = document.getElementById("send-button");
const scrollWrapper = document.getElementById("scroll-wrapper");
//
// auto scroll the content area until a user scrolls up
//
let isAutoScrollOn = true;
let lastKnownScrollPosition = 0;
let ticking = false;
const autoScroller = new ResizeObserver(() => {
if (isAutoScrollOn) {
scrollWrapper.scrollIntoView({ behavior: "smooth", block: "end" });
}
});
document.addEventListener("scroll", () => {
if (!ticking && isAutoScrollOn && window.scrollY < lastKnownScrollPosition) {
window.requestAnimationFrame(() => {
isAutoScrollOn = false;
ticking = false;
});
ticking = true;
} else if (
!ticking &&
!isAutoScrollOn &&
window.scrollY > lastKnownScrollPosition &&
window.scrollY >=
document.documentElement.scrollHeight - window.innerHeight - 30
) {
window.requestAnimationFrame(() => {
isAutoScrollOn = true;
ticking = false;
});
ticking = true;
}
lastKnownScrollPosition = window.scrollY;
});
//
// make response available for copying to clipboard
//
function copyTextToClipboard(responseDiv) {
let elem = responseDiv;
const copyButton = document.createElement("button");
copyButton.className = "btn btn-secondary copy-button";
copyButton.innerHTML = clipboardIcon;
elem = copyButton;
elem.onclick = () => {
navigator.clipboard.writeText(responseDiv.innerText);
};
responseDiv.appendChild(elem);
}
//
// user hits send, enter or ctl enter
//
async function submitRequest(e) {
if (sendButton.innerHTML == "Stop") {
Abort();
return;
}
// enter clears the chat history, ctl enter will continue the conversation
const continuation = e.ctrlKey && e.key === "Enter";
document.getElementById("chat-container").style.display = "block";
let input = document.getElementById("user-input").value;
if (input.length == 0) {
document.getElementById("chat-history").context = "";
let chatHistory = document.getElementById("chat-history");
while (chatHistory.firstChild) {
chatHistory.firstChild.remove();
}
return;
}
let context = document.getElementById("chat-history").context;
if (context === undefined) {
context = "";
}
// append to chat history
let chatHistory = document.getElementById("chat-history");
let userMessageDiv = document.createElement("div");
userMessageDiv.className = "mb-2 user-message";
userMessageDiv.innerText = input;
chatHistory.appendChild(userMessageDiv);
// container for llm response
let responseDiv = document.createElement("div");
responseDiv.className = "response-message mb-2 text-start";
responseDiv.style.minHeight = "3em";
let spinner = document.createElement("div");
spinner.className = "spinner-border text-light";
spinner.setAttribute("role", "status");
responseDiv.appendChild(spinner);
chatHistory.appendChild(responseDiv);
// toggle button to stop text generation
sendButton.innerHTML = "Stop";
// change autoScroller to keep track of our new responseDiv
autoScroller.observe(responseDiv);
if (continuation) {
input = context + " " + input;
}
Query(continuation, input, (word) => {
responseDiv.innerHTML = marked.parse(word);
})
.then(() => {
chatHistory.context = responseDiv.innerHTML;
copyTextToClipboard(responseDiv, true);
sendButton.innerHTML = "Send";
spinner.remove();
})
.catch((error) => {
console.error(error);
sendButton.innerHTML = "Send";
spinner.remove();
});
// Clear user input
document.getElementById("user-input").value = "";
}
//
// event listener for Ctrl+Enter or Enter
//
document.getElementById("user-input").addEventListener("keydown", function (e) {
if (e.ctrlKey) {
if (e.key === "Enter") {
submitRequest(e);
} else {
const query = preCannedQueries[e.key];
if (query) {
document.getElementById("user-input").value = query;
submitRequest(e);
}
}
} else if (e.key === "Enter") {
e.preventDefault();
submitRequest(e);
}
});
window.onload = () => {
Init().then(() => {
// adjustPadding();
sendButton.addEventListener("click", submitRequest);
const userInput = document.getElementById("user-input");
document.getElementById("status").style.display = "none";
userInput.focus();
});
};

111
on-device-ai/check_coc.html Normal file
Просмотреть файл

@ -0,0 +1,111 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Check that a comment adheres to the Code of Conduct</title>
<style>
body {
font-family: system-ui;
font-size: 1rem;
margin: 2rem;
}
input, textarea, button {
padding: .5rem;
border-radius: .25rem;
font-size: inherit;
font-family: inherit;
}
button {
color: white;
border: 0;
background: green;
}
button[disabled] {
background: linear-gradient(to right, #8bc78b, green);
background-size: 200% 100%;
animation: animate-button 5s alternate infinite;
}
@keyframes animate-button {
0% {
background-position: 100%;
}
100% {
background-position: 0%;
}
}
input, textarea {
border: 1px solid #ccc;
}
textarea {
height: 10rem;
}
.form-row {
margin: 1rem 0;
display: flex;
flex-direction: column;
gap: .5rem;
}
label {
font-weight: bold;
color: #333;
}
button {
align-self: end;
}
#status {
background: #eee;
padding: .5rem;
border-radius: .25rem;
}
#message {
margin: 1rem 0;
padding: 1rem;
border-radius: .25rem;
background: #ac2a2a;
color: white;
}
#message:empty {
display: none;
}
</style>
</head>
<body>
<h1>Submit a new issue</h1>
<form id="form">
<div class="form-row">
<label for="title">Add a title</label>
<input type="text" placeholder="Title">
</div>
<div class="form-row">
<label for="description">Add a description</label>
<div id="message"></div>
<textarea id="description" placeholder="Add your description here..."></textarea>
</div>
<div class="form-row">
<button type="submit" id="submit-button">Submit new issue</button>
</div>
</form>
<div id="status"></div>
<script type="module" src="dist/check_coc.js"></script>
</body>
</html>

125
on-device-ai/check_coc.js Normal file
Просмотреть файл

@ -0,0 +1,125 @@
import { Init, Query, Abort } from "./main.js";
const CoC = `
Participation in the community must be a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation.
Interactions must happen in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.
Examples of behavior that contributes to a positive environment for our community include:
- Demonstrating empathy and kindness toward other people
- Being respectful of differing opinions, viewpoints, and experiences
- Giving and gracefully accepting constructive feedback
- Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience
- Focusing on what is best not just for us as individuals, but for the overall community
Examples of unacceptable behavior include:
- The use of sexualized language or imagery, and sexual attention or advances of any kind
- Trolling, insulting or derogatory comments, and personal or political attacks
- Public or private harassment
- Disruptive behavior
- Publishing others' private information, such as a physical or email address, without their explicit permission
- Other conduct which could reasonably be considered inappropriate in a professional setting
`;
const PROMPT = `
<|CODE OF CONDUCT|>
${CoC}
<|END|>
<|COMMENT|>
{
"comment": "Thank you for this project, very useful to me!",
"isAcceptableBasedOnCodeOfConduct": true
}
<|END|>
<|COMMENT|>
{
"comment": "I found a bug in the code, when I try to run the code in Edge, it doesn't work."
"isAcceptableBasedOnCodeOfConduct": true
}
<|END|>
<|COMMENT|>
{
"comment": "What is this pile of garbage code? It doesn't work at all!"
"isAcceptableBasedOnCodeOfConduct": false,
"reason": "The comment is disrespectful and derogatory."
}
<|END|>
<|COMMENT|>
{
"comment": "[COMMENT]",
"isAcceptableBasedOnCodeOfConduct": `;
const descriptionEl = document.getElementById("description");
const form = document.getElementById("form");
const button = document.getElementById("submit-button");
const message = document.getElementById("message");
let isCheckingComment = false;
function showAsLoading() {
displayMessage();
button.textContent = "Checking your comment ...";
button.classList.add("loading");
button.disabled = true;
}
function showAsNormal() {
button.textContent = "Submit new issue";
button.classList.remove("loading");
button.disabled = false;
}
function displayMessage(str) {
message.textContent = str;
}
async function startApp() {
await Init();
form.onsubmit = (e) => {
if (isCheckingComment) {
return;
}
showAsLoading();
isCheckingComment = true;
e.preventDefault();
const prompt = PROMPT.replace("[COMMENT]", descriptionEl.value);
let data = "";
Query(true, prompt, (word) => {
data = word;
console.log(word);
if (word.includes("<|END|>")) {
Abort();
}
}).then(() => {
try {
const result = JSON.parse(
'{"isAcceptableBasedOnCodeOfConduct":' + data.substring(0, data.indexOf("<|END|>"))
);
if (!result.isAcceptableBasedOnCodeOfConduct) {
displayMessage(result.reason + " Please review the Code of Conduct and modify your comment accordingly.");
} else {
displayMessage();
}
} catch (e) {
displayMessage();
}
showAsNormal();
isCheckingComment = false;
});
};
}
window.onload = () => {
startApp();
};

31773
on-device-ai/dist/chat.js поставляемый Normal file

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

1
on-device-ai/dist/chat.js.map поставляемый Normal file

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

29260
on-device-ai/dist/check_coc.js поставляемый Normal file

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

1
on-device-ai/dist/check_coc.js.map поставляемый Normal file

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Двоичные данные
on-device-ai/dist/ort-wasm-simd-threaded.jsep.wasm поставляемый Normal file

Двоичный файл не отображается.

Двоичные данные
on-device-ai/dist/ort-wasm-simd.jsep.wasm поставляемый Normal file

Двоичный файл не отображается.

283
on-device-ai/llm.js Normal file
Просмотреть файл

@ -0,0 +1,283 @@
import * as ort from "onnxruntime-web/webgpu";
ort.env.wasm.numThreads = 1;
ort.env.wasm.simd = true;
ort.env.wasm.wasmPaths =
document.location.pathname.substring(
0,
document.location.pathname.lastIndexOf("/") + 1
) + "dist/";
function log(i) {
console.log(i);
let logger = document.getElementById("status");
if (!logger) {
logger = document.createElement("div");
logger.id = "status";
document.body.appendChild(logger);
}
logger.innerText += `\n${i}`;
}
//
// load file from server or cache
//
async function fetchAndCache(url) {
try {
const cache = await caches.open("onnx");
let cachedResponse = await cache.match(url);
if (cachedResponse === undefined) {
log(`${url} (network)`);
const buffer = await fetch(url).then((response) =>
response.arrayBuffer()
);
try {
await cache.put(url, new Response(buffer));
} catch (error) {
console.error(error);
}
return buffer;
}
log(`${url} (cached)`);
const data = await cachedResponse.arrayBuffer();
return data;
} catch (error) {
log(`can't fetch ${url}`);
throw error;
}
}
//
// class to handle a large language model on top of onnxruntime-web
//
export class LLM {
sess = undefined;
profiler = false;
feed = {};
output_tokens = [];
eos = 2;
need_position_ids = true;
stop = false;
kv_dims = [];
dtype = "float16";
max_tokens = 9999;
constructor() {}
async load(model, options) {
const provider = options.provider || "webgpu";
const verbose = options.verbose;
const local = options.local;
const hasFP16 = provider === "wasm" ? false : options.hasFP16;
this.profiler = options.profiler;
const model_path = local
? "models/" + model.path
: "https://huggingface.co/" + model.path + "/resolve/main";
let model_file = model.file || "model";
model_file = hasFP16 ? model_file + "_q4f16.onnx" : model_file + "_q4.onnx";
log(`loading... ${model.name}, ${provider}`);
const json_bytes = await fetchAndCache(model_path + "/config.json");
let textDecoder = new TextDecoder();
const model_config = JSON.parse(textDecoder.decode(json_bytes));
const model_bytes = await fetchAndCache(model_path + "/onnx/" + model_file);
const externaldata = model.externaldata
? await fetchAndCache(model_path + "/onnx/" + model_file + "_data")
: false;
let modelSize = model_bytes.byteLength;
if (externaldata) {
modelSize += externaldata.byteLength;
}
log(`model size ${Math.round(modelSize / 1024 / 1024)} MB`);
const opt = {
executionProviders: [provider],
preferredOutputLocation: {},
};
switch (provider) {
case "webgpu":
for (let i = 0; i < model_config.num_hidden_layers; ++i) {
opt.preferredOutputLocation[`present.${i}.key`] = "gpu-buffer";
opt.preferredOutputLocation[`present.${i}.value`] = "gpu-buffer";
}
break;
}
if (externaldata !== undefined) {
opt.externalData = [
{
data: externaldata,
path: model_file + "_data",
},
];
}
if (verbose) {
opt.logSeverityLevel = 0;
opt.logVerbosityLevel = 0;
ort.env.logLevel = "verbose";
}
ort.env.webgpu.profiling = {};
if (this.profiler) {
opt.enableProfiling = true;
ort.env.webgpu.profilingMode = "default";
ort.env.webgpu.profiling.mode = "default";
}
this.sess = await ort.InferenceSession.create(model_bytes, opt);
this.eos = model_config.eos_token_id;
this.kv_dims = [
1,
model_config.num_key_value_heads,
0,
model_config.hidden_size / model_config.num_attention_heads,
];
this.dtype = hasFP16 ? "float16" : "float32";
this.num_layers = model_config.num_hidden_layers;
this.initilize_feed();
}
initilize_feed() {
const feed = this.feed;
// dispose of previous gpu buffers
for (const name in feed) {
const t = feed[name];
if (t.location === "gpu-buffer") {
t.dispose();
}
}
this.feed = {};
// key value cache is zero copy, just pass gpu buffer as referece
const empty = this.dtype === "float16" ? new Uint16Array() : [];
for (let i = 0; i < this.num_layers; ++i) {
this.feed[`past_key_values.${i}.key`] = new ort.Tensor(
this.dtype,
empty,
this.kv_dims
);
this.feed[`past_key_values.${i}.value`] = new ort.Tensor(
this.dtype,
empty,
this.kv_dims
);
}
this.output_tokens = [];
}
//
// poor mens argmax
argmax(t) {
const arr = t.data;
const start = t.dims[2] * (t.dims[1] - 1);
let max = arr[start];
let maxidx = 0;
for (let i = 0; i < t.dims[2]; i++) {
const val = arr[i + start];
if (!isFinite(val)) {
throw new Error("found infinitive in logits");
}
if (val > max) {
max = arr[i + start];
maxidx = i;
}
}
return maxidx;
}
//
// update key value cache
//
update_kv_cache(feed, outputs) {
for (const name in outputs) {
if (name.startsWith("present")) {
let newName = name.replace("present", "past_key_values");
// dispose previous gpu buffers
const t = feed[newName];
if (t.location === "gpu-buffer") {
t.dispose();
}
feed[newName] = outputs[name];
}
}
}
//
// tell generate to stop()
//
abort() {
this.stop = true;
}
//
// prefill prompt and generate tokens, greedy search only
//
async generate(tokens, callback, options) {
const max_tokens = options.max_tokens || 256;
const feed = this.feed;
const input_ids = new ort.Tensor(
"int64",
BigInt64Array.from(tokens.map(BigInt)),
[1, tokens.length]
);
feed["input_ids"] = input_ids;
this.stop = false;
this.output_tokens.push(...input_ids.data);
let last_token = 0n;
let seqlen = this.output_tokens.length;
const input_len = input_ids.size;
if (this.need_position_ids) {
feed["position_ids"] = new ort.Tensor(
"int64",
BigInt64Array.from({ length: input_len }, (_, i) =>
BigInt(seqlen - input_len + i)
),
[1, input_len]
);
}
while (
last_token != this.eos &&
last_token != 32007 &&
seqlen < max_tokens &&
!this.stop
) {
seqlen = this.output_tokens.length;
feed["attention_mask"] = new ort.Tensor(
"int64",
BigInt64Array.from({ length: seqlen }, () => 1n),
[1, seqlen]
);
const outputs = await this.sess.run(feed);
last_token = BigInt(this.argmax(outputs.logits));
this.output_tokens.push(last_token);
if (callback && !this.profiler) {
callback(this.output_tokens);
}
this.update_kv_cache(feed, outputs);
feed["input_ids"] = new ort.Tensor(
"int64",
BigInt64Array.from([last_token]),
[1, 1]
);
if (this.need_position_ids) {
feed["position_ids"] = new ort.Tensor(
"int64",
BigInt64Array.from([BigInt(seqlen)]),
[1, 1]
);
}
}
if (this.profiler) {
this.sess.endProfiling();
}
return this.output_tokens;
}
}

155
on-device-ai/main.js Normal file
Просмотреть файл

@ -0,0 +1,155 @@
import { env, AutoTokenizer } from "@xenova/transformers";
import { LLM } from "./llm.js";
const MODELS = {
phi3: {
name: "phi3",
path: "microsoft/Phi-3-mini-4k-instruct-onnx-web",
externaldata: true,
},
phi3dev: {
name: "phi3dev",
path: "schmuell/Phi-3-mini-4k-instruct-onnx-web",
externaldata: true,
},
};
function getConfig() {
const query = window.location.search.substring(1);
var config = {
model: "phi3",
provider: "webgpu",
profiler: 0,
verbose: 0,
threads: 1,
show_special: 0,
csv: 0,
max_tokens: 9999,
local: 0,
}
let vars = query.split("&");
for (var i = 0; i < vars.length; i++) {
let pair = vars[i].split("=");
if (pair[0] in config) {
const key = pair[0];
const value = decodeURIComponent(pair[1]);
if (typeof config[key] == "number") {
config[key] = parseInt(value);
}
else {
config[key] = value;
}
} else if (pair[0].length > 0) {
throw new Error("unknown argument: " + pair[0]);
}
}
if (MODELS[config.model] !== undefined) {
config.model = MODELS[config.model];
}
return config;
}
const config = getConfig();
// setup for transformers.js tokenizer
env.localModelPath = 'models';
env.allowRemoteModels = config.local == 0;
env.allowLocalModels = config.local == 1;
let tokenizer;
const llm = new LLM();
function token_to_text(tokenizer, tokens, startidx) {
const txt = tokenizer.decode(tokens.slice(startidx), { skip_special_tokens: config.show_special != 1, });
return txt;
}
export async function Query(continuation, query, cb) {
let prompt = (continuation) ? query : `<|system|>\nYou are a friendly assistant.<|end|>\n<|user|>\n${query}<|end|>\n<|assistant|>\n`;
const { input_ids } = await tokenizer(prompt, { return_tensor: false, padding: true, truncation: true });
// clear caches
// TODO: use kv_cache for continuation
llm.initilize_feed();
const start_timer = performance.now();
const output_index = llm.output_tokens.length + input_ids.length;
const output_tokens = await llm.generate(input_ids, (output_tokens) => {
if (output_tokens.length == input_ids.length + 1) {
// time to first token
const took = (performance.now() - start_timer) / 1000;
console.log(`time to first token in ${took.toFixed(1)}sec, ${input_ids.length} tokens`);
}
cb(token_to_text(tokenizer, output_tokens, output_index));
}, { max_tokens: config.max_tokens });
const took = (performance.now() - start_timer) / 1000;
cb(token_to_text(tokenizer, output_tokens, output_index));
const seqlen = output_tokens.length - output_index;
console.log(`${seqlen} tokens in ${took.toFixed(1)}sec, ${(seqlen / took).toFixed(2)} tokens/sec`);
}
export function Abort() {
llm.abort();
}
//
// Load the model and tokenizer
//
async function Start(hasFP16) {
try {
tokenizer = await AutoTokenizer.from_pretrained(config.model.path);
console.log("Loading model...");
await llm.load(config.model, {
provider: config.provider,
profiler: config.profiler,
verbose: config.verbose,
local: config.local,
max_tokens: config.max_tokens,
hasFP16: hasFP16,
});
console.log("Ready.");
} catch (error) {
console.log(error);
}
}
//
// Check if we have webgpu and fp16
//
async function hasWebGPU() {
// returns 0 for webgpu with f16, 1 for webgpu without f16, 2 for no webgpu
if (!("gpu" in navigator)) {
return 2;
}
try {
const adapter = await navigator.gpu.requestAdapter()
if (adapter.features.has('shader-f16')) {
return 0;
}
return 1;
} catch (e) {
return 2;
}
}
// Main entry point, which will load the model.
export function Init() {
return new Promise(resolve => {
hasWebGPU().then((supported) => {
if (supported < 2) {
if (supported == 1) {
console.log("Your GPU or Browser does not support webgpu with fp16, using fp32 instead.");
}
Start(supported === 0).then(() => {
resolve();
});
} else {
console.log("Your GPU or Browser does not support webgpu");
}
});
});
}

5400
on-device-ai/package-lock.json сгенерированный Normal file

Разница между файлами не показана из-за своего большого размера Загрузить разницу

22
on-device-ai/package.json Normal file
Просмотреть файл

@ -0,0 +1,22 @@
{
"name": "localchat",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "webpack serve --no-client-overlay",
"build": "webpack",
"lint": "eslint . --ext js --report-unused-disable-directives"
},
"dependencies": {
"@xenova/transformers": "^2.17.1",
"copy-webpack-plugin": "^12.0.2",
"marked": "^12.0.2",
"onnxruntime-web": "1.19.0-dev.20240509-69cfcba38a",
"webpack": "^5.91.0"
},
"devDependencies": {
"webpack-cli": "^5.1.4",
"webpack-dev-server": "^5.0.4"
}
}

Просмотреть файл

@ -0,0 +1,41 @@
import CopyWebpackPlugin from 'copy-webpack-plugin';
import { fileURLToPath } from 'url';
import path from 'path';
const __dirname = path.dirname(fileURLToPath(import.meta.url));
export default {
mode: 'development',
devtool: 'source-map',
entry: {
'dist/chat': './chat.js',
'dist/check_coc': './check_coc.js',
},
output: {
filename: '[name].js',
path: __dirname,
library: {
type: 'module',
},
},
plugins: [
// Copy .wasm files to dist folder
new CopyWebpackPlugin({
patterns: [
{
from: 'node_modules/onnxruntime-web/dist/*.jsep.*',
to: 'dist/[name][ext]'
},
],
}),
],
devServer: {
static: {
directory: __dirname
},
port: 8080
},
experiments: {
outputModule: true,
},
};