alonsosilva's picture
Add app
88d564e
raw
history blame
8.39 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
class BaseStreamer:
"""
Base class from which `.generate()` streamers should inherit.
"""
def put(self, value):
"""Function that is called by `.generate()` to push new tokens"""
raise NotImplementedError()
def end(self):
"""Function that is called by `.generate()` to signal the end of generation"""
raise NotImplementedError()
class TextStreamer(BaseStreamer):
"""
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
<Tip warning={true}>
The API for the streamer classes is still under development and may change in the future.
</Tip>
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextStreamer(tok)
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```
"""
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.decode_kwargs = decode_kwargs
# variables used in the streaming process
self.token_cache = []
self.print_len = 0
self.next_tokens_are_prompt = True
def put(self, value):
"""
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
"""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
# Add the new token to the cache and decodes the entire thing.
self.token_cache.extend(value.tolist())
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
# After the symbol for a new line, we flush the cache.
if text.endswith("\n"):
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
# If the last token is a CJK character, we print the characters.
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len :]
self.print_len += len(printable_text)
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
printable_text = text[self.print_len : text.rfind(" ") + 1]
self.print_len += len(printable_text)
self.on_finalized_text(printable_text)
def end(self):
"""Flushes any remaining cache and prints a newline to stdout."""
# Flush the cache, if it exists
if len(self.token_cache) > 0:
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
else:
printable_text = ""
self.next_tokens_are_prompt = True
self.on_finalized_text(printable_text, stream_end=True)
def on_finalized_text(self, text: str, stream_end: bool = False):
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
# print(text, flush=True, end="" if not stream_end else None)
messages.value = [
*messages.value[:-1],
{
"role": "assistant",
"content": messages.value[-1]["content"] + text,
},
]
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
streamer = TextStreamer(tokenizer, skip_prompt=True)
import re
import solara
from typing import List
from typing_extensions import TypedDict
class MessageDict(TypedDict):
role: str
content: str
messages: solara.Reactive[List[MessageDict]] = solara.reactive([])
@solara.component
def Page():
solara.lab.theme.themes.light.primary = "#0000ff"
solara.lab.theme.themes.light.secondary = "#0000ff"
solara.lab.theme.themes.dark.primary = "#0000ff"
solara.lab.theme.themes.dark.secondary = "#0000ff"
title = "Qwen2-1.5B-Instruct"
with solara.Head():
solara.Title(f"{title}")
with solara.Column(align="center"):
user_message_count = len([m for m in messages.value if m["role"] == "user"])
def send(message):
messages.value = [*messages.value, {"role": "user", "content": message}]
def response(message):
messages.value = [*messages.value, {"role": "assistant", "content": ""}]
text = tokenizer.apply_chat_template(
[{"role": "user", "content": message}],
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(text, return_tensors="pt")
_ = model.generate(**inputs, streamer=streamer, max_new_tokens=512)
def result():
if messages.value != []:
response(messages.value[-1]["content"])
result = solara.lab.use_task(result, dependencies=[user_message_count])
with solara.lab.ChatBox(style={"position": "fixed", "overflow-y": "scroll","scrollbar-width": "none", "-ms-overflow-style": "none", "top": "0", "bottom": "10rem", "width": "70%"}):
for item in messages.value:
with solara.lab.ChatMessage(
user=item["role"] == "user",
name="User" if item["role"] == "user" else "Qwen2-0.5B-Instruct",
avatar_background_color="#33cccc" if item["role"] == "assistant" else "#ff991f",
border_radius="20px",
style="background-color:darkgrey!important;" if solara.lab.theme.dark_effective else "background-color:lightgrey!important;"
):
item["content"] = re.sub('<\|im_end\|>', '', item["content"])
solara.Markdown(item["content"])
solara.lab.ChatInput(send_callback=send, style={"position": "fixed", "bottom": "3rem", "width": "70%"})