zliang commited on
Commit
cff5fa2
1 Parent(s): 9b47e37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -23
app.py CHANGED
@@ -11,18 +11,15 @@ model = YOLOv10("best.pt")
11
  figure_class_index = 3 # class index for figures
12
  table_class_index = 4 # class index for tables
13
 
14
- # Function to perform inference on a batch of images and return bounding boxes for figures and tables
15
- def infer_images_and_get_boxes(images, confidence_threshold=0.6):
16
- results = model.predict(images)
17
- all_boxes = []
18
- for result in results:
19
- boxes = [
20
- (int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]))
21
- for box in result.boxes
22
- if int(box.cls[0]) in {figure_class_index, table_class_index} and box.conf[0] > confidence_threshold
23
- ]
24
- all_boxes.append(boxes)
25
- return all_boxes
26
 
27
  # Function to crop images from the boxes
28
  def crop_images_from_boxes(image, boxes, scale_factor):
@@ -48,17 +45,13 @@ def process_pdf(pdf_file):
48
  # Pre-cache all page pixmaps at low DPI
49
  low_res_pixmaps = [page.get_pixmap(dpi=low_dpi) for page in doc]
50
 
51
- # Prepare a batch of low resolution images for inference
52
- low_res_imgs = [
53
- np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, 3)
54
- for pix in low_res_pixmaps
55
- ]
56
-
57
- # Run inference on the batch of low resolution images
58
- all_boxes = infer_images_and_get_boxes(low_res_imgs)
59
-
60
- # Loop through each page and corresponding boxes
61
- for page_num, (low_res_img, boxes) in enumerate(zip(low_res_imgs, all_boxes)):
62
  if boxes:
63
  # Load high DPI image for cropping only if boxes are found
64
  high_res_pix = doc[page_num].get_pixmap(dpi=high_dpi)
 
11
  figure_class_index = 3 # class index for figures
12
  table_class_index = 4 # class index for tables
13
 
14
+ # Function to perform inference on a single image and return bounding boxes for figures and tables
15
+ def infer_image_and_get_boxes(image, confidence_threshold=0.6):
16
+ results = model.predict(image)
17
+ boxes = [
18
+ (int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]))
19
+ for result in results for box in result.boxes
20
+ if int(box.cls[0]) in {figure_class_index, table_class_index} and box.conf[0] > confidence_threshold
21
+ ]
22
+ return boxes
 
 
 
23
 
24
  # Function to crop images from the boxes
25
  def crop_images_from_boxes(image, boxes, scale_factor):
 
45
  # Pre-cache all page pixmaps at low DPI
46
  low_res_pixmaps = [page.get_pixmap(dpi=low_dpi) for page in doc]
47
 
48
+ # Loop through each page
49
+ for page_num, low_res_pix in enumerate(low_res_pixmaps):
50
+ low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3)
51
+
52
+ # Get bounding boxes from low DPI image
53
+ boxes = infer_image_and_get_boxes(low_res_img)
54
+
 
 
 
 
55
  if boxes:
56
  # Load high DPI image for cropping only if boxes are found
57
  high_res_pix = doc[page_num].get_pixmap(dpi=high_dpi)