import gradio as gr import numpy as np import fitz # PyMuPDF from ultralytics import YOLOv10 import spaces # Load the trained model model = YOLOv10("best.pt") # Define the class indices for figures and tables figure_class_index = 4 # class index for figures table_class_index = 3 # class index for tables # Function to perform inference on an image and return bounding boxes for figures and tables def infer_image_and_get_boxes(image, confidence_threshold=0.6): results = model.predict(image) boxes = [ (int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0])) for result in results for box in result.boxes if int(box.cls[0]) in {figure_class_index, table_class_index} and box.conf[0] > confidence_threshold ] return boxes # Function to crop images from the boxes def crop_images_from_boxes(image, boxes, scale_factor): figures = [] tables = [] for (x1, y1, x2, y2, cls) in boxes: cropped_img = image[int(y1 * scale_factor):int(y2 * scale_factor), int(x1 * scale_factor):int(x2 * scale_factor)] if cls == figure_class_index: figures.append(cropped_img) elif cls == table_class_index: tables.append(cropped_img) return figures, tables @spaces.GPU def process_pdf(pdf_file): # Open the PDF file doc = fitz.open(pdf_file) all_figures = [] all_tables = [] # Set the DPI for inference and high resolution for cropping low_dpi = 50 high_dpi = 300 # Calculate the scaling factor scale_factor = high_dpi / low_dpi # Pre-cache all page pixmaps at low DPI low_res_pixmaps = [page.get_pixmap(dpi=low_dpi) for page in doc] # Loop through each page for page_num, low_res_pix in enumerate(low_res_pixmaps): low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3) # Get bounding boxes from low DPI image boxes = infer_image_and_get_boxes(low_res_img) if boxes: # Load high DPI image for cropping only if boxes are found high_res_pix = doc[page_num].get_pixmap(dpi=high_dpi) high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3) # Crop images at high DPI figures, tables = crop_images_from_boxes(high_res_img, boxes, scale_factor) all_figures.extend(figures) all_tables.extend(tables) return all_figures, all_tables # Create Gradio interface with gr.Blocks() as app: gr.Markdown( """ # PDF Figures and Tables Extraction Upload a PDF file to extract figures and tables using YOLOv10. """ ) with gr.Row(): with gr.Column(): file_input = gr.File(label="Upload a PDF") with gr.Column(): extract_button = gr.Button("Extract") with gr.Row(): with gr.Column(): figures_gallery = gr.Gallery(label="Figures from PDF", object_fit='scale-down') with gr.Column(): tables_gallery = gr.Gallery(label="Tables from PDF", object_fit='scale-down') extract_button.click(process_pdf, inputs=file_input, outputs=[figures_gallery, tables_gallery]) app.launch()