Mayada commited on
Commit
c6ad764
1 Parent(s): 2e12dcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -12
app.py CHANGED
@@ -6,7 +6,7 @@ import torchvision.transforms as transforms
6
  from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoModelForSeq2SeqLM
7
 
8
  # Load the models
9
- caption_model = VisionEncoderDecoderModel.from_pretrained('Mayada/AIC-transformer') # Your model on Hugging Face
10
  caption_tokenizer = AutoTokenizer.from_pretrained('aubmindlab/bert-base-arabertv02')
11
  question_model = AutoModelForSeq2SeqLM.from_pretrained("Mihakram/AraT5-base-question-generation")
12
  question_tokenizer = AutoTokenizer.from_pretrained("Mihakram/AraT5-base-question-generation")
@@ -23,10 +23,9 @@ inference_transforms = transforms.Compose([
23
  normalize
24
  ])
25
 
26
- # Load the dictionary (use it from your Hugging Face Space or include in the repo)
27
- dictionary = {
28
- "caption": "alternative_caption" # Replace with your actual dictionary
29
- }
30
 
31
  # Function to correct words in the caption using the dictionary
32
  def correct_caption(caption):
@@ -67,49 +66,67 @@ def generate_questions(context, answer):
67
  'question: ', ' ') for g in generated_ids]
68
  return questions
69
 
70
- # Gradio Interface Function
 
 
 
 
 
71
  def caption_question_interface(image):
 
72
  captions = generate_captions(image)
 
 
73
  corrected_captions = [correct_caption(caption) for caption in captions]
 
 
74
  questions_with_answers = []
75
-
76
  for caption in corrected_captions:
77
  words = caption.split()
 
78
  if len(words) > 0:
79
  answer = words[0]
80
  question = generate_questions(caption, answer)
81
  questions_with_answers.extend([(q, answer) for q in question])
 
82
  if len(words) > 1:
83
  answer = words[1]
84
  question = generate_questions(caption, answer)
85
  questions_with_answers.extend([(q, answer) for q in question])
 
86
  if len(words) > 1:
87
  answer = " ".join(words[:2])
88
  question = generate_questions(caption, answer)
89
  questions_with_answers.extend([(q, answer) for q in question])
 
90
  if len(words) > 2:
91
  answer = words[2]
92
  question = generate_questions(caption, answer)
93
  questions_with_answers.extend([(q, answer) for q in question])
 
94
  if len(words) > 3:
95
  answer = words[3]
96
  question = generate_questions(caption, answer)
97
  questions_with_answers.extend([(q, answer) for q in question])
98
 
 
99
  formatted_questions = [f"Question: {q}\nAnswer: {a}" for q, a in questions_with_answers]
100
  formatted_questions = "\n".join(formatted_questions)
101
 
 
102
  return "\n".join(corrected_captions), formatted_questions
103
 
104
  gr_interface = gr.Interface(
105
  fn=caption_question_interface,
106
- inputs=gr.inputs.Image(type="pil", label="Input Image"),
107
  outputs=[
108
- gr.outputs.Textbox(label="Generated Captions"),
109
- gr.outputs.Textbox(label="Generated Questions and Answers")
110
  ],
111
  title="Image Captioning and Question Generation",
112
- description="Generate captions and questions for images using pre-trained models."
 
113
  )
114
 
115
- gr_interface.launch()
 
 
6
  from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoModelForSeq2SeqLM
7
 
8
  # Load the models
9
+ caption_model = VisionEncoderDecoderModel.from_pretrained('/content/drive/MyDrive/ICModel')
10
  caption_tokenizer = AutoTokenizer.from_pretrained('aubmindlab/bert-base-arabertv02')
11
  question_model = AutoModelForSeq2SeqLM.from_pretrained("Mihakram/AraT5-base-question-generation")
12
  question_tokenizer = AutoTokenizer.from_pretrained("Mihakram/AraT5-base-question-generation")
 
23
  normalize
24
  ])
25
 
26
+ # Load the dictionary
27
+ with open("/content/drive/MyDrive/DICTIONARY (3).txt", "r", encoding="utf-8") as file:
28
+ dictionary = dict(line.strip().split("\t") for line in file)
 
29
 
30
  # Function to correct words in the caption using the dictionary
31
  def correct_caption(caption):
 
66
  'question: ', ' ') for g in generated_ids]
67
  return questions
68
 
69
+ # Define the Gradio interface with Seafoam theme
70
+ class Seafoam(Base):
71
+ pass
72
+
73
+ seafoam = Seafoam()
74
+
75
  def caption_question_interface(image):
76
+ # Generate captions
77
  captions = generate_captions(image)
78
+
79
+ # Correct captions using the dictionary
80
  corrected_captions = [correct_caption(caption) for caption in captions]
81
+
82
+ # Generate questions for each caption
83
  questions_with_answers = []
 
84
  for caption in corrected_captions:
85
  words = caption.split()
86
+ # Generate questions for the first word
87
  if len(words) > 0:
88
  answer = words[0]
89
  question = generate_questions(caption, answer)
90
  questions_with_answers.extend([(q, answer) for q in question])
91
+ # Generate questions for the second word
92
  if len(words) > 1:
93
  answer = words[1]
94
  question = generate_questions(caption, answer)
95
  questions_with_answers.extend([(q, answer) for q in question])
96
+ # Generate questions for the second word + first word
97
  if len(words) > 1:
98
  answer = " ".join(words[:2])
99
  question = generate_questions(caption, answer)
100
  questions_with_answers.extend([(q, answer) for q in question])
101
+ # Generate questions for the third word
102
  if len(words) > 2:
103
  answer = words[2]
104
  question = generate_questions(caption, answer)
105
  questions_with_answers.extend([(q, answer) for q in question])
106
+ # Generate questions for the fourth word
107
  if len(words) > 3:
108
  answer = words[3]
109
  question = generate_questions(caption, answer)
110
  questions_with_answers.extend([(q, answer) for q in question])
111
 
112
+ # Format questions with answers
113
  formatted_questions = [f"Question: {q}\nAnswer: {a}" for q, a in questions_with_answers]
114
  formatted_questions = "\n".join(formatted_questions)
115
 
116
+ # Return the generated captions and formatted questions with answers
117
  return "\n".join(corrected_captions), formatted_questions
118
 
119
  gr_interface = gr.Interface(
120
  fn=caption_question_interface,
121
+ inputs=gr.Image(type="pil", label="Input Image"),
122
  outputs=[
123
+ gr.Textbox(label="Generated Captions"),
124
+ gr.Textbox(label="Generated Questions and Answers")
125
  ],
126
  title="Image Captioning and Question Generation",
127
+ description="Generate captions and questions for images using pre-trained models.",
128
+ theme=seafoam,
129
  )
130
 
131
+ # Launch the interface
132
+ gr_interface.launch(share=True)