hiyouga's picture
Update app.py
bca5e76 verified
raw
history blame
2.59 kB
from threading import Thread
from typing import Dict
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, TextIteratorStreamer
TITLE = "<h1><center>Chat with PaliGemma-3B-Chat-v0.1</center></h1>"
DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/hiyouga/PaliGemma-3B-Chat-v0.1' target='_blank'>our model page</a> for details.</center></h3>"
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
model_id = "hiyouga/PaliGemma-3B-Chat-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
@spaces.GPU
def stream_chat(message: Dict[str, str], history: list):
# {'text': 'what is this', 'files': ['image-xxx.jpg']}
image = Image.open(message["files"][0])
pixel_values = processor(images=[image], return_tensors="pt").to(model.device)["pixel_values"]
conversation = []
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message["text"]})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
image_token_id = tokenizer.convert_tokens_to_ids("<image>")
image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id)
input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
pixel_values=pixel_values,
streamer=streamer,
max_new_tokens=256,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
output = ""
for new_token in streamer:
output += new_token
yield output
chatbot = gr.Chatbot(height=450)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
multimodal=True,
chatbot=chatbot,
fill_height=True,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()