from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import gradio as gr
from PIL import Image
import re
# Load models
def initialize_models():
"""Loads and returns the RAG multimodal and Qwen2-VL models along with the processor."""
multimodal_rag = RAGMultiModalModel.from_pretrained("vidore/colpali")
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.float32)
qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
return multimodal_rag, qwen_model, qwen_processor
multimodal_rag, qwen_model, qwen_processor = initialize_models()
# Text extraction function
def perform_ocr(image):
"""Extracts Sanskrit and English text from an image using the Qwen model."""
query = "Extract text from the image in original language"
# Format the request for the model
user_input = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": query}
]
}
]
# Preprocess the input
input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(user_input)
model_inputs = qwen_processor(
text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
).to("cpu") # Use CPU for inference
# Generate output
with torch.no_grad():
generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=2000)
trimmed_ids = [output[len(input_ids):] for input_ids, output in zip(model_inputs.input_ids, generated_ids)]
ocr_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return ocr_result
# Keyword search function
def highlight_keyword(text, keyword):
"""Searches and highlights the keyword in the extracted text."""
keyword_lowercase = keyword.lower()
sentences = text.split('. ')
results = []
for sentence in sentences:
if keyword_lowercase in sentence.lower():
highlighted = re.sub(f'({re.escape(keyword)})', r'\1', sentence, flags=re.IGNORECASE)
results.append(highlighted)
return results if results else ["No matches found."]
# Gradio app for text extraction
def extract_text(image):
"""Extracts text from an uploaded image."""
return perform_ocr(image)
# Gradio app for keyword search
def search_in_text(extracted_text, keyword):
"""Searches for a keyword in the extracted text and highlights matches."""
results = highlight_keyword(extracted_text, keyword)
return "
".join(results)
# Updated title with revised phrasing
header_html = """