|
import gradio as gr |
|
from gradio.themes.base import Base |
|
from PIL import Image |
|
import torch |
|
import torchvision.transforms as transforms |
|
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
caption_model = VisionEncoderDecoderModel.from_pretrained('Mayada/AIC-transformer') |
|
caption_tokenizer = AutoTokenizer.from_pretrained('aubmindlab/bert-base-arabertv02') |
|
question_model = AutoModelForSeq2SeqLM.from_pretrained("Mihakram/AraT5-base-question-generation") |
|
question_tokenizer = AutoTokenizer.from_pretrained("Mihakram/AraT5-base-question-generation") |
|
|
|
|
|
normalize = transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225] |
|
) |
|
|
|
inference_transforms = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
normalize |
|
]) |
|
|
|
|
|
with open("DICTIONARY (3).txt", "r", encoding="utf-8") as file: |
|
dictionary = dict(line.strip().split("\t") for line in file) |
|
|
|
|
|
def correct_caption(caption): |
|
corrected_words = [dictionary.get(word, word) for word in caption.split()] |
|
corrected_caption = " ".join(corrected_words) |
|
return corrected_caption |
|
|
|
|
|
def generate_captions(image): |
|
img_tensor = inference_transforms(image).unsqueeze(0) |
|
generated = caption_model.generate( |
|
img_tensor, |
|
num_beams=3, |
|
max_length=10, |
|
early_stopping=True, |
|
do_sample=True, |
|
top_k=1000, |
|
num_return_sequences=1, |
|
) |
|
captions = [caption_tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated] |
|
return captions |
|
|
|
|
|
def generate_questions(context, answer): |
|
text = "context: " + context + " " + "answer: " + answer + " </s>" |
|
text_encoding = question_tokenizer.encode_plus( |
|
text, return_tensors="pt" |
|
) |
|
question_model.eval() |
|
generated_ids = question_model.generate( |
|
input_ids=text_encoding['input_ids'], |
|
attention_mask=text_encoding['attention_mask'], |
|
max_length=64, |
|
num_beams=5, |
|
num_return_sequences=1 |
|
) |
|
questions = [question_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace( |
|
'question: ', ' ') for g in generated_ids] |
|
return questions |
|
|
|
|
|
class Seafoam(Base): |
|
pass |
|
|
|
seafoam = Seafoam() |
|
|
|
def caption_question_interface(image): |
|
|
|
captions = generate_captions(image) |
|
|
|
|
|
corrected_captions = [correct_caption(caption) for caption in captions] |
|
|
|
|
|
questions_with_answers = [] |
|
for caption in corrected_captions: |
|
words = caption.split() |
|
|
|
if len(words) > 0: |
|
answer = words[0] |
|
question = generate_questions(caption, answer) |
|
questions_with_answers.extend([(q, answer) for q in question]) |
|
|
|
if len(words) > 1: |
|
answer = words[1] |
|
question = generate_questions(caption, answer) |
|
questions_with_answers.extend([(q, answer) for q in question]) |
|
|
|
if len(words) > 1: |
|
answer = " ".join(words[:2]) |
|
question = generate_questions(caption, answer) |
|
questions_with_answers.extend([(q, answer) for q in question]) |
|
|
|
if len(words) > 2: |
|
answer = words[2] |
|
question = generate_questions(caption, answer) |
|
questions_with_answers.extend([(q, answer) for q in question]) |
|
|
|
if len(words) > 3: |
|
answer = words[3] |
|
question = generate_questions(caption, answer) |
|
questions_with_answers.extend([(q, answer) for q in question]) |
|
|
|
|
|
formatted_questions = [f"Question: {q}\nAnswer: {a}" for q, a in questions_with_answers] |
|
formatted_questions = "\n".join(formatted_questions) |
|
|
|
|
|
return "\n".join(corrected_captions), formatted_questions |
|
|
|
gr_interface = gr.Interface( |
|
fn=caption_question_interface, |
|
inputs=gr.Image(type="pil", label="Input Image"), |
|
outputs=[ |
|
gr.Textbox(label="Generated Captions"), |
|
gr.Textbox(label="Generated Questions and Answers") |
|
], |
|
title="Image Captioning and Question Generation", |
|
description="Generate captions and questions for images using pre-trained models.", |
|
theme=seafoam, |
|
) |
|
|
|
|
|
gr_interface.launch(share=True) |
|
|