import time import torch from transformers import AutoModelForImageSegmentation from PIL import Image from torchvision import transforms import gradio as gr import gc def load_model(): model = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True) device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device) model.eval() return model, device birefnet, device = load_model() # Preprocessing image_size = (1024, 1024) transform_image = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def run_inference(images, model, device): inputs = [] original_sizes = [] for img in images: original_sizes.append(img.size) inputs.append(transform_image(img)) input_tensor = torch.stack(inputs).to(device) try: with torch.no_grad(): preds = model(input_tensor)[-1].sigmoid().cpu() except torch.OutOfMemoryError: del input_tensor torch.cuda.empty_cache() raise # Post-process results = [] for i, img in enumerate(images): pred = preds[i].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(original_sizes[i]) result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0)) result.paste(img, mask=mask) results.append(result) # Cleanup del input_tensor, preds gc.collect() torch.cuda.empty_cache() return results def binary_search_max(images): # After OOM, try to find max feasible batch low, high = 1, len(images) best = None best_count = 0 while low <= high: mid = (low + high) // 2 batch = images[:mid] try: # Re-load model to avoid leftover memory fragmentation global birefnet, device birefnet, device = load_model() res = run_inference(batch, birefnet, device) best = res best_count = mid low = mid + 1 except torch.OutOfMemoryError: high = mid - 1 return best, best_count def extract_objects(filepaths): images = [Image.open(p).convert("RGB") for p in filepaths] start_time = time.time() # First attempt: all images try: results = run_inference(images, birefnet, device) end_time = time.time() total_time = end_time - start_time summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully." return results, summary except torch.OutOfMemoryError: # OOM occurred, try to find feasible batch size now oom_time = time.time() initial_attempt_time = oom_time - start_time best, best_count = binary_search_max(images) end_time = time.time() total_time = end_time - start_time if best is None: # Not even 1 image works summary = ( f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n" f"Could not process even a single image.\n" f"Total time including fallback attempts: {total_time:.2f}s." ) return [], summary else: summary = ( f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n" f"Found that {best_count} images can be processed without OOM.\n" f"Total time including fallback attempts: {total_time:.2f}s.\n" f"Next time, try using up to {best_count} images." ) return best, summary iface = gr.Interface( fn=extract_objects, inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"), outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")], title="BiRefNet Bulk Background Removal with On-Demand Fallback", description="Upload as many images as you want. If OOM occurs, a fallback will find the max feasible number. Extra cleanup steps and reinitialization for more consistent results." ) iface.launch()