Mayada's picture
Update app.py
107fab0 verified
raw
history blame contribute delete
No virus
4.97 kB
import gradio as gr
from gradio.themes.base import Base
from PIL import Image
import torch
import torchvision.transforms as transforms
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoModelForSeq2SeqLM
# Load the models
caption_model = VisionEncoderDecoderModel.from_pretrained('Mayada/AIC-transformer')
caption_tokenizer = AutoTokenizer.from_pretrained('aubmindlab/bert-base-arabertv02')
question_model = AutoModelForSeq2SeqLM.from_pretrained("Mihakram/AraT5-base-question-generation")
question_tokenizer = AutoTokenizer.from_pretrained("Mihakram/AraT5-base-question-generation")
# Define the normalization and transformations
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], # ImageNet mean
std=[0.229, 0.224, 0.225] # ImageNet standard deviation
)
inference_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize
])
# Load the dictionary
with open("DICTIONARY (3).txt", "r", encoding="utf-8") as file:
dictionary = dict(line.strip().split("\t") for line in file)
# Function to correct words in the caption using the dictionary
def correct_caption(caption):
corrected_words = [dictionary.get(word, word) for word in caption.split()]
corrected_caption = " ".join(corrected_words)
return corrected_caption
# Function to generate captions for an image
def generate_captions(image):
img_tensor = inference_transforms(image).unsqueeze(0)
generated = caption_model.generate(
img_tensor,
num_beams=3,
max_length=10,
early_stopping=True,
do_sample=True,
top_k=1000,
num_return_sequences=1,
)
captions = [caption_tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated]
return captions
# Function to generate questions given a context and answer
def generate_questions(context, answer):
text = "context: " + context + " " + "answer: " + answer + " </s>"
text_encoding = question_tokenizer.encode_plus(
text, return_tensors="pt"
)
question_model.eval()
generated_ids = question_model.generate(
input_ids=text_encoding['input_ids'],
attention_mask=text_encoding['attention_mask'],
max_length=64,
num_beams=5,
num_return_sequences=1
)
questions = [question_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(
'question: ', ' ') for g in generated_ids]
return questions
# Interface
class Seafoam(Base):
pass
seafoam = Seafoam()
def caption_question_interface(image):
# Generate captions
captions = generate_captions(image)
# Proofread captions using the dictionary
corrected_captions = [correct_caption(caption) for caption in captions]
# Generate questions for each caption
questions_with_answers = []
for caption in corrected_captions:
words = caption.split()
# Generate questions for the first word
if len(words) > 0:
answer = words[0]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
# Generate questions for the second word
if len(words) > 1:
answer = words[1]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
# Generate questions for the second word + first word
if len(words) > 1:
answer = " ".join(words[:2])
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
# Generate questions for the third word
if len(words) > 2:
answer = words[2]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
# Generate questions for the fourth word
if len(words) > 3:
answer = words[3]
question = generate_questions(caption, answer)
questions_with_answers.extend([(q, answer) for q in question])
# Format questions with answers
formatted_questions = [f"Question: {q}\nKeyword: {a}" for q, a in questions_with_answers]
formatted_questions = "\n".join(formatted_questions)
# Return the generated captions and formatted questions with answers
return "\n".join(corrected_captions), formatted_questions
gr_interface = gr.Interface(
fn=caption_question_interface,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=[
gr.Textbox(label="Generated Captions"),
gr.Textbox(label="Generated Questions")
],
title="Visual Question Generator",
description="Generate captions and questions for images using Arabic image captioning model and question generation model",
theme=seafoam,
)
# Launch the interface
gr_interface.launch(share=True)