File size: 3,371 Bytes
0bbf6ef
 
cbe1985
 
36ec266
07f5bd9
aa8cd87
0bbf6ef
43d306c
ea76efb
 
e91a768
b1e4794
cff5fa2
cbe1985
cff5fa2
eb98323
cff5fa2
 
 
 
e91a768
 
 
eb98323
 
 
 
 
 
 
 
 
b296597
e8ad557
4504622
cbe1985
 
eb98323
 
3cadd69
 
 
 
c65777e
3cadd69
 
ec2e6e8
cbe1985
 
 
 
 
 
 
cff5fa2
 
 
649e38b
cbe1985
 
 
ff2c42f
 
eb98323
 
 
cbe1985
eb98323
0bbf6ef
 
c5d6690
028ddc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bbf6ef
028ddc2
1b3f90f
cbe1985
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
import numpy as np
import fitz  # PyMuPDF
from ultralytics import YOLOv10
import spaces
# Load the trained model
model = YOLOv10("best.pt")

# Define the class indices for figures and tables
figure_class_index = 4  # class index for figures
table_class_index = 3   # class index for tables

# Function to perform inference on an image and return bounding boxes for figures and tables
def infer_image_and_get_boxes(image, confidence_threshold=0.6):
    results = model.predict(image)
    boxes = [
        (int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0]))
        for result in results for box in result.boxes
        if int(box.cls[0]) in {figure_class_index, table_class_index} and box.conf[0] > confidence_threshold
    ]
    return boxes

# Function to crop images from the boxes
def crop_images_from_boxes(image, boxes, scale_factor):
    figures = []
    tables = []
    for (x1, y1, x2, y2, cls) in boxes:
        cropped_img = image[int(y1 * scale_factor):int(y2 * scale_factor), int(x1 * scale_factor):int(x2 * scale_factor)]
        if cls == figure_class_index:
            figures.append(cropped_img)
        elif cls == table_class_index:
            tables.append(cropped_img)
    return figures, tables

@spaces.GPU
def process_pdf(pdf_file):
    # Open the PDF file
    doc = fitz.open(pdf_file)
    all_figures = []
    all_tables = []

    # Set the DPI for inference and high resolution for cropping
    low_dpi = 50
    high_dpi = 300

    # Calculate the scaling factor
    scale_factor = high_dpi / low_dpi

    # Pre-cache all page pixmaps at low DPI
    low_res_pixmaps = [page.get_pixmap(dpi=low_dpi) for page in doc]
    
    # Loop through each page
    for page_num, low_res_pix in enumerate(low_res_pixmaps):
        low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3)
        
        # Get bounding boxes from low DPI image
        boxes = infer_image_and_get_boxes(low_res_img)
        
        if boxes:
            # Load high DPI image for cropping only if boxes are found
            high_res_pix = doc[page_num].get_pixmap(dpi=high_dpi)
            high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3)
            
            # Crop images at high DPI
            figures, tables = crop_images_from_boxes(high_res_img, boxes, scale_factor)
            all_figures.extend(figures)
            all_tables.extend(tables)
    
    return all_figures, all_tables

# Create Gradio interface

with gr.Blocks() as app:
    gr.Markdown(
        """
        # PDF Figures and Tables Extraction
        Upload a PDF file to extract figures and tables using YOLOv10.
        """
    )
    
    with gr.Row():
        with gr.Column():
            file_input = gr.File(label="Upload a PDF")
        with gr.Column():
            extract_button = gr.Button("Extract")
    
    with gr.Row():
        with gr.Column():
            figures_gallery = gr.Gallery(label="Figures from PDF", object_fit='scale-down')
        with gr.Column():
            tables_gallery = gr.Gallery(label="Tables from PDF", object_fit='scale-down')
    
    extract_button.click(process_pdf, inputs=file_input, outputs=[figures_gallery, tables_gallery])

app.launch()