|
import gradio as gr |
|
import spaces |
|
import torch |
|
from loadimg import load_img |
|
from torchvision import transforms |
|
from transformers import AutoModelForImageSegmentation |
|
from diffusers import FluxFillPipeline |
|
from PIL import Image, ImageDraw |
|
from diffusers.utils import load_image |
|
|
|
torch.set_float32_matmul_precision(["high", "highest"][0]) |
|
|
|
birefnet = AutoModelForImageSegmentation.from_pretrained( |
|
"ZhengPeng7/BiRefNet", trust_remote_code=True |
|
) |
|
birefnet.to("cuda") |
|
|
|
transform_image = transforms.Compose( |
|
[ |
|
transforms.Resize((1024, 1024)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
pipe = FluxFillPipeline.from_pretrained( |
|
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16 |
|
).to("cuda") |
|
|
|
|
|
def can_expand(source_width, source_height, target_width, target_height, alignment): |
|
if alignment in ("Left", "Right") and source_width >= target_width: |
|
return False |
|
if alignment in ("Top", "Bottom") and source_height >= target_height: |
|
return False |
|
return True |
|
|
|
|
|
def prepare_image_and_mask( |
|
image, |
|
width, |
|
height, |
|
overlap_percentage, |
|
resize_percentage, |
|
alignment, |
|
overlap_left, |
|
overlap_right, |
|
overlap_top, |
|
overlap_bottom, |
|
): |
|
target_size = (width, height) |
|
|
|
scale_factor = min(target_size[0] / image.width, target_size[1] / image.height) |
|
new_width = int(image.width * scale_factor) |
|
new_height = int(image.height * scale_factor) |
|
|
|
source = image.resize((new_width, new_height), Image.LANCZOS) |
|
|
|
resize_percentage = 50 |
|
|
|
|
|
resize_factor = resize_percentage / 100 |
|
new_width = int(source.width * resize_factor) |
|
new_height = int(source.height * resize_factor) |
|
|
|
|
|
new_width = max(new_width, 64) |
|
new_height = max(new_height, 64) |
|
|
|
|
|
source = source.resize((new_width, new_height), Image.LANCZOS) |
|
|
|
|
|
overlap_x = int(new_width * (overlap_percentage / 100)) |
|
overlap_y = int(new_height * (overlap_percentage / 100)) |
|
|
|
|
|
overlap_x = max(overlap_x, 1) |
|
overlap_y = max(overlap_y, 1) |
|
|
|
|
|
if alignment == "Middle": |
|
margin_x = (target_size[0] - new_width) // 2 |
|
margin_y = (target_size[1] - new_height) // 2 |
|
elif alignment == "Left": |
|
margin_x = 0 |
|
margin_y = (target_size[1] - new_height) // 2 |
|
elif alignment == "Right": |
|
margin_x = target_size[0] - new_width |
|
margin_y = (target_size[1] - new_height) // 2 |
|
elif alignment == "Top": |
|
margin_x = (target_size[0] - new_width) // 2 |
|
margin_y = 0 |
|
elif alignment == "Bottom": |
|
margin_x = (target_size[0] - new_width) // 2 |
|
margin_y = target_size[1] - new_height |
|
|
|
|
|
margin_x = max(0, min(margin_x, target_size[0] - new_width)) |
|
margin_y = max(0, min(margin_y, target_size[1] - new_height)) |
|
|
|
|
|
background = Image.new("RGB", target_size, (255, 255, 255)) |
|
background.paste(source, (margin_x, margin_y)) |
|
|
|
|
|
mask = Image.new("L", target_size, 255) |
|
mask_draw = ImageDraw.Draw(mask) |
|
|
|
|
|
white_gaps_patch = 2 |
|
|
|
left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch |
|
right_overlap = ( |
|
margin_x + new_width - overlap_x |
|
if overlap_right |
|
else margin_x + new_width - white_gaps_patch |
|
) |
|
top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch |
|
bottom_overlap = ( |
|
margin_y + new_height - overlap_y |
|
if overlap_bottom |
|
else margin_y + new_height - white_gaps_patch |
|
) |
|
|
|
if alignment == "Left": |
|
left_overlap = margin_x + overlap_x if overlap_left else margin_x |
|
elif alignment == "Right": |
|
right_overlap = ( |
|
margin_x + new_width - overlap_x if overlap_right else margin_x + new_width |
|
) |
|
elif alignment == "Top": |
|
top_overlap = margin_y + overlap_y if overlap_top else margin_y |
|
elif alignment == "Bottom": |
|
bottom_overlap = ( |
|
margin_y + new_height - overlap_y |
|
if overlap_bottom |
|
else margin_y + new_height |
|
) |
|
|
|
|
|
mask_draw.rectangle( |
|
[(left_overlap, top_overlap), (right_overlap, bottom_overlap)], fill=0 |
|
) |
|
|
|
return background, mask |
|
|
|
|
|
def inpaint( |
|
image, |
|
width, |
|
height, |
|
overlap_percentage, |
|
num_inference_steps, |
|
custom_resize_percentage, |
|
prompt_input, |
|
alignment, |
|
overlap_left, |
|
overlap_right, |
|
overlap_top, |
|
overlap_bottom, |
|
progress=gr.Progress(track_tqdm=True), |
|
): |
|
background, mask = prepare_image_and_mask( |
|
image, |
|
width, |
|
height, |
|
overlap_percentage, |
|
custom_resize_percentage, |
|
alignment, |
|
overlap_left, |
|
overlap_right, |
|
overlap_top, |
|
overlap_bottom, |
|
) |
|
|
|
if not can_expand(background.width, background.height, width, height, alignment): |
|
alignment = "Middle" |
|
|
|
cnet_image = background.copy() |
|
cnet_image.paste(0, (0, 0), mask) |
|
|
|
final_prompt = prompt_input |
|
|
|
|
|
|
|
result = pipe( |
|
prompt=final_prompt, |
|
height=height, |
|
width=width, |
|
image=cnet_image, |
|
mask_image=mask, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=30, |
|
).images[0] |
|
|
|
result = result.convert("RGBA") |
|
cnet_image.paste(result, (0, 0), mask) |
|
|
|
return cnet_image |
|
|
|
|
|
@spaces.GPU |
|
def rmbg(image, url): |
|
if image is None: |
|
image = url |
|
image = load_img(image).convert("RGB") |
|
image_size = image.size |
|
input_images = transform_image(image).unsqueeze(0).to("cuda") |
|
|
|
with torch.no_grad(): |
|
preds = birefnet(input_images)[-1].sigmoid().cpu() |
|
pred = preds[0].squeeze() |
|
pred_pil = transforms.ToPILImage()(pred) |
|
mask = pred_pil.resize(image_size) |
|
image.putalpha(mask) |
|
return image |
|
|
|
|
|
def placeholder(img): |
|
return img |
|
|
|
|
|
rmbg_tab = gr.Interface( |
|
fn=rmbg, inputs=["image", "text"], outputs=["image"], api_name="rmbg" |
|
) |
|
|
|
outpaint_tab = gr.Interface( |
|
fr=placeholder, inputs=["image"], outputs=["image"], api_name="outpainting" |
|
) |
|
|
|
demo = gr.TabbedInterface( |
|
[rmbg_tab, outpaint_tab], |
|
["remove background", "outpainting"], |
|
title="Utilities that require GPU", |
|
) |
|
|
|
|
|
demo.launch() |
|
|