zliang commited on
Commit
c65777e
1 Parent(s): ec2e6e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -24
app.py CHANGED
@@ -1,13 +1,13 @@
1
  import gradio as gr
2
  from ultralytics import YOLO
3
  import cv2
 
4
  import numpy as np
5
  import fitz # PyMuPDF
6
  from PIL import Image
7
- import spaces
8
 
9
  # Load the trained model
10
- model_path = 'best.pt' # Replace with the path to your trained .pt file
11
  model = YOLO(model_path)
12
 
13
  # Define the class indices for figures and tables
@@ -16,28 +16,20 @@ table_class_index = 4 # class index for tables
16
 
17
  # Function to perform inference on an image and return bounding boxes for figures and tables
18
  def infer_image_and_get_boxes(image, confidence_threshold=0.6):
19
- # Perform inference
20
  results = model(image)
21
-
22
- boxes = []
23
- # Extract results
24
- for result in results:
25
- for box in result.boxes:
26
- cls = int(box.cls[0])
27
- confidence = box.conf[0]
28
- if (cls == figure_class_index or cls == table_class_index) and confidence > confidence_threshold:
29
- x1, y1, x2, y2 = map(int, box.xyxy[0])
30
- boxes.append((x1, y1, x2, y2))
31
-
32
  return boxes
33
 
34
  # Function to crop images from the boxes
35
  def crop_images_from_boxes(image, boxes, scale_factor):
36
- cropped_images = []
37
- for box in boxes:
38
- x1, y1, x2, y2 = [int(coord * scale_factor) for coord in box]
39
- cropped_image = image[y1:y2, x1:x2]
40
- cropped_images.append(cropped_image)
41
  return cropped_images
42
 
43
  @spaces.GPU
@@ -49,7 +41,7 @@ def process_pdf(pdf_file):
49
  # Set the DPI for inference and high resolution for cropping
50
  low_dpi = 50
51
  high_dpi = 300
52
-
53
  # Calculate the scaling factor
54
  scale_factor = high_dpi / low_dpi
55
 
@@ -59,8 +51,7 @@ def process_pdf(pdf_file):
59
 
60
  # Perform inference at low DPI
61
  low_res_pix = page.get_pixmap(dpi=low_dpi)
62
- low_res_img = Image.frombytes("RGB", [low_res_pix.width, low_res_pix.height], low_res_pix.samples)
63
- low_res_img = np.array(low_res_img)
64
 
65
  # Get bounding boxes from low DPI image
66
  boxes = infer_image_and_get_boxes(low_res_img)
@@ -68,8 +59,7 @@ def process_pdf(pdf_file):
68
  if boxes:
69
  # Load high DPI image for cropping only if boxes are found
70
  high_res_pix = page.get_pixmap(dpi=high_dpi)
71
- high_res_img = Image.frombytes("RGB", [high_res_pix.width, high_res_pix.height], high_res_pix.samples)
72
- high_res_img = np.array(high_res_img)
73
 
74
  # Crop images at high DPI
75
  cropped_imgs = crop_images_from_boxes(high_res_img, boxes, scale_factor)
 
1
  import gradio as gr
2
  from ultralytics import YOLO
3
  import cv2
4
+ import spaces
5
  import numpy as np
6
  import fitz # PyMuPDF
7
  from PIL import Image
 
8
 
9
  # Load the trained model
10
+ model_path = 'runs/detect/train7/weights/best.pt' # Replace with the path to your trained .pt file
11
  model = YOLO(model_path)
12
 
13
  # Define the class indices for figures and tables
 
16
 
17
  # Function to perform inference on an image and return bounding boxes for figures and tables
18
  def infer_image_and_get_boxes(image, confidence_threshold=0.6):
 
19
  results = model(image)
20
+ boxes = [
21
+ (int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]))
22
+ for result in results for box in result.boxes
23
+ if int(box.cls[0]) in {figure_class_index, table_class_index} and box.conf[0] > confidence_threshold
24
+ ]
 
 
 
 
 
 
25
  return boxes
26
 
27
  # Function to crop images from the boxes
28
  def crop_images_from_boxes(image, boxes, scale_factor):
29
+ cropped_images = [
30
+ image[int(y1 * scale_factor):int(y2 * scale_factor), int(x1 * scale_factor):int(x2 * scale_factor)]
31
+ for (x1, y1, x2, y2) in boxes
32
+ ]
 
33
  return cropped_images
34
 
35
  @spaces.GPU
 
41
  # Set the DPI for inference and high resolution for cropping
42
  low_dpi = 50
43
  high_dpi = 300
44
+
45
  # Calculate the scaling factor
46
  scale_factor = high_dpi / low_dpi
47
 
 
51
 
52
  # Perform inference at low DPI
53
  low_res_pix = page.get_pixmap(dpi=low_dpi)
54
+ low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3)
 
55
 
56
  # Get bounding boxes from low DPI image
57
  boxes = infer_image_and_get_boxes(low_res_img)
 
59
  if boxes:
60
  # Load high DPI image for cropping only if boxes are found
61
  high_res_pix = page.get_pixmap(dpi=high_dpi)
62
+ high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3)
 
63
 
64
  # Crop images at high DPI
65
  cropped_imgs = crop_images_from_boxes(high_res_img, boxes, scale_factor)