pranshh commited on
Commit
bf86837
1 Parent(s): d21030f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -24
app.py CHANGED
@@ -4,9 +4,11 @@ import gradio as gr
4
  from PIL import Image
5
  from byaldi import RAGMultiModalModel
6
  from qwen_vl_utils import process_vision_info
 
 
7
 
8
  # Load ColPali model
9
- RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
10
 
11
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
12
 
@@ -16,30 +18,33 @@ def load_model():
16
  vlm = load_model()
17
 
18
  def ocr_image(image, keyword=""):
19
- # Convert PIL Image to file-like object
20
- import io
21
- img_byte_arr = io.BytesIO()
22
- image.save(img_byte_arr, format='PNG')
23
- img_byte_arr = img_byte_arr.getvalue()
24
-
25
- # Index the image
26
- RAG.index(input_data=img_byte_arr, index_name="temp_index", overwrite=True)
27
-
28
- # Retrieve text from the image
29
- results = RAG.search("Extract all text from this image", k=1)
30
-
31
- # Extract text from results
32
- output_text = results[0].get('text', '')
33
-
34
- if keyword:
35
- keyword_lower = keyword.lower()
36
- if keyword_lower in output_text.lower():
37
- highlighted_text = output_text.replace(keyword, f"**{keyword}**")
38
- return f"Keyword '{keyword}' found in the text:\n\n{highlighted_text}"
 
 
39
  else:
40
- return f"Keyword '{keyword}' not found in the text:\n\n{output_text}"
41
- else:
42
- return output_text
 
43
 
44
  def process_image(image, keyword=""):
45
  max_size = 1024
 
4
  from PIL import Image
5
  from byaldi import RAGMultiModalModel
6
  from qwen_vl_utils import process_vision_info
7
+ import os
8
+ import tempfile
9
 
10
  # Load ColPali model
11
+ RAG = RAGMultiModalModel.from_pretrained("vidore/colpali", device_map="cpu", torch_dtype=torch.float32)
12
 
13
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
14
 
 
18
  vlm = load_model()
19
 
20
  def ocr_image(image, keyword=""):
21
+ # Save the image to a temporary file
22
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
23
+ image.save(temp_file, format='PNG')
24
+ temp_file_path = temp_file.name
25
+
26
+ try:
27
+ # Index the image
28
+ RAG.index(input_path=temp_file_path, index_name="temp_index", overwrite=True)
29
+
30
+ # Retrieve text from the image
31
+ results = RAG.search("Extract all text from this image", k=1)
32
+
33
+ # Extract text from results
34
+ output_text = results[0].get('text', '')
35
+
36
+ if keyword:
37
+ keyword_lower = keyword.lower()
38
+ if keyword_lower in output_text.lower():
39
+ highlighted_text = output_text.replace(keyword, f"**{keyword}**")
40
+ return f"Keyword '{keyword}' found in the text:\n\n{highlighted_text}"
41
+ else:
42
+ return f"Keyword '{keyword}' not found in the text:\n\n{output_text}"
43
  else:
44
+ return output_text
45
+ finally:
46
+ # Clean up the temporary file
47
+ os.unlink(temp_file_path)
48
 
49
  def process_image(image, keyword=""):
50
  max_size = 1024