hiyouga's picture
Update app.py
659f477 verified
raw
history blame
2.59 kB
from threading import Thread
from typing import Dict
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, TextIteratorStreamer
TITLE = "<h1><center>Chat with PaliGemma-3B-Chat-v0.1</center></h1>"
DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/hiyouga/PaliGemma-3B-Chat-v0.1' target='_blank'>our model page</a> for details.</center></h3>"
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
model_id = "hiyouga/PaliGemma-3B-Chat-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
@spaces.GPU
def stream_chat(message: Dict[str, str], history: list):
# {'text': 'what is this', 'files': ['image-xxx.jpg']}
image = Image.open(message["files"][0])
pixel_values = processor(images=[image], return_tensors="pt").to(model.device)["pixel_values"]
conversation = []
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message["text"]})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
image_token_id = tokenizer.convert_tokens_to_ids("<image>")
image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id)
input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
pixel_values=pixel_values
streamer=streamer,
max_new_tokens=256,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
output = ""
for new_token in streamer:
output += new_token
yield output
chatbot = gr.Chatbot(height=450)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
multimodal=True,
chatbot=chatbot,
fill_height=True,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()