Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 3,241 Bytes
879455c d64b893 6463491 879455c 6463491 879455c 6463491 879455c 6463491 1e641f1 6463491 1e641f1 6463491 879455c 6463491 879455c 5dd2af5 879455c 9bfb451 879455c 9052a89 879455c 9052a89 879455c 9052a89 879455c 9052a89 879455c 9bfb451 81eb27e 879455c 9052a89 879455c 9052a89 879455c 9052a89 879455c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
"use server"
import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"
import { LLMEngine } from "@/types"
export async function predict(inputs: string, nbMaxNewTokens: number): Promise<string> {
const hf = new HfInference(process.env.AUTH_HF_API_TOKEN)
const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
const inferenceEndpoint = `${process.env.LLM_HF_INFERENCE_ENDPOINT_URL || ""}`
const inferenceModel = `${process.env.LLM_HF_INFERENCE_API_MODEL || ""}`
let hfie: HfInferenceEndpoint = hf
switch (llmEngine) {
case "INFERENCE_ENDPOINT":
if (inferenceEndpoint) {
// console.log("Using a custom HF Inference Endpoint")
hfie = hf.endpoint(inferenceEndpoint)
} else {
const error = "No Inference Endpoint URL defined"
console.error(error)
throw new Error(error)
}
break;
case "INFERENCE_API":
if (inferenceModel) {
// console.log("Using an HF Inference API Model")
} else {
const error = "No Inference API model defined"
console.error(error)
throw new Error(error)
}
break;
default:
const error = "Please check your Hugging Face Inference API or Inference Endpoint settings"
console.error(error)
throw new Error(error)
}
const api = llmEngine === "INFERENCE_ENDPOINT" ? hfie : hf
let instructions = ""
try {
for await (const output of api.textGenerationStream({
model: llmEngine === "INFERENCE_ENDPOINT" ? undefined : (inferenceModel || undefined),
inputs,
parameters: {
do_sample: true,
max_new_tokens: nbMaxNewTokens,
return_full_text: false,
}
})) {
instructions += output.token.text
// process.stdout.write(output.token.text)
if (
instructions.includes("</s>") ||
instructions.includes("<s>") ||
instructions.includes("/s>") ||
instructions.includes("[INST]") ||
instructions.includes("[/INST]") ||
instructions.includes("<SYS>") ||
instructions.includes("<<SYS>>") ||
instructions.includes("</SYS>") ||
instructions.includes("<</SYS>>") ||
instructions.includes("<|user|>") ||
instructions.includes("<|end|>") ||
instructions.includes("<|system|>") ||
instructions.includes("<|assistant|>")
) {
break
}
}
} catch (err) {
// console.error(`error during generation: ${err}`)
// a common issue with Llama-2 might be that the model receives too many requests
if (`${err}` === "Error: Model is overloaded") {
instructions = ``
}
}
// need to do some cleanup of the garbage the LLM might have gave us
return (
instructions
.replaceAll("<|end|>", "")
.replaceAll("<s>", "")
.replaceAll("</s>", "")
.replaceAll("/s>", "")
.replaceAll("[INST]", "")
.replaceAll("[/INST]", "")
.replaceAll("<SYS>", "")
.replaceAll("<<SYS>>", "")
.replaceAll("</SYS>", "")
.replaceAll("<</SYS>>", "")
.replaceAll("<|system|>", "")
.replaceAll("<|user|>", "")
.replaceAll("<|all|>", "")
.replaceAll("<|assistant|>", "")
.replaceAll('""', '"')
)
}
|