zliang commited on
Commit
e91a768
1 Parent(s): 743990b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -29
app.py CHANGED
@@ -1,55 +1,89 @@
 
1
  import gradio as gr
2
  from ultralytics import YOLO
3
  import cv2
4
  import numpy as np
5
- #import spaces
 
6
 
7
  # Load the trained model
8
  model_path = 'best.pt' # Replace with the path to your trained .pt file
9
  model = YOLO(model_path)
10
 
11
- # Function to perform inference on an image
12
- colors = {
13
- 0: (255, 0, 0), # Red for category 0
14
- 1: (0, 255, 0), # Green for category 1
15
- 2: (0, 0, 255), # Blue for category 2
16
- 3: (255, 255, 0), # Cyan for category 3
17
- 4: (255, 0, 255) # Magenta for category 4
18
- }
19
-
20
- # Function to perform inference on an image
21
- def infer_image(image):
22
  # Convert the image from BGR to RGB
23
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
24
 
25
  # Perform inference
26
  results = model(image_rgb)
27
 
28
- # Extract results and annotate image
 
29
  for result in results:
30
  for box in result.boxes:
31
- x1, y1, x2, y2 = box.xyxy[0]
32
  cls = int(box.cls[0])
33
- conf = float(box.conf[0])
34
-
35
- # Get the color for the current class
36
- color = colors.get(cls, (0, 255, 0)) # Default to green if class not found
37
-
38
- # Draw bounding box
39
- cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
40
- # Draw label
41
- label = f'{model.names[cls]} {conf:.2f}'
42
- cv2.putText(image, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- return image
45
 
46
  # Create Gradio interface
47
  iface = gr.Interface(
48
- fn=infer_image,
49
- inputs=gr.Image(type="numpy", label="Upload an Image"),
50
- outputs=gr.Image(type="numpy", label="Annotated Image"),
51
  title="Fast document layout analysis based on YOLOv8",
52
- description="Upload an image to get document layout analysis results."
53
  )
54
 
55
  # Launch the app
 
1
+ # Load the trained model
2
  import gradio as gr
3
  from ultralytics import YOLO
4
  import cv2
5
  import numpy as np
6
+ import fitz # PyMuPDF
7
+ from PIL import Image
8
 
9
  # Load the trained model
10
  model_path = 'best.pt' # Replace with the path to your trained .pt file
11
  model = YOLO(model_path)
12
 
13
+ # Define the class indices for figures and tables (adjust based on your model's classes)
14
+ figure_class_index = 3 # class index for figures
15
+ table_class_index = 4 # class index for tables
16
+
17
+ # Function to perform inference on an image and return bounding boxes for figures and tables
18
+ def infer_image_and_get_boxes(image):
 
 
 
 
 
19
  # Convert the image from BGR to RGB
20
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
21
 
22
  # Perform inference
23
  results = model(image_rgb)
24
 
25
+ boxes = []
26
+ # Extract results
27
  for result in results:
28
  for box in result.boxes:
 
29
  cls = int(box.cls[0])
30
+ if cls == figure_class_index or cls == table_class_index:
31
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
32
+ boxes.append((x1, y1, x2, y2))
33
+
34
+ return boxes
35
+
36
+ # Function to crop images from the boxes
37
+ def crop_images_from_boxes(image, boxes, scale_factor):
38
+ cropped_images = []
39
+ for box in boxes:
40
+ x1, y1, x2, y2 = [int(coord * scale_factor) for coord in box]
41
+ cropped_image = image[y1:y2, x1:x2]
42
+ cropped_images.append(cropped_image)
43
+ return cropped_images
44
+
45
+ def process_pdf(pdf_file):
46
+ # Open the PDF file
47
+ doc = fitz.open(pdf_file)
48
+ all_cropped_images = []
49
+
50
+ # Set the DPI for inference and high resolution for cropping
51
+ low_dpi = 50
52
+ high_dpi = 300
53
+
54
+ # Calculate the scaling factor
55
+ scale_factor = high_dpi / low_dpi
56
+
57
+ # Loop through each page
58
+ for page_num in range(len(doc)):
59
+ page = doc.load_page(page_num)
60
+
61
+ # Perform inference at low DPI
62
+ low_res_pix = page.get_pixmap(dpi=low_dpi)
63
+ low_res_img = Image.frombytes("RGB", [low_res_pix.width, low_res_pix.height], low_res_pix.samples)
64
+ low_res_img = np.array(low_res_img)
65
+
66
+ # Get bounding boxes from low DPI image
67
+ boxes = infer_image_and_get_boxes(low_res_img)
68
+
69
+ # Load high DPI image for cropping
70
+ high_res_pix = page.get_pixmap(dpi=high_dpi)
71
+ high_res_img = Image.frombytes("RGB", [high_res_pix.width, high_res_pix.height], high_res_pix.samples)
72
+ high_res_img = np.array(high_res_img)
73
+
74
+ # Crop images at high DPI
75
+ cropped_imgs = crop_images_from_boxes(high_res_img, boxes, scale_factor)
76
+ all_cropped_images.extend(cropped_imgs)
77
 
78
+ return all_cropped_images
79
 
80
  # Create Gradio interface
81
  iface = gr.Interface(
82
+ fn=process_pdf,
83
+ inputs=gr.File(label="Upload a PDF"),
84
+ outputs=gr.Gallery(label="Cropped Figures and Tables from PDF Pages"),
85
  title="Fast document layout analysis based on YOLOv8",
86
+ description="Upload a PDF file to get cropped figures and tables from each page."
87
  )
88
 
89
  # Launch the app