File size: 2,074 Bytes
5213b80 8c7e6f1 573aa88 8c7e6f1 f977d49 8c7e6f1 f977d49 573aa88 2ac97e2 573aa88 8c7e6f1 573aa88 f977d49 2ac97e2 5213b80 573aa88 5213b80 2ac97e2 f977d49 8c7e6f1 f977d49 573aa88 38434c2 573aa88 f977d49 573aa88 f977d49 8c7e6f1 f977d49 8c7e6f1 dd66861 e9a4671 dd66861 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import { type ChatCompletionInputMessage } from '@huggingface/tasks';
import { HfInference } from '@huggingface/inference';
import type { Conversation, ModelEntryWithTokenizer } from '$lib/types';
export function createHfInference(token: string): HfInference {
return new HfInference(token);
}
export async function handleStreamingResponse(
hf: HfInference,
conversation: Conversation,
onChunk: (content: string) => void,
abortController: AbortController,
systemMessage?: ChatCompletionInputMessage
): Promise<void> {
const messages = [
...(isSystemPromptSupported(conversation.model) && systemMessage?.content?.length
? [systemMessage]
: []),
...conversation.messages
];
let out = '';
try {
for await (const chunk of hf.chatCompletionStream(
{
model: conversation.model.id,
messages,
temperature: conversation.config.temperature,
max_tokens: conversation.config.maxTokens
},
{ signal: abortController.signal }
)) {
if (chunk.choices && chunk.choices.length > 0 && chunk.choices[0]?.delta?.content) {
out += chunk.choices[0].delta.content;
onChunk(out);
}
}
} catch (error) {
if (error.name === 'AbortError') {
console.log('Stream aborted');
} else {
throw error;
}
}
}
export async function handleNonStreamingResponse(
hf: HfInference,
conversation: Conversation,
systemMessage?: ChatCompletionInputMessage
): Promise<ChatCompletionInputMessage> {
const messages = [
...(isSystemPromptSupported(conversation.model) && systemMessage?.content?.length
? [systemMessage]
: []),
...conversation.messages
];
const response = await hf.chatCompletion({
model: conversation.model,
messages,
temperature: conversation.config.temperature,
max_tokens: conversation.config.maxTokens
});
if (response.choices && response.choices.length > 0) {
return response.choices[0].message;
}
throw new Error('No response from the model');
}
export function isSystemPromptSupported(model: ModelEntryWithTokenizer) {
return model.tokenizerConfig?.chat_template?.includes('system');
}
|