File size: 4,107 Bytes
8ed06e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import time
from threading import Thread
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import (
AutoProcessor,
MllamaForConditionalGeneration,
TextIteratorStreamer,
)
# Constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT = "toandev/Viet-Receipt-Llama-3.2-11B-Vision-Instruct"
# Model initialization
model = MllamaForConditionalGeneration.from_pretrained(
CHECKPOINT, torch_dtype=torch.bfloat16
).to(DEVICE)
processor = AutoProcessor.from_pretrained(CHECKPOINT)
def process_chat_history(history):
messages = []
images = []
for i, msg in enumerate(history):
if isinstance(msg[0], tuple):
messages.extend(
[
{
"role": "user",
"content": [
{"type": "text", "text": history[i + 1][0]},
{"type": "image"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": history[i + 1][1]}],
},
]
)
images.append(Image.open(msg[0][0]).convert("RGB"))
elif isinstance(history[i - 1], tuple) and isinstance(msg[0], str):
continue
elif isinstance(history[i - 1][0], str) and isinstance(msg[0], str):
messages.extend(
[
{"role": "user", "content": [{"type": "text", "text": msg[0]}]},
{
"role": "assistant",
"content": [{"type": "text", "text": msg[1]}],
},
]
)
return messages, images
@spaces.GPU
def bot_streaming(message, history, max_new_tokens=250):
text = message["text"]
messages, images = process_chat_history(history)
# Handle current message
if len(message["files"]) == 1:
image = (
Image.open(message["files"][0])
if isinstance(message["files"][0], str)
else Image.open(message["files"][0]["path"])
).convert("RGB")
images.append(image)
messages.append(
{
"role": "user",
"content": [{"type": "text", "text": text}, {"type": "image"}],
}
)
else:
messages.append({"role": "user", "content": [{"type": "text", "text": text}]})
# Process inputs
texts = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = (
processor(text=texts, images=images, return_tensors="pt")
if images
else processor(text=texts, return_tensors="pt")
).to(DEVICE)
# Setup streaming
streamer = TextIteratorStreamer(
processor, skip_special_tokens=True, skip_prompt=True
)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
return "Hello"
demo = gr.ChatInterface(
fn=bot_streaming,
textbox=gr.MultimodalTextbox(),
additional_inputs=[
gr.Slider(
minimum=10,
maximum=500,
value=250,
step=10,
label="Maximum number of new tokens to generate",
)
],
examples=[
[
{
"text": "Hóa đơn được in tại nhà hàng nào?",
"files": ["./examples/01.jpg"],
},
200,
],
[
{
"text": "Mô tả thông tin hóa đơn một cách chi tiết.",
"files": ["./examples/02.jpg"],
},
500,
],
],
cache_examples=False,
stop_btn="Stop",
fill_height=True,
multimodal=True,
)
if __name__ == "__main__":
demo.launch(debug=True)
|