Spaces:
Sleeping
Sleeping
# Load the trained model | |
import gradio as gr | |
from ultralytics import YOLO | |
import cv2 | |
import numpy as np | |
import fitz # PyMuPDF | |
from PIL import Image | |
# Load the trained model | |
model_path = 'best.pt' # Replace with the path to your trained .pt file | |
model = YOLO(model_path) | |
# Define the class indices for figures and tables (adjust based on your model's classes) | |
figure_class_index = 3 # class index for figures | |
table_class_index = 4 # class index for tables | |
# Function to perform inference on an image and return bounding boxes for figures and tables | |
def infer_image_and_get_boxes(image): | |
# Convert the image from BGR to RGB | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# Perform inference | |
results = model(image_rgb) | |
boxes = [] | |
# Extract results | |
for result in results: | |
for box in result.boxes: | |
cls = int(box.cls[0]) | |
if cls == figure_class_index or cls == table_class_index: | |
x1, y1, x2, y2 = map(int, box.xyxy[0]) | |
boxes.append((x1, y1, x2, y2)) | |
return boxes | |
# Function to crop images from the boxes | |
def crop_images_from_boxes(image, boxes, scale_factor): | |
cropped_images = [] | |
for box in boxes: | |
x1, y1, x2, y2 = [int(coord * scale_factor) for coord in box] | |
cropped_image = image[y1:y2, x1:x2] | |
cropped_images.append(cropped_image) | |
return cropped_images | |
def process_pdf(pdf_file): | |
# Open the PDF file | |
doc = fitz.open(pdf_file) | |
all_cropped_images = [] | |
# Set the DPI for inference and high resolution for cropping | |
low_dpi = 50 | |
high_dpi = 300 | |
# Calculate the scaling factor | |
scale_factor = high_dpi / low_dpi | |
# Loop through each page | |
for page_num in range(len(doc)): | |
page = doc.load_page(page_num) | |
# Perform inference at low DPI | |
low_res_pix = page.get_pixmap(dpi=low_dpi) | |
low_res_img = Image.frombytes("RGB", [low_res_pix.width, low_res_pix.height], low_res_pix.samples) | |
low_res_img = np.array(low_res_img) | |
# Get bounding boxes from low DPI image | |
boxes = infer_image_and_get_boxes(low_res_img) | |
# Load high DPI image for cropping | |
high_res_pix = page.get_pixmap(dpi=high_dpi) | |
high_res_img = Image.frombytes("RGB", [high_res_pix.width, high_res_pix.height], high_res_pix.samples) | |
high_res_img = np.array(high_res_img) | |
# Crop images at high DPI | |
cropped_imgs = crop_images_from_boxes(high_res_img, boxes, scale_factor) | |
all_cropped_images.extend(cropped_imgs) | |
return all_cropped_images | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=process_pdf, | |
inputs=gr.File(label="Upload a PDF"), | |
outputs=gr.Gallery(label="Cropped Figures and Tables from PDF Pages"), | |
title="Fast document layout analysis based on YOLOv8", | |
description="Upload a PDF file to get cropped figures and tables from each page." | |
) | |
# Launch the app | |
iface.launch() | |