toandev's picture
WIP
8ed06e7
raw
history blame
4.11 kB
import time
from threading import Thread
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import (
AutoProcessor,
MllamaForConditionalGeneration,
TextIteratorStreamer,
)
# Constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT = "toandev/Viet-Receipt-Llama-3.2-11B-Vision-Instruct"
# Model initialization
model = MllamaForConditionalGeneration.from_pretrained(
CHECKPOINT, torch_dtype=torch.bfloat16
).to(DEVICE)
processor = AutoProcessor.from_pretrained(CHECKPOINT)
def process_chat_history(history):
messages = []
images = []
for i, msg in enumerate(history):
if isinstance(msg[0], tuple):
messages.extend(
[
{
"role": "user",
"content": [
{"type": "text", "text": history[i + 1][0]},
{"type": "image"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": history[i + 1][1]}],
},
]
)
images.append(Image.open(msg[0][0]).convert("RGB"))
elif isinstance(history[i - 1], tuple) and isinstance(msg[0], str):
continue
elif isinstance(history[i - 1][0], str) and isinstance(msg[0], str):
messages.extend(
[
{"role": "user", "content": [{"type": "text", "text": msg[0]}]},
{
"role": "assistant",
"content": [{"type": "text", "text": msg[1]}],
},
]
)
return messages, images
@spaces.GPU
def bot_streaming(message, history, max_new_tokens=250):
text = message["text"]
messages, images = process_chat_history(history)
# Handle current message
if len(message["files"]) == 1:
image = (
Image.open(message["files"][0])
if isinstance(message["files"][0], str)
else Image.open(message["files"][0]["path"])
).convert("RGB")
images.append(image)
messages.append(
{
"role": "user",
"content": [{"type": "text", "text": text}, {"type": "image"}],
}
)
else:
messages.append({"role": "user", "content": [{"type": "text", "text": text}]})
# Process inputs
texts = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = (
processor(text=texts, images=images, return_tensors="pt")
if images
else processor(text=texts, return_tensors="pt")
).to(DEVICE)
# Setup streaming
streamer = TextIteratorStreamer(
processor, skip_special_tokens=True, skip_prompt=True
)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
return "Hello"
demo = gr.ChatInterface(
fn=bot_streaming,
textbox=gr.MultimodalTextbox(),
additional_inputs=[
gr.Slider(
minimum=10,
maximum=500,
value=250,
step=10,
label="Maximum number of new tokens to generate",
)
],
examples=[
[
{
"text": "Hóa đơn được in tại nhà hàng nào?",
"files": ["./examples/01.jpg"],
},
200,
],
[
{
"text": "Mô tả thông tin hóa đơn một cách chi tiết.",
"files": ["./examples/02.jpg"],
},
500,
],
],
cache_examples=False,
stop_btn="Stop",
fill_height=True,
multimodal=True,
)
if __name__ == "__main__":
demo.launch(debug=True)