Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,091 Bytes
08a6c8d 5f6343c 08a6c8d f7151f4 08a6c8d f7151f4 08a6c8d f7151f4 08a6c8d f7151f4 08a6c8d e325f49 659f477 e325f49 a2f5d42 e325f49 a2f5d42 e325f49 a2f5d42 e325f49 659f477 08a6c8d f7151f4 659f477 f7151f4 08a6c8d f7151f4 08a6c8d bca5e76 08a6c8d 659f477 1ae1376 08a6c8d f7151f4 08a6c8d f7151f4 08a6c8d f7151f4 08a6c8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from threading import Thread
from typing import Dict
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, TextIteratorStreamer
TITLE = "<h1><center>Chat with PaliGemma-3B-Chat-v0.1</center></h1>"
DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/hiyouga/PaliGemma-3B-Chat-v0.1' target='_blank'>our model page</a> for details.</center></h3>"
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
model_id = "hiyouga/PaliGemma-3B-Chat-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
@spaces.GPU
def stream_chat(message: Dict[str, str], history: list):
# Turn 1:
# {'text': 'what is this', 'files': ['image-xxx.jpg']}
# []
# Turn 2:
# {'text': 'continue?', 'files': []}
# [[('image-xxx.jpg',), None], ['what is this', 'a image.']]
image_path = None
if len(message["files"]) != 0:
image_path = message["files"][0]
if len(history) != 0 and isinstance(history[0][0], tuple):
image_path = history[0][0][0]
history = history[1:]
if image_path is not None:
image = Image.open(image_path)
else:
image = Image.new("RGB", (100, 100), (255, 255, 255))
pixel_values = processor(images=[image], return_tensors="pt").to(model.device)["pixel_values"]
conversation = []
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message["text"]})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
image_token_id = tokenizer.convert_tokens_to_ids("<image>")
image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id)
input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
pixel_values=pixel_values,
streamer=streamer,
max_new_tokens=256,
do_sample=True,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
output = ""
for new_token in streamer:
output += new_token
yield output
chatbot = gr.Chatbot(height=450)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
multimodal=True,
chatbot=chatbot,
fill_height=True,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()
|