Spaces:
Running
Running
File size: 5,517 Bytes
65f5126 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import gradio as gr
from PIL import Image
import re
# Load models
def initialize_models():
"""Loads and returns the RAG multimodal and Qwen2-VL models along with the processor."""
multimodal_rag = RAGMultiModalModel.from_pretrained("vidore/colpali")
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.float32)
qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
return multimodal_rag, qwen_model, qwen_processor
multimodal_rag, qwen_model, qwen_processor = initialize_models()
# Text extraction function
def perform_ocr(image):
"""Extracts Sanskrit and English text from an image using the Qwen model."""
query = "Extract text from the image in original language"
# Format the request for the model
user_input = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": query}
]
}
]
# Preprocess the input
input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(user_input)
model_inputs = qwen_processor(
text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
).to("cpu") # Use CPU for inference
# Generate output
with torch.no_grad():
generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=2000)
trimmed_ids = [output[len(input_ids):] for input_ids, output in zip(model_inputs.input_ids, generated_ids)]
ocr_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return ocr_result
# Keyword search function
def highlight_keyword(text, keyword):
"""Searches and highlights the keyword in the extracted text."""
keyword_lowercase = keyword.lower()
sentences = text.split('. ')
results = []
for sentence in sentences:
if keyword_lowercase in sentence.lower():
highlighted = re.sub(f'({re.escape(keyword)})', r'<mark>\1</mark>', sentence, flags=re.IGNORECASE)
results.append(highlighted)
return results if results else ["No matches found."]
# Gradio app for text extraction
def extract_text(image):
"""Extracts text from an uploaded image."""
return perform_ocr(image)
# Gradio app for keyword search
def search_in_text(extracted_text, keyword):
"""Searches for a keyword in the extracted text and highlights matches."""
results = highlight_keyword(extracted_text, keyword)
return "<br>".join(results)
# Updated title with revised phrasing
header_html = """
<h1 style="text-align: center; color: #4CAF50;"><span class="gradient-text">OCR and Text Search Prototype</span></h1>
"""
# CSS to fix button sizes
custom_css = """
.gr-button {
width: 200px; /* Set a fixed width for the buttons */
padding: 12px 20px; /* Add padding to buttons for consistency */
}
.gr-textbox {
max-height: 300px; /* Set a maximum height for the extracted text output */
overflow-y: scroll; /* Enable scrolling when text exceeds the height */
}
"""
# Gradio Interface
with gr.Blocks(css=custom_css) as interface:
# Header section
gr.HTML(header_html)
# Sidebar section
with gr.Row():
with gr.Column(scale=1, min_width=200):
gr.Markdown("## Instructions")
gr.Markdown("""
1. Upload an image containing text.
2. Extract the text from the image.
3. Search for specific keywords within the extracted text.
""")
gr.Markdown("### Features")
gr.Markdown("""
- **OCR**: Extract text from images.
- **Keyword Search**: Search and highlight keywords in extracted text.
""")
with gr.Column(scale=3):
# Main content in tabs
with gr.Tabs():
# First Tab: Text Extraction
with gr.Tab("Extract Text"):
gr.Markdown("### Upload an image to extract text:")
with gr.Row():
image_upload = gr.Image(type="pil", label="Upload Image", interactive=True)
with gr.Row():
extract_btn = gr.Button("Extract Text")
extracted_textbox = gr.Textbox(label="Extracted Text")
extract_btn.click(extract_text, inputs=image_upload, outputs=extracted_textbox)
# Second Tab: Keyword Search
with gr.Tab("Search in Extracted Text"):
gr.Markdown("### Search for a keyword in the extracted text:")
with gr.Row():
keyword_searchbox = gr.Textbox(label="Enter Keyword", placeholder="Keyword to search")
with gr.Row():
search_btn = gr.Button("Search")
search_results = gr.HTML(label="Results")
search_btn.click(search_in_text, inputs=[extracted_textbox, keyword_searchbox], outputs=search_results)
# Launch the Gradio App
interface.launch() |