import gradio as gr import spaces import torch from loadimg import load_img from torchvision import transforms from transformers import AutoModelForImageSegmentation from diffusers import FluxFillPipeline from PIL import Image, ImageDraw from diffusers.utils import load_image torch.set_float32_matmul_precision(["high", "highest"][0]) birefnet = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) birefnet.to("cuda") transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) pipe = FluxFillPipeline.from_pretrained( "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16 ).to("cuda") def can_expand(source_width, source_height, target_width, target_height, alignment): if alignment in ("Left", "Right") and source_width >= target_width: return False if alignment in ("Top", "Bottom") and source_height >= target_height: return False return True def prepare_image_and_mask( image, width, height, overlap_percentage, resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom, ): target_size = (width, height) scale_factor = min(target_size[0] / image.width, target_size[1] / image.height) new_width = int(image.width * scale_factor) new_height = int(image.height * scale_factor) source = image.resize((new_width, new_height), Image.LANCZOS) resize_percentage = 50 # Calculate new dimensions based on percentage resize_factor = resize_percentage / 100 new_width = int(source.width * resize_factor) new_height = int(source.height * resize_factor) # Ensure minimum size of 64 pixels new_width = max(new_width, 64) new_height = max(new_height, 64) # Resize the image source = source.resize((new_width, new_height), Image.LANCZOS) # Calculate the overlap in pixels based on the percentage overlap_x = int(new_width * (overlap_percentage / 100)) overlap_y = int(new_height * (overlap_percentage / 100)) # Ensure minimum overlap of 1 pixel overlap_x = max(overlap_x, 1) overlap_y = max(overlap_y, 1) # Calculate margins based on alignment if alignment == "Middle": margin_x = (target_size[0] - new_width) // 2 margin_y = (target_size[1] - new_height) // 2 elif alignment == "Left": margin_x = 0 margin_y = (target_size[1] - new_height) // 2 elif alignment == "Right": margin_x = target_size[0] - new_width margin_y = (target_size[1] - new_height) // 2 elif alignment == "Top": margin_x = (target_size[0] - new_width) // 2 margin_y = 0 elif alignment == "Bottom": margin_x = (target_size[0] - new_width) // 2 margin_y = target_size[1] - new_height # Adjust margins to eliminate gaps margin_x = max(0, min(margin_x, target_size[0] - new_width)) margin_y = max(0, min(margin_y, target_size[1] - new_height)) # Create a new background image and paste the resized source image background = Image.new("RGB", target_size, (255, 255, 255)) background.paste(source, (margin_x, margin_y)) # Create the mask mask = Image.new("L", target_size, 255) mask_draw = ImageDraw.Draw(mask) # Calculate overlap areas white_gaps_patch = 2 left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch right_overlap = ( margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch ) top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch bottom_overlap = ( margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch ) if alignment == "Left": left_overlap = margin_x + overlap_x if overlap_left else margin_x elif alignment == "Right": right_overlap = ( margin_x + new_width - overlap_x if overlap_right else margin_x + new_width ) elif alignment == "Top": top_overlap = margin_y + overlap_y if overlap_top else margin_y elif alignment == "Bottom": bottom_overlap = ( margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height ) # Draw the mask mask_draw.rectangle( [(left_overlap, top_overlap), (right_overlap, bottom_overlap)], fill=0 ) return background, mask def inpaint( image, width, height, overlap_percentage, num_inference_steps, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom, progress=gr.Progress(track_tqdm=True), ): background, mask = prepare_image_and_mask( image, width, height, overlap_percentage, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom, ) if not can_expand(background.width, background.height, width, height, alignment): alignment = "Middle" cnet_image = background.copy() cnet_image.paste(0, (0, 0), mask) final_prompt = prompt_input # generator = torch.Generator(device="cuda").manual_seed(42) result = pipe( prompt=final_prompt, height=height, width=width, image=cnet_image, mask_image=mask, num_inference_steps=num_inference_steps, guidance_scale=30, ).images[0] result = result.convert("RGBA") cnet_image.paste(result, (0, 0), mask) return cnet_image @spaces.GPU def rmbg(image, url): if image is None: image = url image = load_img(image).convert("RGB") image_size = image.size input_images = transform_image(image).unsqueeze(0).to("cuda") # Prediction with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) image.putalpha(mask) return image def placeholder(img): return img rmbg_tab = gr.Interface( fn=rmbg, inputs=["image", "text"], outputs=["image"], api_name="rmbg" ) outpaint_tab = gr.Interface( fr=placeholder, inputs=["image"], outputs=["image"], api_name="outpainting" ) demo = gr.TabbedInterface( [rmbg_tab, outpaint_tab], ["remove background", "outpainting"], title="Utilities that require GPU", ) demo.launch()