"""Script to create a Gradio app for the finetuned paligemma model."""
from threading import Thread
from typing import Dict
import gradio as gr
import torch
from PIL import Image
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor,TextIteratorStreamer
TITLE = "
Multimodal and Multilingual Bot for Africans
"
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
MODEL_ID = "heisguyy/kagglex-paligemma"
model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID)
processor = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
def stream_chat(message: Dict[str, str], history: list):
"""Function to stream chat."""
image_path = None
if len(message["files"]) != 0:
image_path = message["files"][0]
if len(history) != 0 and isinstance(history[0][0], tuple):
image_path = history[0][0][0]
history = history[1:]
if image_path is not None:
image = Image.open(image_path).convert("RGB")
else:
image = Image.new("RGB", (100, 100), (255, 255, 255))
results = processor(message["text"], image, return_tensors="pt").to(model.device)
output = model.generate(**results, max_new_tokens=20)
return processor.decode(output[0], skip_special_tokens=True)[len(message["text"]):]
chatbot = gr.Chatbot(height=600)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.ChatInterface(
fn=stream_chat,
multimodal=True,
chatbot=chatbot,
fill_height=True,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()