File size: 4,966 Bytes
8ceebf6
f3c07ed
 
 
 
 
 
 
9e5ebfb
f3c07ed
 
 
 
 
 
 
 
8ceebf6
 
f3c07ed
 
 
 
 
 
c6ad764
5f85a7e
c6ad764
f3c07ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b1faf1
c6ad764
 
 
 
 
f3c07ed
c6ad764
f3c07ed
c6ad764
5b1faf1
f3c07ed
c6ad764
 
f3c07ed
 
 
c6ad764
f3c07ed
107fab0
 
 
c6ad764
f3c07ed
107fab0
 
 
c6ad764
f3c07ed
107fab0
 
 
c6ad764
f3c07ed
107fab0
 
 
c6ad764
f3c07ed
107fab0
 
 
f3c07ed
c6ad764
107fab0
f3c07ed
 
c6ad764
f3c07ed
 
 
 
c6ad764
f3c07ed
c6ad764
5b1faf1
f3c07ed
5b1faf1
107fab0
c6ad764
f3c07ed
8ceebf6
c6ad764
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
from gradio.themes.base import Base
from PIL import Image
import torch
import torchvision.transforms as transforms
from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoModelForSeq2SeqLM

# Load the models
caption_model = VisionEncoderDecoderModel.from_pretrained('Mayada/AIC-transformer')
caption_tokenizer = AutoTokenizer.from_pretrained('aubmindlab/bert-base-arabertv02')
question_model = AutoModelForSeq2SeqLM.from_pretrained("Mihakram/AraT5-base-question-generation")
question_tokenizer = AutoTokenizer.from_pretrained("Mihakram/AraT5-base-question-generation")

# Define the normalization and transformations
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],  # ImageNet mean
    std=[0.229, 0.224, 0.225]  # ImageNet standard deviation
)

inference_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize
])

# Load the dictionary
with open("DICTIONARY (3).txt", "r", encoding="utf-8") as file:
    dictionary = dict(line.strip().split("\t") for line in file)

# Function to correct words in the caption using the dictionary
def correct_caption(caption):
    corrected_words = [dictionary.get(word, word) for word in caption.split()]
    corrected_caption = " ".join(corrected_words)
    return corrected_caption

# Function to generate captions for an image
def generate_captions(image):
    img_tensor = inference_transforms(image).unsqueeze(0)
    generated = caption_model.generate(
        img_tensor,
        num_beams=3,
        max_length=10,
        early_stopping=True,
        do_sample=True,
        top_k=1000,
        num_return_sequences=1,
    )
    captions = [caption_tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated]
    return captions

# Function to generate questions given a context and answer
def generate_questions(context, answer):
    text = "context: " + context + " " + "answer: " + answer + " </s>"
    text_encoding = question_tokenizer.encode_plus(
        text, return_tensors="pt"
    )
    question_model.eval()
    generated_ids = question_model.generate(
        input_ids=text_encoding['input_ids'],
        attention_mask=text_encoding['attention_mask'],
        max_length=64,
        num_beams=5,
        num_return_sequences=1
    )
    questions = [question_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(
        'question: ', ' ') for g in generated_ids]
    return questions

# Interface
class Seafoam(Base):
    pass

seafoam = Seafoam()

def caption_question_interface(image):
    # Generate captions
    captions = generate_captions(image)

    # Proofread captions using the dictionary
    corrected_captions = [correct_caption(caption) for caption in captions]

    # Generate questions for each caption
    questions_with_answers = []
    for caption in corrected_captions:
        words = caption.split()
        # Generate questions for the first word
        if len(words) > 0:
            answer = words[0]
            question = generate_questions(caption, answer)
            questions_with_answers.extend([(q, answer) for q in question])
        # Generate questions for the second word
        if len(words) > 1:
            answer = words[1]
            question = generate_questions(caption, answer)
            questions_with_answers.extend([(q, answer) for q in question])
        # Generate questions for the second word + first word
        if len(words) > 1:
            answer = " ".join(words[:2])
            question = generate_questions(caption, answer)
            questions_with_answers.extend([(q, answer) for q in question])
        # Generate questions for the third word
        if len(words) > 2:
            answer = words[2]
            question = generate_questions(caption, answer)
            questions_with_answers.extend([(q, answer) for q in question])
        # Generate questions for the fourth word
        if len(words) > 3:
            answer = words[3]
            question = generate_questions(caption, answer)
            questions_with_answers.extend([(q, answer) for q in question])

    # Format questions with answers
    formatted_questions = [f"Question: {q}\nKeyword: {a}" for q, a in questions_with_answers]
    formatted_questions = "\n".join(formatted_questions)

    # Return the generated captions and formatted questions with answers
    return "\n".join(corrected_captions), formatted_questions

gr_interface = gr.Interface(
    fn=caption_question_interface,
    inputs=gr.Image(type="pil", label="Input Image"), 
    outputs=[
        gr.Textbox(label="Generated Captions"), 
        gr.Textbox(label="Generated Questions")
    ],
    title="Visual Question Generator",
    description="Generate captions and questions for images using Arabic image captioning model and question generation model",
    theme=seafoam,
)

# Launch the interface
gr_interface.launch(share=True)