Spaces:
Sleeping
Sleeping
File size: 8,386 Bytes
88d564e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
class BaseStreamer:
"""
Base class from which `.generate()` streamers should inherit.
"""
def put(self, value):
"""Function that is called by `.generate()` to push new tokens"""
raise NotImplementedError()
def end(self):
"""Function that is called by `.generate()` to signal the end of generation"""
raise NotImplementedError()
class TextStreamer(BaseStreamer):
"""
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
<Tip warning={true}>
The API for the streamer classes is still under development and may change in the future.
</Tip>
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextStreamer(tok)
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```
"""
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.decode_kwargs = decode_kwargs
# variables used in the streaming process
self.token_cache = []
self.print_len = 0
self.next_tokens_are_prompt = True
def put(self, value):
"""
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
"""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
# Add the new token to the cache and decodes the entire thing.
self.token_cache.extend(value.tolist())
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
# After the symbol for a new line, we flush the cache.
if text.endswith("\n"):
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
# If the last token is a CJK character, we print the characters.
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len :]
self.print_len += len(printable_text)
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
printable_text = text[self.print_len : text.rfind(" ") + 1]
self.print_len += len(printable_text)
self.on_finalized_text(printable_text)
def end(self):
"""Flushes any remaining cache and prints a newline to stdout."""
# Flush the cache, if it exists
if len(self.token_cache) > 0:
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
else:
printable_text = ""
self.next_tokens_are_prompt = True
self.on_finalized_text(printable_text, stream_end=True)
def on_finalized_text(self, text: str, stream_end: bool = False):
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
# print(text, flush=True, end="" if not stream_end else None)
messages.value = [
*messages.value[:-1],
{
"role": "assistant",
"content": messages.value[-1]["content"] + text,
},
]
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
streamer = TextStreamer(tokenizer, skip_prompt=True)
import re
import solara
from typing import List
from typing_extensions import TypedDict
class MessageDict(TypedDict):
role: str
content: str
messages: solara.Reactive[List[MessageDict]] = solara.reactive([])
@solara.component
def Page():
solara.lab.theme.themes.light.primary = "#0000ff"
solara.lab.theme.themes.light.secondary = "#0000ff"
solara.lab.theme.themes.dark.primary = "#0000ff"
solara.lab.theme.themes.dark.secondary = "#0000ff"
title = "Qwen2-1.5B-Instruct"
with solara.Head():
solara.Title(f"{title}")
with solara.Column(align="center"):
user_message_count = len([m for m in messages.value if m["role"] == "user"])
def send(message):
messages.value = [*messages.value, {"role": "user", "content": message}]
def response(message):
messages.value = [*messages.value, {"role": "assistant", "content": ""}]
text = tokenizer.apply_chat_template(
[{"role": "user", "content": message}],
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(text, return_tensors="pt")
_ = model.generate(**inputs, streamer=streamer, max_new_tokens=512)
def result():
if messages.value != []:
response(messages.value[-1]["content"])
result = solara.lab.use_task(result, dependencies=[user_message_count])
with solara.lab.ChatBox(style={"position": "fixed", "overflow-y": "scroll","scrollbar-width": "none", "-ms-overflow-style": "none", "top": "0", "bottom": "10rem", "width": "70%"}):
for item in messages.value:
with solara.lab.ChatMessage(
user=item["role"] == "user",
name="User" if item["role"] == "user" else "Qwen2-0.5B-Instruct",
avatar_background_color="#33cccc" if item["role"] == "assistant" else "#ff991f",
border_radius="20px",
style="background-color:darkgrey!important;" if solara.lab.theme.dark_effective else "background-color:lightgrey!important;"
):
item["content"] = re.sub('<\|im_end\|>', '', item["content"])
solara.Markdown(item["content"])
solara.lab.ChatInput(send_callback=send, style={"position": "fixed", "bottom": "3rem", "width": "70%"})
|