Mayada's picture
Update app.py
5f85a7e verified
raw
history blame
4.99 kB
import gradio as gr
from gradio.themes.base import Base
from PIL import Image
import torch
import torchvision.transforms as transforms
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoModelForSeq2SeqLM
# Load the models
caption_model = VisionEncoderDecoderModel.from_pretrained('Mayada/AIC-transformer')
caption_tokenizer = AutoTokenizer.from_pretrained('aubmindlab/bert-base-arabertv02')
question_model = AutoModelForSeq2SeqLM.from_pretrained("Mihakram/AraT5-base-question-generation")
question_tokenizer = AutoTokenizer.from_pretrained("Mihakram/AraT5-base-question-generation")
# Define the normalization and transformations
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], # ImageNet mean
std=[0.229, 0.224, 0.225] # ImageNet standard deviation
)
inference_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize
])
# Load the dictionary
with open("DICTIONARY (3).txt", "r", encoding="utf-8") as file:
dictionary = dict(line.strip().split("\t") for line in file)
# Function to correct words in the caption using the dictionary
def correct_caption(caption):
corrected_words = [dictionary.get(word, word) for word in caption.split()]
corrected_caption = " ".join(corrected_words)
return corrected_caption
# Function to generate captions for an image
def generate_captions(image):
img_tensor = inference_transforms(image).unsqueeze(0)
generated = caption_model.generate(
img_tensor,
num_beams=3,
max_length=10,
early_stopping=True,
do_sample=True,
top_k=1000,
num_return_sequences=1,
)
captions = [caption_tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated]
return captions
# Function to generate questions given a context and answer
def generate_questions(context, answer):
text = "context: " + context + " " + "answer: " + answer + " </s>"
text_encoding = question_tokenizer.encode_plus(
text, return_tensors="pt"
)
question_model.eval()
generated_ids = question_model.generate(
input_ids=text_encoding['input_ids'],
attention_mask=text_encoding['attention_mask'],
max_length=64,
num_beams=5,
num_return_sequences=1
)
questions = [question_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(
'question: ', ' ') for g in generated_ids]
return questions
# Define the Gradio interface with Seafoam theme
class Seafoam(Base):
pass
seafoam = Seafoam()
def caption_question_interface(image):
# Generate captions
captions = generate_captions(image)
# Correct captions using the dictionary
corrected_captions = [correct_caption(caption) for caption in captions]
# Generate questions for each caption
questions_with_answers = []
for caption in corrected_captions:
words = caption.split()
# Generate questions for the first word
if len(words) > 0:
answer = words[0]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
# Generate questions for the second word
if len(words) > 1:
answer = words[1]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
# Generate questions for the second word + first word
if len(words) > 1:
answer = " ".join(words[:2])
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
# Generate questions for the third word
if len(words) > 2:
answer = words[2]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
# Generate questions for the fourth word
if len(words) > 3:
answer = words[3]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
# Format questions with answers
formatted_questions = [f"Question: {q}\nAnswer: {a}" for q, a in questions_with_answers]
formatted_questions = "\n".join(formatted_questions)
# Return the generated captions and formatted questions with answers
return "\n".join(corrected_captions), formatted_questions
gr_interface = gr.Interface(
fn=caption_question_interface,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=[
gr.Textbox(label="Generated Captions"),
gr.Textbox(label="Generated Questions and Answers")
],
title="Image Captioning and Question Generation",
description="Generate captions and questions for images using pre-trained models.",
theme=seafoam,
)
# Launch the interface
gr_interface.launch(share=True)