alonsosilva's picture
Make it compliant with blog
4478f21
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
import re
import solara
from typing import List
from typing_extensions import TypedDict
class MessageDict(TypedDict):
role: str
content: str
def response_generator(message):
text = tokenizer.apply_chat_template(
[{"role": "user", "content": message}],
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(text, return_tensors="pt")
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for chunk in streamer:
yield chunk
def add_chunk_to_ai_message(chunk: str):
messages.value = [
*messages.value[:-1],
{
"role": "assistant",
"content": messages.value[-1]["content"] + chunk,
},
]
messages: solara.Reactive[List[MessageDict]] = solara.reactive([])
@solara.component
def Page():
solara.lab.theme.themes.light.primary = "#0000ff"
solara.lab.theme.themes.light.secondary = "#0000ff"
solara.lab.theme.themes.dark.primary = "#0000ff"
solara.lab.theme.themes.dark.secondary = "#0000ff"
title = "Qwen2-1.5B-Instruct"
with solara.Head():
solara.Title(f"{title}")
with solara.Column(align="center"):
user_message_count = len([m for m in messages.value if m["role"] == "user"])
def send(message):
messages.value = [*messages.value, {"role": "user", "content": message}]
def response(message):
messages.value = [*messages.value, {"role": "assistant", "content": ""}]
for chunk in response_generator(message):
add_chunk_to_ai_message(chunk)
def result():
if messages.value != []:
response(messages.value[-1]["content"])
result = solara.lab.use_task(result, dependencies=[user_message_count])
with solara.lab.ChatBox(style={"position": "fixed", "overflow-y": "scroll","scrollbar-width": "none", "-ms-overflow-style": "none", "top": "0", "bottom": "10rem", "width": "70%"}):
for item in messages.value:
with solara.lab.ChatMessage(
user=item["role"] == "user",
name="User" if item["role"] == "user" else "Qwen2-1.5B-Instruct",
avatar_background_color="#33cccc" if item["role"] == "assistant" else "#ff991f",
border_radius="20px",
style="background-color:darkgrey!important;" if solara.lab.theme.dark_effective else "background-color:lightgrey!important;"
):
item["content"] = re.sub('<\|im_end\|>', '', item["content"])
solara.Markdown(item["content"])
solara.lab.ChatInput(send_callback=send, style={"position": "fixed", "bottom": "3rem", "width": "70%"})