import time from threading import Thread import gradio as gr import spaces import torch from PIL import Image from transformers import ( AutoProcessor, MllamaForConditionalGeneration, TextIteratorStreamer, ) # Constants DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") CHECKPOINT = "toandev/Viet-Receipt-Llama-3.2-11B-Vision-Instruct" # Model initialization model = MllamaForConditionalGeneration.from_pretrained( CHECKPOINT, torch_dtype=torch.bfloat16 ).to(DEVICE) processor = AutoProcessor.from_pretrained(CHECKPOINT) def process_chat_history(history): messages = [] images = [] for i, msg in enumerate(history): if isinstance(msg[0], tuple): messages.extend( [ { "role": "user", "content": [ {"type": "text", "text": history[i + 1][0]}, {"type": "image"}, ], }, { "role": "assistant", "content": [{"type": "text", "text": history[i + 1][1]}], }, ] ) images.append(Image.open(msg[0][0]).convert("RGB")) elif isinstance(history[i - 1], tuple) and isinstance(msg[0], str): continue elif isinstance(history[i - 1][0], str) and isinstance(msg[0], str): messages.extend( [ {"role": "user", "content": [{"type": "text", "text": msg[0]}]}, { "role": "assistant", "content": [{"type": "text", "text": msg[1]}], }, ] ) return messages, images @spaces.GPU def bot_streaming(message, history, max_new_tokens=250): text = message["text"] messages, images = process_chat_history(history) # Handle current message if len(message["files"]) == 1: image = ( Image.open(message["files"][0]) if isinstance(message["files"][0], str) else Image.open(message["files"][0]["path"]) ).convert("RGB") images.append(image) messages.append( { "role": "user", "content": [{"type": "text", "text": text}, {"type": "image"}], } ) else: messages.append({"role": "user", "content": [{"type": "text", "text": text}]}) # Process inputs texts = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = ( processor(text=texts, images=images, return_tensors="pt") if images else processor(text=texts, return_tensors="pt") ).to(DEVICE) # Setup streaming streamer = TextIteratorStreamer( processor, skip_special_tokens=True, skip_prompt=True ) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text time.sleep(0.01) yield buffer return "Hello" demo = gr.ChatInterface( fn=bot_streaming, textbox=gr.MultimodalTextbox(), additional_inputs=[ gr.Slider( minimum=10, maximum=500, value=250, step=10, label="Maximum number of new tokens to generate", ) ], examples=[ [ { "text": "Hóa đơn được in tại nhà hàng nào?", "files": ["./examples/01.jpg"], }, 200, ], [ { "text": "Mô tả thông tin hóa đơn một cách chi tiết.", "files": ["./examples/02.jpg"], }, 500, ], ], cache_examples=False, stop_btn="Stop", fill_height=True, multimodal=True, ) if __name__ == "__main__": demo.launch(debug=True)