fastpaperlayout / app.py
zliang's picture
Update app.py
028ddc2 verified
raw
history blame
3.37 kB
import gradio as gr
import numpy as np
import fitz # PyMuPDF
from ultralytics import YOLOv10
import spaces
# Load the trained model
model = YOLOv10("best.pt")
# Define the class indices for figures and tables
figure_class_index = 3 # class index for figures
table_class_index = 4 # class index for tables
# Function to perform inference on an image and return bounding boxes for figures and tables
def infer_image_and_get_boxes(image, confidence_threshold=0.6):
results = model.predict(image)
boxes = [
(int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0]))
for result in results for box in result.boxes
if int(box.cls[0]) in {figure_class_index, table_class_index} and box.conf[0] > confidence_threshold
]
return boxes
# Function to crop images from the boxes
def crop_images_from_boxes(image, boxes, scale_factor):
figures = []
tables = []
for (x1, y1, x2, y2, cls) in boxes:
cropped_img = image[int(y1 * scale_factor):int(y2 * scale_factor), int(x1 * scale_factor):int(x2 * scale_factor)]
if cls == figure_class_index:
figures.append(cropped_img)
elif cls == table_class_index:
tables.append(cropped_img)
return figures, tables
@spaces.GPU
def process_pdf(pdf_file):
# Open the PDF file
doc = fitz.open(pdf_file)
all_figures = []
all_tables = []
# Set the DPI for inference and high resolution for cropping
low_dpi = 50
high_dpi = 300
# Calculate the scaling factor
scale_factor = high_dpi / low_dpi
# Pre-cache all page pixmaps at low DPI
low_res_pixmaps = [page.get_pixmap(dpi=low_dpi) for page in doc]
# Loop through each page
for page_num, low_res_pix in enumerate(low_res_pixmaps):
low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3)
# Get bounding boxes from low DPI image
boxes = infer_image_and_get_boxes(low_res_img)
if boxes:
# Load high DPI image for cropping only if boxes are found
high_res_pix = doc[page_num].get_pixmap(dpi=high_dpi)
high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3)
# Crop images at high DPI
figures, tables = crop_images_from_boxes(high_res_img, boxes, scale_factor)
all_figures.extend(figures)
all_tables.extend(tables)
return all_figures, all_tables
# Create Gradio interface
with gr.Blocks() as app:
gr.Markdown(
"""
# PDF Figures and Tables Extraction
Upload a PDF file to extract figures and tables using YOLOv10.
"""
)
with gr.Row():
with gr.Column():
file_input = gr.File(label="Upload a PDF")
with gr.Column():
extract_button = gr.Button("Extract")
with gr.Row():
with gr.Column():
figures_gallery = gr.Gallery(label="Figures from PDF", object_fit='scale-down')
with gr.Column():
tables_gallery = gr.Gallery(label="Tables from PDF", object_fit='scale-down')
extract_button.click(process_pdf, inputs=file_input, outputs=[figures_gallery, tables_gallery])
app.launch()