lauro1's picture
updates
d54ea4b
raw
history blame
10.4 kB
import { pipeline, env } from "@xenova/transformers";
import init, { Model } from "./phi/m.js";
import { streamToAsyncIterable } from "$lib/utils/streamToAsyncIterable";
import URI from "urijs";
import { compileTemplate2 } from "$lib/utils/template";
// Shamelessly stolen from Transformers.js
export async function tryCache(cache, ...names) {
for (let name of names) {
try {
console.log(name);
let result = await cache.match(name);
if (result) return result;
} catch (e) {
continue;
}
}
return undefined;
}
async function read_stream(url, response) {
const reader = response.body.getReader();
const contentLength = +response.headers.get("Content-Length");
let receivedLength = 0;
let chunks = [];
let uri = new URI(url);
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
chunks.push(value);
receivedLength += value.length;
let percent = (receivedLength / contentLength) * 100;
self.postMessage({ status: "progress", file: uri.filename(), progress: percent });
}
let chunksAll = new Uint8Array(receivedLength);
let position = 0;
for (let chunk of chunks) {
chunksAll.set(chunk, position);
position += chunk.length;
}
return chunksAll;
}
async function fetchArrayBuffer(url) {
let cache = await caches.open("transformers-cache");
const response = await tryCache(cache, url);
if (response != undefined) {
console.log(url);
let res = await read_stream(url, response);
cache.put(
url,
new Response(res, {
headers: response.headers,
})
);
return new Uint8Array(res);
} else {
const response = await fetch(url);
let res = await read_stream(url, response);
cache.put(
url,
new Response(res, {
headers: response.headers,
})
);
return new Uint8Array(res);
}
}
class Phi {
static instance = {};
static async getInstance(weightsURL, modelID, tokenizerURL, quantized) {
// load individual modelID only once
if (!this.instance[modelID]) {
await init();
self.postMessage({ status: "loading", message: "Loading Model" });
const [weightsArrayU8, tokenizerArrayU8] = await Promise.all([
fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
]);
self.postMessage({ status: "init_model" });
this.instance[modelID] = new Model(weightsArrayU8, tokenizerArrayU8, quantized);
self.postMessage({ status: "ready", model: "phi-1_5" });
}
return this.instance[modelID];
}
}
export class FlanPipeline {
static curr_model = "";
static instance = null;
static async getInstance(progress_callback = null, model, task) {
if (this.instance === null) {
this.instance = pipeline(task, model, { progress_callback });
this.curr_model = model;
} else {
if (this.curr_model != model) {
this.instance = pipeline(task, model, { progress_callback });
this.curr_model = model;
}
}
return this.instance;
}
}
let controller = null;
let phi_model = null;
// Listen for messages from the main thread
self.addEventListener("message", async (event) => {
if (event.data.command == "abort") {
console.log("ABORT");
if (controller != null) {
try {
controller.abort();
} catch (e) {
console.log(e);
}
}
} else if (event.data.model_obj.is_local ?? true) {
if (event.data.model_obj.is_phi ?? false) {
controller = new AbortController();
generate_phi(event.data);
} else {
let pipe = await FlanPipeline.getInstance(
(x) => {
self.postMessage(x);
},
event.data.model,
event.data.model_obj.type
);
let output = await pipe(event.data.text, {
max_new_tokens: event.data.model_obj.parameters?.max_new_tokens ?? 256,
temperature: event.data.model_obj.parameters?.temperature ?? 0.7,
callback_function: (x) => {
self.postMessage({
status: "update",
output: pipe.tokenizer.decode(x[0].output_token_ids, { skip_special_tokens: true }),
id_now: event.data.id_now,
});
},
});
// Send the output back to the main thread
self.postMessage({
status: "complete",
output: output,
searchID: event.data.searchID,
id_now: event.data.id_now,
model: "phi-1_5",
});
}
} else {
const m = {
preprompt: event.data.model_obj.preprompt,
userMessageToken: event.data.model_obj.userMessageToken,
userMessageEndToken: event.data.model_obj.userMessageEndToken,
assistantMessageToken: event.data.model_obj.assistantMessageToken,
assistantMessageEndToken: event.data.model_obj.assistantMessageEndToken,
}
console.log(event.data.model_obj.chatPromptTemplate)
const t = compileTemplate2(event.data.model_obj.chatPromptTemplate, m)
const res = t({messages: event.data.messages, preprompt: m.preprompt})
console.log(res)
controller = new AbortController();
const context = buildContext(event.data);
const newParameters = {
max_new_tokens: event.data.model_obj.parameters?.max_new_tokens ?? 256,
temperature: event.data.model_obj.parameters?.temperature ?? 0.7,
truncate: event.data.model_obj.parameters?.truncate ?? 2048,
return_full_text: false,
};
let body = JSON.stringify({
inputs: res,
parameters: newParameters,
});
let text_output = "";
const server_addr = event.data.model_obj.server_addr ?? ""
try {
let resp = await fetch(server_addr + "/generate_stream", {
headers: {
"Content-Type": "application/json",
accesstoken: event.data.jwt,
},
method: "POST",
body: body,
signal: controller.signal,
});
if (resp.ok) {
let stream1 = resp.body;
for await (const input of streamToAsyncIterable(stream1)) {
const lines = new TextDecoder()
.decode(input)
.split("\n")
.filter((line) => line.startsWith("data:"));
for (const message of lines) {
let lastIndex = message.lastIndexOf("\ndata:");
if (lastIndex === -1) {
lastIndex = message.indexOf("data");
}
if (lastIndex === -1) {
console.error("Could not parse last message", message);
}
let lastMessage = message.slice(lastIndex).trim().slice("data:".length);
if (lastMessage.includes("\n")) {
lastMessage = lastMessage.slice(0, lastMessage.indexOf("\n"));
}
try {
const lastMessageJSON = JSON.parse(lastMessage);
if (!lastMessageJSON.generated_text) {
const res = lastMessageJSON.token.text;
text_output += res;
self.postMessage({
status: "update",
output: text_output,
id_now: event.data.id_now,
});
}
} catch (e) {
console.log(lastMessage);
console.log(e);
}
}
}
} else {
if (resp.status == 401 || resp.status == 403) {
self.postMessage({
status: "invalid_jwt",
});
}
console.log(resp);
self.postMessage({
status: "aborted",
output: text_output,
searchID: event.data.searchID,
id_now: event.data.id_now,
})
self.postMessage({
status: "error",
output: text_output,
error: "Error while trying to communicate with the server",
})
return;
}
} catch (e) {
console.log(e)
self.postMessage({
status: "aborted",
output: text_output,
searchID: event.data.searchID,
id_now: event.data.id_now,
})
if (e.name != "AbortError") {
self.postMessage({
status: "error",
output: text_output,
error: "Error while trying to communicate with the server",
})
}
return;
}
self.postMessage({
status: "complete",
output: text_output,
searchID: event.data.searchID,
id_now: event.data.id_now,
});
}
});
async function generate_phi(data) {
const tokenizerURL = "https://huggingface.co/microsoft/phi-1_5/raw/main/tokenizer.json";
const weightsURL = "https://huggingface.co/lmz/candle-quantized-phi/resolve/main/model-q4k.gguf";
let prompt = data.text;
let maxSeqLen = data.model_obj.parameters?.max_new_tokens ?? 256;
let temp = data.model_obj.parameters?.temperature ?? 0.7;
let modelID = 0;
let quantized = true;
let top_p = 1;
let repeatPenalty = 1.1;
let seed = 299792458;
self.postMessage({ status: "initiate", file: "tokenizer.json", name: "phi-1_5" }); // Fake init
try {
const model = await Phi.getInstance(weightsURL, modelID, tokenizerURL, quantized);
const firstToken = model.init_with_prompt(prompt, temp, top_p, repeatPenalty, 64, BigInt(seed));
const seq_len = 2048;
let sentence = firstToken;
let maxTokens = maxSeqLen ? maxSeqLen : seq_len - prompt.length - 1;
let startTime = performance.now();
let tokensCount = 0;
while (tokensCount < maxTokens) {
await new Promise(async (resolve) => {
if (controller && controller.signal.aborted) {
self.postMessage({
status: "aborted",
message: "Aborted",
output: sentence,
searchID: data.searchID,
id_now: data.id_now,
});
return;
}
const token = await model.next_token();
if (token === "<|endoftext|>") {
self.postMessage({
status: "complete",
output: sentence,
searchID: data.searchID,
id_now: data.id_now,
});
return;
}
const tokensSec = ((tokensCount + 1) / (performance.now() - startTime)) * 1000;
sentence += token;
self.postMessage({
status: "update",
message: "Generating token",
token: token,
output: sentence,
totalTime: performance.now() - startTime,
tokensSec,
prompt: prompt,
id_now: data.id_now,
});
setTimeout(resolve, 0);
});
tokensCount++;
}
self.postMessage({
status: "complete",
output: sentence,
searchID: data.searchID,
id_now: data.id_now,
});
} catch (e) {
console.log(e);
self.postMessage({ error: e });
}
}
function buildContext(data) {
// Will be replaced by the original contextManager made by HF
let context = "";
let got_user_prompt = false;
for (let message of data.messages) {
if (message.content.trim().length > 0) {
if (message.from === "user") {
if (got_user_prompt == false) {
context = context + "<s>[INST] " + message.content;
got_user_prompt = true;
} else {
context = context + " " + message.content;
}
} else {
got_user_prompt = false;
context = context + " [/INST]" + message.content + " </s>";
}
}
}
if (got_user_prompt == true) {
context = context + " [/INST]";
} else {
context = context + "<s>[INST] " + data.text + " [/INST]";
}
return context;
}