evo-1-131k-base / streamer.py
Zymrael
init
27140ac
from transformers import AutoTokenizer
class BaseStreamer:
"""
Base class from which `.generate()` streamers should inherit.
"""
def put(self, value):
"""Function that is called by `.generate()` to push new tokens"""
raise NotImplementedError()
def end(self):
"""Function that is called by `.generate()` to signal the end of generation"""
raise NotImplementedError()
class ByteStreamer(BaseStreamer):
"""
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
<Tip warning={true}>
The API for the streamer classes is still under development and may change in the future.
</Tip>
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
>>> tok = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextStreamer(tok)
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```
"""
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.decode_kwargs = decode_kwargs
# variables used in the streaming process
self.token_cache = []
self.print_len = 0
self.next_tokens_are_prompt = True
def put(self, value):
"""
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
"""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
# Add the new token to the cache and decodes the entire thing.
self.token_cache.extend(value.tolist())
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
# After the symbol for a new line, we flush the cache.
if text.endswith("\n"):
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
else:
printable_text = text[self.print_len : self.print_len + 1]
self.print_len += len(printable_text)
self.on_finalized_text(printable_text)
def end(self):
"""Flushes any remaining cache and prints a newline to stdout."""
# Flush the cache, if it exists
if len(self.token_cache) > 0:
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
else:
printable_text = ""
self.next_tokens_are_prompt = True
self.on_finalized_text(printable_text, stream_end=True)
def on_finalized_text(self, text: str, stream_end: bool = False):
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
print(text, flush=True, end="" if not stream_end else None)