gpu-utils / app.py
not-lain's picture
wip outpainting
aa16383
raw
history blame
6.7 kB
import gradio as gr
import spaces
import torch
from loadimg import load_img
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from diffusers import FluxFillPipeline
from PIL import Image, ImageDraw
from diffusers.utils import load_image
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
).to("cuda")
def can_expand(source_width, source_height, target_width, target_height, alignment):
if alignment in ("Left", "Right") and source_width >= target_width:
return False
if alignment in ("Top", "Bottom") and source_height >= target_height:
return False
return True
def prepare_image_and_mask(
image,
width,
height,
overlap_percentage,
resize_percentage,
alignment,
overlap_left,
overlap_right,
overlap_top,
overlap_bottom,
):
target_size = (width, height)
scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
new_width = int(image.width * scale_factor)
new_height = int(image.height * scale_factor)
source = image.resize((new_width, new_height), Image.LANCZOS)
resize_percentage = 50
# Calculate new dimensions based on percentage
resize_factor = resize_percentage / 100
new_width = int(source.width * resize_factor)
new_height = int(source.height * resize_factor)
# Ensure minimum size of 64 pixels
new_width = max(new_width, 64)
new_height = max(new_height, 64)
# Resize the image
source = source.resize((new_width, new_height), Image.LANCZOS)
# Calculate the overlap in pixels based on the percentage
overlap_x = int(new_width * (overlap_percentage / 100))
overlap_y = int(new_height * (overlap_percentage / 100))
# Ensure minimum overlap of 1 pixel
overlap_x = max(overlap_x, 1)
overlap_y = max(overlap_y, 1)
# Calculate margins based on alignment
if alignment == "Middle":
margin_x = (target_size[0] - new_width) // 2
margin_y = (target_size[1] - new_height) // 2
elif alignment == "Left":
margin_x = 0
margin_y = (target_size[1] - new_height) // 2
elif alignment == "Right":
margin_x = target_size[0] - new_width
margin_y = (target_size[1] - new_height) // 2
elif alignment == "Top":
margin_x = (target_size[0] - new_width) // 2
margin_y = 0
elif alignment == "Bottom":
margin_x = (target_size[0] - new_width) // 2
margin_y = target_size[1] - new_height
# Adjust margins to eliminate gaps
margin_x = max(0, min(margin_x, target_size[0] - new_width))
margin_y = max(0, min(margin_y, target_size[1] - new_height))
# Create a new background image and paste the resized source image
background = Image.new("RGB", target_size, (255, 255, 255))
background.paste(source, (margin_x, margin_y))
# Create the mask
mask = Image.new("L", target_size, 255)
mask_draw = ImageDraw.Draw(mask)
# Calculate overlap areas
white_gaps_patch = 2
left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
right_overlap = (
margin_x + new_width - overlap_x
if overlap_right
else margin_x + new_width - white_gaps_patch
)
top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
bottom_overlap = (
margin_y + new_height - overlap_y
if overlap_bottom
else margin_y + new_height - white_gaps_patch
)
if alignment == "Left":
left_overlap = margin_x + overlap_x if overlap_left else margin_x
elif alignment == "Right":
right_overlap = (
margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
)
elif alignment == "Top":
top_overlap = margin_y + overlap_y if overlap_top else margin_y
elif alignment == "Bottom":
bottom_overlap = (
margin_y + new_height - overlap_y
if overlap_bottom
else margin_y + new_height
)
# Draw the mask
mask_draw.rectangle(
[(left_overlap, top_overlap), (right_overlap, bottom_overlap)], fill=0
)
return background, mask
def inpaint(
image,
width,
height,
overlap_percentage,
num_inference_steps,
custom_resize_percentage,
prompt_input,
alignment,
overlap_left,
overlap_right,
overlap_top,
overlap_bottom,
progress=gr.Progress(track_tqdm=True),
):
background, mask = prepare_image_and_mask(
image,
width,
height,
overlap_percentage,
custom_resize_percentage,
alignment,
overlap_left,
overlap_right,
overlap_top,
overlap_bottom,
)
if not can_expand(background.width, background.height, width, height, alignment):
alignment = "Middle"
cnet_image = background.copy()
cnet_image.paste(0, (0, 0), mask)
final_prompt = prompt_input
# generator = torch.Generator(device="cuda").manual_seed(42)
result = pipe(
prompt=final_prompt,
height=height,
width=width,
image=cnet_image,
mask_image=mask,
num_inference_steps=num_inference_steps,
guidance_scale=30,
).images[0]
result = result.convert("RGBA")
cnet_image.paste(result, (0, 0), mask)
return cnet_image
@spaces.GPU
def rmbg(image, url):
if image is None:
image = url
image = load_img(image).convert("RGB")
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
def placeholder(img):
return img
rmbg_tab = gr.Interface(
fn=rmbg, inputs=["image", "text"], outputs=["image"], api_name="rmbg"
)
outpaint_tab = gr.Interface(
fr=placeholder, inputs=["image"], outputs=["image"], api_name="outpainting"
)
demo = gr.TabbedInterface(
[rmbg_tab, outpaint_tab],
["remove background", "outpainting"],
title="Utilities that require GPU",
)
demo.launch()