DeF0017 commited on
Commit
65f5126
1 Parent(s): d0b78ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -175
app.py CHANGED
@@ -1,175 +1,139 @@
1
- import gradio as gr
2
- import spaces
3
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
4
- from qwen_vl_utils import process_vision_info
5
- import torch
6
- from PIL import Image
7
- import subprocess
8
- import numpy as np
9
- import os
10
- from threading import Thread
11
- import uuid
12
- import io
13
- import re # Import regular expressions for word highlighting
14
-
15
- # Model and Processor Loading (Done once at startup)
16
- MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
17
- model = Qwen2VLForConditionalGeneration.from_pretrained(
18
- MODEL_ID,
19
- trust_remote_code=True,
20
- torch_dtype=torch.float16
21
- ).to("cuda").eval()
22
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
23
-
24
- DESCRIPTION = "[Qwen2-VL-2B Demo](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)"
25
-
26
- # Define supported media extensions
27
- image_extensions = Image.registered_extensions()
28
- video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
29
-
30
-
31
- def identify_and_save_blob(blob_path):
32
- """Identifies if the blob is an image or video and saves it accordingly."""
33
- try:
34
- with open(blob_path, 'rb') as file:
35
- blob_content = file.read()
36
-
37
- # Try to identify if it's an image
38
- try:
39
- Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
40
- extension = ".png" # Default to PNG for saving
41
- media_type = "image"
42
- except (IOError, SyntaxError):
43
- # If it's not a valid image, assume it's a video
44
- extension = ".mp4" # Default to MP4 for saving
45
- media_type = "video"
46
-
47
- # Create a unique filename
48
- filename = f"temp_{uuid.uuid4()}_media{extension}"
49
- with open(filename, "wb") as f:
50
- f.write(blob_content)
51
-
52
- return filename, media_type
53
-
54
- except FileNotFoundError:
55
- raise ValueError(f"The file {blob_path} was not found.")
56
- except Exception as e:
57
- raise ValueError(f"An error occurred while processing the file: {e}")
58
-
59
-
60
- @spaces.GPU
61
- def qwen_inference(media_input, search_word):
62
- """
63
- Performs OCR on the input media and highlights the search_word in the extracted text.
64
-
65
- Args:
66
- media_input (str): Path to the uploaded image or video file.
67
- search_word (str): The word to search and highlight in the OCR result.
68
-
69
- Yields:
70
- str: The OCR result with highlighted search words.
71
- """
72
- text_input = "Extract text" # Hardcoded text query
73
-
74
- if isinstance(media_input, str): # If it's a filepath
75
- media_path = media_input
76
- if media_path.endswith(tuple([i for i, f in image_extensions.items()])):
77
- media_type = "image"
78
- elif media_path.endswith(video_extensions):
79
- media_type = "video"
80
- else:
81
- try:
82
- media_path, media_type = identify_and_save_blob(media_input)
83
- print(media_path, media_type)
84
- except Exception as e:
85
- print(e)
86
- raise ValueError(
87
- "Unsupported media type. Please upload an image or video."
88
- )
89
-
90
- print(f"Processing media: {media_path} (Type: {media_type})")
91
-
92
- messages = [
93
- {
94
- "role": "user",
95
- "content": [
96
- {
97
- "type": media_type,
98
- media_type: media_path,
99
- **({"fps": 8.0} if media_type == "video" else {}),
100
- },
101
- {"type": "text", "text": text_input},
102
- ],
103
- }
104
- ]
105
-
106
- # Apply chat template to format the input for the model
107
- text = processor.apply_chat_template(
108
- messages, tokenize=False, add_generation_prompt=True
109
- )
110
- image_inputs, video_inputs = process_vision_info(messages)
111
-
112
- # Prepare model inputs
113
- inputs = processor(
114
- text=[text],
115
- images=image_inputs,
116
- videos=video_inputs,
117
- padding=True,
118
- return_tensors="pt",
119
- ).to("cuda")
120
-
121
- # Initialize the streamer for iterative generation
122
- streamer = TextIteratorStreamer(
123
- processor, skip_prompt=True, **{"skip_special_tokens": True}
124
- )
125
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
126
-
127
- # Start the generation in a separate thread
128
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
129
- thread.start()
130
-
131
- buffer = ""
132
- for new_text in streamer:
133
- buffer += new_text
134
- # Highlight the search_word in the buffer
135
- if search_word:
136
- # Use regex for case-insensitive search and highlight
137
- pattern = re.compile(re.escape(search_word), re.IGNORECASE)
138
- highlighted_text = pattern.sub(lambda m: f"<mark>{m.group(0)}</mark>", buffer)
139
- else:
140
- highlighted_text = buffer
141
- yield highlighted_text
142
-
143
-
144
- css = """
145
- #output {
146
- height: 500px;
147
- overflow: auto;
148
- border: 1px solid #ccc;
149
- }
150
- """
151
-
152
- with gr.Blocks(css=css) as demo:
153
- gr.Markdown(DESCRIPTION)
154
-
155
- with gr.Tab(label="Image/Video Input"):
156
- with gr.Row():
157
- with gr.Column():
158
- input_media = gr.File(
159
- label="Upload Image or Video", type="filepath"
160
- )
161
- search_word = gr.Textbox(
162
- label="Search Word", placeholder="Enter word to highlight", lines=1
163
- )
164
- submit_btn = gr.Button(value="Submit")
165
- with gr.Column():
166
- # Use HTML component to display highlighted text
167
- output_text = gr.HTML(label="Output Text")
168
-
169
- submit_btn.click(
170
- qwen_inference,
171
- inputs=[input_media, search_word],
172
- outputs=[output_text]
173
- )
174
-
175
- demo.launch(debug=True)
 
1
+ from byaldi import RAGMultiModalModel
2
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
3
+ from qwen_vl_utils import process_vision_info
4
+ import torch
5
+ import gradio as gr
6
+ from PIL import Image
7
+ import re
8
+
9
+
10
+ # Load models
11
+ def initialize_models():
12
+ """Loads and returns the RAG multimodal and Qwen2-VL models along with the processor."""
13
+ multimodal_rag = RAGMultiModalModel.from_pretrained("vidore/colpali")
14
+ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.float32)
15
+ qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
16
+ return multimodal_rag, qwen_model, qwen_processor
17
+
18
+ multimodal_rag, qwen_model, qwen_processor = initialize_models()
19
+
20
+ # Text extraction function
21
+ def perform_ocr(image):
22
+ """Extracts Sanskrit and English text from an image using the Qwen model."""
23
+ query = "Extract text from the image in original language"
24
+
25
+ # Format the request for the model
26
+ user_input = [
27
+ {
28
+ "role": "user",
29
+ "content": [
30
+ {"type": "image", "image": image},
31
+ {"type": "text", "text": query}
32
+ ]
33
+ }
34
+ ]
35
+
36
+ # Preprocess the input
37
+ input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True)
38
+ image_inputs, video_inputs = process_vision_info(user_input)
39
+ model_inputs = qwen_processor(
40
+ text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
41
+ ).to("cpu") # Use CPU for inference
42
+
43
+ # Generate output
44
+ with torch.no_grad():
45
+ generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=2000)
46
+ trimmed_ids = [output[len(input_ids):] for input_ids, output in zip(model_inputs.input_ids, generated_ids)]
47
+ ocr_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
48
+
49
+ return ocr_result
50
+
51
+ # Keyword search function
52
+ def highlight_keyword(text, keyword):
53
+ """Searches and highlights the keyword in the extracted text."""
54
+ keyword_lowercase = keyword.lower()
55
+ sentences = text.split('. ')
56
+ results = []
57
+
58
+ for sentence in sentences:
59
+ if keyword_lowercase in sentence.lower():
60
+ highlighted = re.sub(f'({re.escape(keyword)})', r'<mark>\1</mark>', sentence, flags=re.IGNORECASE)
61
+ results.append(highlighted)
62
+
63
+ return results if results else ["No matches found."]
64
+
65
+ # Gradio app for text extraction
66
+ def extract_text(image):
67
+ """Extracts text from an uploaded image."""
68
+ return perform_ocr(image)
69
+
70
+ # Gradio app for keyword search
71
+ def search_in_text(extracted_text, keyword):
72
+ """Searches for a keyword in the extracted text and highlights matches."""
73
+ results = highlight_keyword(extracted_text, keyword)
74
+ return "<br>".join(results)
75
+
76
+ # Updated title with revised phrasing
77
+ header_html = """
78
+ <h1 style="text-align: center; color: #4CAF50;"><span class="gradient-text">OCR and Text Search Prototype</span></h1>
79
+ """
80
+
81
+ # CSS to fix button sizes
82
+ custom_css = """
83
+ .gr-button {
84
+ width: 200px; /* Set a fixed width for the buttons */
85
+ padding: 12px 20px; /* Add padding to buttons for consistency */
86
+ }
87
+ .gr-textbox {
88
+ max-height: 300px; /* Set a maximum height for the extracted text output */
89
+ overflow-y: scroll; /* Enable scrolling when text exceeds the height */
90
+ }
91
+ """
92
+
93
+ # Gradio Interface
94
+ with gr.Blocks(css=custom_css) as interface:
95
+
96
+ # Header section
97
+ gr.HTML(header_html)
98
+
99
+ # Sidebar section
100
+ with gr.Row():
101
+ with gr.Column(scale=1, min_width=200):
102
+ gr.Markdown("## Instructions")
103
+ gr.Markdown("""
104
+ 1. Upload an image containing text.
105
+ 2. Extract the text from the image.
106
+ 3. Search for specific keywords within the extracted text.
107
+ """)
108
+ gr.Markdown("### Features")
109
+ gr.Markdown("""
110
+ - **OCR**: Extract text from images.
111
+ - **Keyword Search**: Search and highlight keywords in extracted text.
112
+ """)
113
+
114
+ with gr.Column(scale=3):
115
+ # Main content in tabs
116
+ with gr.Tabs():
117
+
118
+ # First Tab: Text Extraction
119
+ with gr.Tab("Extract Text"):
120
+ gr.Markdown("### Upload an image to extract text:")
121
+ with gr.Row():
122
+ image_upload = gr.Image(type="pil", label="Upload Image", interactive=True)
123
+ with gr.Row():
124
+ extract_btn = gr.Button("Extract Text")
125
+ extracted_textbox = gr.Textbox(label="Extracted Text")
126
+ extract_btn.click(extract_text, inputs=image_upload, outputs=extracted_textbox)
127
+
128
+ # Second Tab: Keyword Search
129
+ with gr.Tab("Search in Extracted Text"):
130
+ gr.Markdown("### Search for a keyword in the extracted text:")
131
+ with gr.Row():
132
+ keyword_searchbox = gr.Textbox(label="Enter Keyword", placeholder="Keyword to search")
133
+ with gr.Row():
134
+ search_btn = gr.Button("Search")
135
+ search_results = gr.HTML(label="Results")
136
+ search_btn.click(search_in_text, inputs=[extracted_textbox, keyword_searchbox], outputs=search_results)
137
+
138
+ # Launch the Gradio App
139
+ interface.launch()