Spaces:
Running
Running
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from threading import Thread | |
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") | |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
import re | |
import solara | |
from typing import List | |
from typing_extensions import TypedDict | |
class MessageDict(TypedDict): | |
role: str | |
content: str | |
def response_generator(message): | |
text = tokenizer.apply_chat_template( | |
[{"role": "user", "content": message}], | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
inputs = tokenizer(text, return_tensors="pt") | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
for chunk in streamer: | |
yield chunk | |
def add_chunk_to_ai_message(chunk: str): | |
messages.value = [ | |
*messages.value[:-1], | |
{ | |
"role": "assistant", | |
"content": messages.value[-1]["content"] + chunk, | |
}, | |
] | |
messages: solara.Reactive[List[MessageDict]] = solara.reactive([]) | |
def Page(): | |
solara.lab.theme.themes.light.primary = "#0000ff" | |
solara.lab.theme.themes.light.secondary = "#0000ff" | |
solara.lab.theme.themes.dark.primary = "#0000ff" | |
solara.lab.theme.themes.dark.secondary = "#0000ff" | |
title = "Qwen2-0.5B-Instruct" | |
with solara.Head(): | |
solara.Title(f"{title}") | |
with solara.Column(align="center"): | |
user_message_count = len([m for m in messages.value if m["role"] == "user"]) | |
def send(message): | |
messages.value = [*messages.value, {"role": "user", "content": message}] | |
def response(message): | |
messages.value = [*messages.value, {"role": "assistant", "content": ""}] | |
for chunk in response_generator(message): | |
add_chunk_to_ai_message(chunk) | |
def result(): | |
if messages.value != []: | |
response(messages.value[-1]["content"]) | |
result = solara.lab.use_task(result, dependencies=[user_message_count]) | |
with solara.lab.ChatBox(style={"position": "fixed", "overflow-y": "scroll","scrollbar-width": "none", "-ms-overflow-style": "none", "top": "0", "bottom": "10rem", "width": "70%"}): | |
for item in messages.value: | |
with solara.lab.ChatMessage( | |
user=item["role"] == "user", | |
name="User" if item["role"] == "user" else "Qwen2-0.5B-Instruct", | |
avatar_background_color="#33cccc" if item["role"] == "assistant" else "#ff991f", | |
border_radius="20px", | |
style="background-color:darkgrey!important;" if solara.lab.theme.dark_effective else "background-color:lightgrey!important;" | |
): | |
item["content"] = re.sub('<\|im_end\|>', '', item["content"]) | |
solara.Markdown(item["content"]) | |
solara.lab.ChatInput(send_callback=send, style={"position": "fixed", "bottom": "3rem", "width": "70%"}) | |