|
import time |
|
import torch |
|
from transformers import AutoModelForImageSegmentation |
|
from PIL import Image |
|
from torchvision import transforms |
|
import gradio as gr |
|
import gc |
|
|
|
def load_model(): |
|
model = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True) |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
model.to(device) |
|
model.eval() |
|
return model, device |
|
|
|
birefnet, device = load_model() |
|
|
|
|
|
image_size = (1024, 1024) |
|
transform_image = transforms.Compose([ |
|
transforms.Resize(image_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def run_inference(images, model, device): |
|
inputs = [] |
|
original_sizes = [] |
|
for img in images: |
|
original_sizes.append(img.size) |
|
inputs.append(transform_image(img)) |
|
input_tensor = torch.stack(inputs).to(device) |
|
try: |
|
with torch.no_grad(): |
|
preds = model(input_tensor)[-1].sigmoid().cpu() |
|
except torch.OutOfMemoryError: |
|
del input_tensor |
|
torch.cuda.empty_cache() |
|
raise |
|
|
|
results = [] |
|
for i, img in enumerate(images): |
|
pred = preds[i].squeeze() |
|
pred_pil = transforms.ToPILImage()(pred) |
|
mask = pred_pil.resize(original_sizes[i]) |
|
result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0)) |
|
result.paste(img, mask=mask) |
|
results.append(result) |
|
|
|
del input_tensor, preds |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
return results |
|
|
|
def binary_search_max(images): |
|
|
|
low, high = 1, len(images) |
|
best = None |
|
best_count = 0 |
|
while low <= high: |
|
mid = (low + high) // 2 |
|
batch = images[:mid] |
|
try: |
|
|
|
global birefnet, device |
|
birefnet, device = load_model() |
|
res = run_inference(batch, birefnet, device) |
|
best = res |
|
best_count = mid |
|
low = mid + 1 |
|
except torch.OutOfMemoryError: |
|
high = mid - 1 |
|
return best, best_count |
|
|
|
def extract_objects(filepaths): |
|
images = [Image.open(p).convert("RGB") for p in filepaths] |
|
start_time = time.time() |
|
|
|
|
|
try: |
|
results = run_inference(images, birefnet, device) |
|
end_time = time.time() |
|
total_time = end_time - start_time |
|
summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully." |
|
return results, summary |
|
except torch.OutOfMemoryError: |
|
|
|
oom_time = time.time() |
|
initial_attempt_time = oom_time - start_time |
|
|
|
best, best_count = binary_search_max(images) |
|
end_time = time.time() |
|
total_time = end_time - start_time |
|
|
|
if best is None: |
|
|
|
summary = ( |
|
f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n" |
|
f"Could not process even a single image.\n" |
|
f"Total time including fallback attempts: {total_time:.2f}s." |
|
) |
|
return [], summary |
|
else: |
|
summary = ( |
|
f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n" |
|
f"Found that {best_count} images can be processed without OOM.\n" |
|
f"Total time including fallback attempts: {total_time:.2f}s.\n" |
|
f"Next time, try using up to {best_count} images." |
|
) |
|
return best, summary |
|
|
|
iface = gr.Interface( |
|
fn=extract_objects, |
|
inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"), |
|
outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")], |
|
title="BiRefNet Bulk Background Removal with On-Demand Fallback", |
|
description="Upload as many images as you want. If OOM occurs, a fallback will find the max feasible number. Extra cleanup steps and reinitialization for more consistent results." |
|
) |
|
|
|
iface.launch() |
|
|