File size: 2,697 Bytes
d04279c
 
d2212a0
28c861d
 
86c3f4a
28c861d
d2212a0
 
d04279c
 
d2212a0
d04279c
 
 
 
 
d2212a0
d04279c
d2212a0
 
d04279c
 
 
 
 
 
 
 
 
 
 
 
 
 
bf86837
d04279c
 
 
 
 
 
 
 
 
 
 
bf86837
d04279c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89c493b
d04279c
 
 
d2212a0
d04279c
89c493b
9936c85
 
 
89c493b
d2212a0
d04279c
d2212a0
 
89c493b
 
 
 
d2212a0
89c493b
28c861d
 
d04279c
8222a16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import gradio as gr
from PIL import Image


processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")

# Initialize the model with float16 precision and handle fallback to CPU
# Simplified model loading function for CPU
def load_model():
    return Qwen2VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2-VL-2B-Instruct",
        torch_dtype=torch.float32,  # Use float32 for CPU
        low_cpu_mem_usage=True
    )

# Load the model
vlm = load_model()

# OCR function to extract text from an image
def ocr_image(image, query="Extract text from the image", keyword=""):
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image,
                },
                {"type": "text", "text": query},
            ],
        }
    ]

    # Prepare inputs for the model
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cpu")

    # Generate the output text using the model
    generated_ids = vlm.generate(**inputs, max_new_tokens=512)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]
    
    if keyword:
        keyword_lower = keyword.lower()
        if keyword_lower in output_text.lower():
            highlighted_text = output_text.replace(keyword, f"**{keyword}**")
            return f"Keyword '{keyword}' found in the text:\n\n{highlighted_text}"
        else:
            return f"Keyword '{keyword}' not found in the text:\n\n{output_text}"
    else:
        return output_text

# Gradio interface
def process_image(image, keyword=""):
    max_size = 1024
    if max(image.size) > max_size:
        image.thumbnail((max_size, max_size))
    return ocr_image(image, keyword=keyword)

# Update the Gradio interface:
interface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(type="pil"),
        gr.Textbox(label="Enter keyword to search (optional)")
    ],
    outputs="text",
    title="Hindi & English OCR with Keyword Search",
)

# Launch Gradio interface in Colab
interface.launch()