import gradio as gr from gradio.themes.base import Base from PIL import Image import torch import torchvision.transforms as transforms from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoModelForSeq2SeqLM # Load the models caption_model = VisionEncoderDecoderModel.from_pretrained('Mayada/AIC-transformer') caption_tokenizer = AutoTokenizer.from_pretrained('aubmindlab/bert-base-arabertv02') question_model = AutoModelForSeq2SeqLM.from_pretrained("Mihakram/AraT5-base-question-generation") question_tokenizer = AutoTokenizer.from_pretrained("Mihakram/AraT5-base-question-generation") # Define the normalization and transformations normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], # ImageNet mean std=[0.229, 0.224, 0.225] # ImageNet standard deviation ) inference_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), normalize ]) # Load the dictionary with open("/content/drive/MyDrive/DICTIONARY (3).txt", "r", encoding="utf-8") as file: dictionary = dict(line.strip().split("\t") for line in file) # Function to correct words in the caption using the dictionary def correct_caption(caption): corrected_words = [dictionary.get(word, word) for word in caption.split()] corrected_caption = " ".join(corrected_words) return corrected_caption # Function to generate captions for an image def generate_captions(image): img_tensor = inference_transforms(image).unsqueeze(0) generated = caption_model.generate( img_tensor, num_beams=3, max_length=10, early_stopping=True, do_sample=True, top_k=1000, num_return_sequences=1, ) captions = [caption_tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated] return captions # Function to generate questions given a context and answer def generate_questions(context, answer): text = "context: " + context + " " + "answer: " + answer + " " text_encoding = question_tokenizer.encode_plus( text, return_tensors="pt" ) question_model.eval() generated_ids = question_model.generate( input_ids=text_encoding['input_ids'], attention_mask=text_encoding['attention_mask'], max_length=64, num_beams=5, num_return_sequences=1 ) questions = [question_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace( 'question: ', ' ') for g in generated_ids] return questions # Define the Gradio interface with Seafoam theme class Seafoam(Base): pass seafoam = Seafoam() def caption_question_interface(image): # Generate captions captions = generate_captions(image) # Correct captions using the dictionary corrected_captions = [correct_caption(caption) for caption in captions] # Generate questions for each caption questions_with_answers = [] for caption in corrected_captions: words = caption.split() # Generate questions for the first word if len(words) > 0: answer = words[0] question = generate_questions(caption, answer) questions_with_answers.extend([(q, answer) for q in question]) # Generate questions for the second word if len(words) > 1: answer = words[1] question = generate_questions(caption, answer) questions_with_answers.extend([(q, answer) for q in question]) # Generate questions for the second word + first word if len(words) > 1: answer = " ".join(words[:2]) question = generate_questions(caption, answer) questions_with_answers.extend([(q, answer) for q in question]) # Generate questions for the third word if len(words) > 2: answer = words[2] question = generate_questions(caption, answer) questions_with_answers.extend([(q, answer) for q in question]) # Generate questions for the fourth word if len(words) > 3: answer = words[3] question = generate_questions(caption, answer) questions_with_answers.extend([(q, answer) for q in question]) # Format questions with answers formatted_questions = [f"Question: {q}\nAnswer: {a}" for q, a in questions_with_answers] formatted_questions = "\n".join(formatted_questions) # Return the generated captions and formatted questions with answers return "\n".join(corrected_captions), formatted_questions gr_interface = gr.Interface( fn=caption_question_interface, inputs=gr.Image(type="pil", label="Input Image"), outputs=[ gr.Textbox(label="Generated Captions"), gr.Textbox(label="Generated Questions and Answers") ], title="Image Captioning and Question Generation", description="Generate captions and questions for images using pre-trained models.", theme=seafoam, ) # Launch the interface gr_interface.launch(share=True)