File size: 5,517 Bytes
65f5126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import gradio as gr 
from PIL import Image
import re


# Load models
def initialize_models():
    """Loads and returns the RAG multimodal and Qwen2-VL models along with the processor."""
    multimodal_rag = RAGMultiModalModel.from_pretrained("vidore/colpali")
    qwen_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.float32)
    qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
    return multimodal_rag, qwen_model, qwen_processor

multimodal_rag, qwen_model, qwen_processor = initialize_models()

# Text extraction function
def perform_ocr(image):
    """Extracts Sanskrit and English text from an image using the Qwen model."""
    query = "Extract text from the image in original language"

    # Format the request for the model
    user_input = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": query}
            ]
        }
    ]

    # Preprocess the input
    input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(user_input)
    model_inputs = qwen_processor(
        text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
    ).to("cpu")  # Use CPU for inference

    # Generate output
    with torch.no_grad():
        generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=2000)
        trimmed_ids = [output[len(input_ids):] for input_ids, output in zip(model_inputs.input_ids, generated_ids)]
        ocr_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    return ocr_result

# Keyword search function
def highlight_keyword(text, keyword):
    """Searches and highlights the keyword in the extracted text."""
    keyword_lowercase = keyword.lower()
    sentences = text.split('. ')
    results = []

    for sentence in sentences:
        if keyword_lowercase in sentence.lower():
            highlighted = re.sub(f'({re.escape(keyword)})', r'<mark>\1</mark>', sentence, flags=re.IGNORECASE)
            results.append(highlighted)

    return results if results else ["No matches found."]

# Gradio app for text extraction
def extract_text(image):
    """Extracts text from an uploaded image."""
    return perform_ocr(image)

# Gradio app for keyword search
def search_in_text(extracted_text, keyword):
    """Searches for a keyword in the extracted text and highlights matches."""
    results = highlight_keyword(extracted_text, keyword)
    return "<br>".join(results)

# Updated title with revised phrasing
header_html = """
<h1 style="text-align: center; color: #4CAF50;"><span class="gradient-text">OCR and Text Search Prototype</span></h1>
"""

# CSS to fix button sizes
custom_css = """
    .gr-button {
        width: 200px; /* Set a fixed width for the buttons */
        padding: 12px 20px; /* Add padding to buttons for consistency */
    }
    .gr-textbox {
        max-height: 300px; /* Set a maximum height for the extracted text output */
        overflow-y: scroll; /* Enable scrolling when text exceeds the height */
    }
"""

# Gradio Interface
with gr.Blocks(css=custom_css) as interface:

    # Header section
    gr.HTML(header_html)

    # Sidebar section
    with gr.Row():
        with gr.Column(scale=1, min_width=200):
            gr.Markdown("## Instructions")
            gr.Markdown("""
                1. Upload an image containing text.
                2. Extract the text from the image.
                3. Search for specific keywords within the extracted text.
            """)
            gr.Markdown("### Features")
            gr.Markdown("""
                - **OCR**: Extract text from images.
                - **Keyword Search**: Search and highlight keywords in extracted text.
            """)

        with gr.Column(scale=3):
            # Main content in tabs
            with gr.Tabs():

                # First Tab: Text Extraction
                with gr.Tab("Extract Text"):
                    gr.Markdown("### Upload an image to extract text:")
                    with gr.Row():
                        image_upload = gr.Image(type="pil", label="Upload Image", interactive=True)
                    with gr.Row():
                        extract_btn = gr.Button("Extract Text")
                        extracted_textbox = gr.Textbox(label="Extracted Text")
                    extract_btn.click(extract_text, inputs=image_upload, outputs=extracted_textbox)

                # Second Tab: Keyword Search
                with gr.Tab("Search in Extracted Text"):
                    gr.Markdown("### Search for a keyword in the extracted text:")
                    with gr.Row():
                        keyword_searchbox = gr.Textbox(label="Enter Keyword", placeholder="Keyword to search")
                    with gr.Row():
                        search_btn = gr.Button("Search")
                        search_results = gr.HTML(label="Results")
                    search_btn.click(search_in_text, inputs=[extracted_textbox, keyword_searchbox], outputs=search_results)

# Launch the Gradio App
interface.launch()