phi-2 / app.py
Benjamin Gonzalez
try to implement streaming
c4f947a
raw
history blame
2.79 kB
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
StoppingCriteriaList,
)
from threading import Thread
import gradio as gr
if torch.cuda.is_available():
torch.set_default_device("cuda")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True,
)
def Phi2StoppingCriteria(
input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs
) -> bool:
stop_list = ["Exercise", "Exercises", "<|endoftext|>"]
stop_tokens = []
for stop in stop_list:
stop_tokens.append(
tokenizer(stop, add_special_tokens=False, return_tensors="pt").input_ids
)
return input_ids[-1] in stop_tokens
stopping_criteria = StoppingCriteriaList([Phi2StoppingCriteria])
def generate(prompt, max_new_tokens):
inputs = tokenizer(prompt, return_tensors="pt")
# thanks https://huggingface.co/spaces/joaogante/transformers_streaming/blob/main/app.py
streamer = TextIteratorStreamer(inputs)
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
stopping_criteria=stopping_criteria,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
model_output = ""
for new_text in streamer:
model_output += new_text
yield model_output
return model_output
demo = gr.Interface(
fn=generate,
inputs=[
gr.Text(
label="prompt",
value="Write a detailed analogy between mathematics and a lighthouse.",
),
gr.Number(value=100, label="max new tokens", maximum=500),
],
outputs="text",
examples=[
[
"Write a detailed analogy between mathematics and a lighthouse.",
75,
],
[
"Instruct: Write a detailed analogy between mathematics and a lighthouse.\nOutput:",
75,
],
[
"Alice: I don't know why, I'm struggling to maintain focus while studying. Any suggestions?\n\nBob: ",
150,
],
[
'''def print_prime(n):
"""
Print all primes between 1 and n
"""\n''',
100,
],
["User: How does sleep affect mood?\nAI:", 125],
["Who was Ada Lovelace?", 100],
["Explain the concept of skip lists.", 125],
],
title="Microsoft Phi-2",
description="Unofficial demo of Microsoft Phi-2, a high performing model with only 2.7B parameters.",
)
if __name__ == "__main__":
demo.launch(show_api=False)