pranshh commited on
Commit
86c3f4a
·
verified ·
1 Parent(s): 4af6e9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -20
app.py CHANGED
@@ -1,55 +1,68 @@
1
  # -*- coding: utf-8 -*-
2
  """OCR Web Application Prototype.ipynb
3
-
4
  Automatically generated by Colab.
5
-
6
  Original file is located at
7
  https://colab.research.google.com/drive/1vzsQ17-W1Vy6yJ60XUwFy0QRkOR_SIg7
8
  """
9
 
10
- from transformers import AutoProcessor
 
11
  import torch
12
  import gradio as gr
13
  from PIL import Image
14
- # Hypothetical imports
15
- from byaldi import ByaldiProcessor
16
- from colpali import ColPaliQwen2VLModel
17
 
18
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
19
- byaldi_processor = ByaldiProcessor()
20
 
 
 
21
  def load_model():
22
- return ColPaliQwen2VLModel.from_pretrained(
23
  "Qwen/Qwen2-VL-2B-Instruct",
24
- torch_dtype=torch.float32,
25
- low_cpu_mem_usage=True,
26
- device_map="auto"
27
  )
28
 
 
29
  vlm = load_model()
30
 
 
31
  def ocr_image(image, query="Extract text from the image", keyword=""):
32
- processed_image = byaldi_processor.process_image(image)
33
-
34
  messages = [
35
  {
36
  "role": "user",
37
  "content": [
38
  {
39
  "type": "image",
40
- "image": processed_image,
41
  },
42
  {"type": "text", "text": query},
43
  ],
44
  }
45
  ]
46
 
47
- inputs = processor(messages, return_tensors="pt")
 
 
 
 
 
 
 
 
 
48
  inputs = inputs.to("cpu")
49
 
50
- output = vlm.generate(**inputs, max_new_tokens=512)
51
- output_text = processor.decode(output[0], skip_special_tokens=True)
52
-
 
 
 
 
 
 
 
53
  if keyword:
54
  keyword_lower = keyword.lower()
55
  if keyword_lower in output_text.lower():
@@ -60,14 +73,14 @@ def ocr_image(image, query="Extract text from the image", keyword=""):
60
  else:
61
  return output_text
62
 
 
63
  def process_image(image, keyword=""):
64
- # Resize image if it's too large
65
  max_size = 1024
66
  if max(image.size) > max_size:
67
  image.thumbnail((max_size, max_size))
68
  return ocr_image(image, keyword=keyword)
69
 
70
- # Gradio interface:
71
  interface = gr.Interface(
72
  fn=process_image,
73
  inputs=[
 
1
  # -*- coding: utf-8 -*-
2
  """OCR Web Application Prototype.ipynb
 
3
  Automatically generated by Colab.
 
4
  Original file is located at
5
  https://colab.research.google.com/drive/1vzsQ17-W1Vy6yJ60XUwFy0QRkOR_SIg7
6
  """
7
 
8
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
9
+ from qwen_vl_utils import process_vision_info
10
  import torch
11
  import gradio as gr
12
  from PIL import Image
13
+
 
 
14
 
15
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
 
16
 
17
+ # Initialize the model with float16 precision and handle fallback to CPU
18
+ # Simplified model loading function for CPU
19
  def load_model():
20
+ return Qwen2VLForConditionalGeneration.from_pretrained(
21
  "Qwen/Qwen2-VL-2B-Instruct",
22
+ torch_dtype=torch.float32, # Use float32 for CPU
23
+ low_cpu_mem_usage=True
 
24
  )
25
 
26
+ # Load the model
27
  vlm = load_model()
28
 
29
+ # OCR function to extract text from an image
30
  def ocr_image(image, query="Extract text from the image", keyword=""):
 
 
31
  messages = [
32
  {
33
  "role": "user",
34
  "content": [
35
  {
36
  "type": "image",
37
+ "image": image,
38
  },
39
  {"type": "text", "text": query},
40
  ],
41
  }
42
  ]
43
 
44
+ # Prepare inputs for the model
45
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
46
+ image_inputs, video_inputs = process_vision_info(messages)
47
+ inputs = processor(
48
+ text=[text],
49
+ images=image_inputs,
50
+ videos=video_inputs,
51
+ padding=True,
52
+ return_tensors="pt",
53
+ )
54
  inputs = inputs.to("cpu")
55
 
56
+ # Generate the output text using the model
57
+ generated_ids = vlm.generate(**inputs, max_new_tokens=512)
58
+ generated_ids_trimmed = [
59
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
60
+ ]
61
+
62
+ output_text = processor.batch_decode(
63
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
64
+ )[0]
65
+
66
  if keyword:
67
  keyword_lower = keyword.lower()
68
  if keyword_lower in output_text.lower():
 
73
  else:
74
  return output_text
75
 
76
+ # Gradio interface
77
  def process_image(image, keyword=""):
 
78
  max_size = 1024
79
  if max(image.size) > max_size:
80
  image.thumbnail((max_size, max_size))
81
  return ocr_image(image, keyword=keyword)
82
 
83
+ # Update the Gradio interface:
84
  interface = gr.Interface(
85
  fn=process_image,
86
  inputs=[