# Load the trained model import gradio as gr from ultralytics import YOLO import cv2 import numpy as np import fitz # PyMuPDF from PIL import Image # Load the trained model model_path = 'best.pt' # Replace with the path to your trained .pt file model = YOLO(model_path) # Define the class indices for figures and tables (adjust based on your model's classes) figure_class_index = 3 # class index for figures table_class_index = 4 # class index for tables # Function to perform inference on an image and return bounding boxes for figures and tables def infer_image_and_get_boxes(image): # Convert the image from BGR to RGB image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Perform inference results = model(image_rgb) boxes = [] # Extract results for result in results: for box in result.boxes: cls = int(box.cls[0]) if cls == figure_class_index or cls == table_class_index: x1, y1, x2, y2 = map(int, box.xyxy[0]) boxes.append((x1, y1, x2, y2)) return boxes # Function to crop images from the boxes def crop_images_from_boxes(image, boxes, scale_factor): cropped_images = [] for box in boxes: x1, y1, x2, y2 = [int(coord * scale_factor) for coord in box] cropped_image = image[y1:y2, x1:x2] cropped_images.append(cropped_image) return cropped_images def process_pdf(pdf_file): # Open the PDF file doc = fitz.open(pdf_file) all_cropped_images = [] # Set the DPI for inference and high resolution for cropping low_dpi = 50 high_dpi = 300 # Calculate the scaling factor scale_factor = high_dpi / low_dpi # Loop through each page for page_num in range(len(doc)): page = doc.load_page(page_num) # Perform inference at low DPI low_res_pix = page.get_pixmap(dpi=low_dpi) low_res_img = Image.frombytes("RGB", [low_res_pix.width, low_res_pix.height], low_res_pix.samples) low_res_img = np.array(low_res_img) # Get bounding boxes from low DPI image boxes = infer_image_and_get_boxes(low_res_img) # Load high DPI image for cropping high_res_pix = page.get_pixmap(dpi=high_dpi) high_res_img = Image.frombytes("RGB", [high_res_pix.width, high_res_pix.height], high_res_pix.samples) high_res_img = np.array(high_res_img) # Crop images at high DPI cropped_imgs = crop_images_from_boxes(high_res_img, boxes, scale_factor) all_cropped_images.extend(cropped_imgs) return all_cropped_images # Create Gradio interface iface = gr.Interface( fn=process_pdf, inputs=gr.File(label="Upload a PDF"), outputs=gr.Gallery(label="Cropped Figures and Tables from PDF Pages"), title="Fast document layout analysis based on YOLOv8", description="Upload a PDF file to get cropped figures and tables from each page." ) # Launch the app iface.launch()