from byaldi import RAGMultiModalModel from transformers import Qwen2VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info import torch import gradio as gr from PIL import Image import re # Load models def initialize_models(): """Loads and returns the RAG multimodal and Qwen2-VL models along with the processor.""" multimodal_rag = RAGMultiModalModel.from_pretrained("vidore/colpali") qwen_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.float32) qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True) return multimodal_rag, qwen_model, qwen_processor multimodal_rag, qwen_model, qwen_processor = initialize_models() # Text extraction function def perform_ocr(image): """Extracts Sanskrit and English text from an image using the Qwen model.""" query = "Extract text from the image in original language" # Format the request for the model user_input = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": query} ] } ] # Preprocess the input input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(user_input) model_inputs = qwen_processor( text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" ).to("cpu") # Use CPU for inference # Generate output with torch.no_grad(): generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=2000) trimmed_ids = [output[len(input_ids):] for input_ids, output in zip(model_inputs.input_ids, generated_ids)] ocr_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] return ocr_result # Keyword search function def highlight_keyword(text, keyword): """Searches and highlights the keyword in the extracted text.""" keyword_lowercase = keyword.lower() sentences = text.split('. ') results = [] for sentence in sentences: if keyword_lowercase in sentence.lower(): highlighted = re.sub(f'({re.escape(keyword)})', r'\1', sentence, flags=re.IGNORECASE) results.append(highlighted) return results if results else ["No matches found."] # Gradio app for text extraction def extract_text(image): """Extracts text from an uploaded image.""" return perform_ocr(image) # Gradio app for keyword search def search_in_text(extracted_text, keyword): """Searches for a keyword in the extracted text and highlights matches.""" results = highlight_keyword(extracted_text, keyword) return "
".join(results) # Updated title with revised phrasing header_html = """

OCR and Text Search Prototype

""" # CSS to fix button sizes custom_css = """ .gr-button { width: 200px; /* Set a fixed width for the buttons */ padding: 12px 20px; /* Add padding to buttons for consistency */ } .gr-textbox { max-height: 300px; /* Set a maximum height for the extracted text output */ overflow-y: scroll; /* Enable scrolling when text exceeds the height */ } """ # Gradio Interface with gr.Blocks(css=custom_css) as interface: # Header section gr.HTML(header_html) # Sidebar section with gr.Row(): with gr.Column(scale=1, min_width=200): gr.Markdown("## Instructions") gr.Markdown(""" 1. Upload an image containing text. 2. Extract the text from the image. 3. Search for specific keywords within the extracted text. """) gr.Markdown("### Features") gr.Markdown(""" - **OCR**: Extract text from images. - **Keyword Search**: Search and highlight keywords in extracted text. """) with gr.Column(scale=3): # Main content in tabs with gr.Tabs(): # First Tab: Text Extraction with gr.Tab("Extract Text"): gr.Markdown("### Upload an image to extract text:") with gr.Row(): image_upload = gr.Image(type="pil", label="Upload Image", interactive=True) with gr.Row(): extract_btn = gr.Button("Extract Text") extracted_textbox = gr.Textbox(label="Extracted Text") extract_btn.click(extract_text, inputs=image_upload, outputs=extracted_textbox) # Second Tab: Keyword Search with gr.Tab("Search in Extracted Text"): gr.Markdown("### Search for a keyword in the extracted text:") with gr.Row(): keyword_searchbox = gr.Textbox(label="Enter Keyword", placeholder="Keyword to search") with gr.Row(): search_btn = gr.Button("Search") search_results = gr.HTML(label="Results") search_btn.click(search_in_text, inputs=[extracted_textbox, keyword_searchbox], outputs=search_results) # Launch the Gradio App interface.launch()