multimodalart's picture
Update app.py
69b2b0d verified
raw
history blame
3.78 kB
import gradio as gr
import torch
from diffusers.utils import load_image
from controlnet_flux import FluxControlNetModel
from transformer_flux import FluxTransformer2DModel
from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
from PIL import Image, ImageDraw
import spaces
# Load models
controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dtype=torch.bfloat16
)
pipe = FluxControlNetInpaintingPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
controlnet=controlnet,
transformer=transformer,
torch_dtype=torch.bfloat16
).to("cuda")
pipe.transformer.to(torch.bfloat16)
pipe.controlnet.to(torch.bfloat16)
def prepare_image_and_mask(image, width, height, overlap_percentage):
# Resize the input image to fit within the target size
image.thumbnail((width, height), Image.LANCZOS)
# Create a new white background image of the target size
background = Image.new('RGB', (width, height), (255, 255, 255))
# Paste the resized image onto the background
offset = ((width - image.width) // 2, (height - image.height) // 2)
background.paste(image, offset)
# Create a mask
mask = Image.new('L', (width, height), 255)
draw = ImageDraw.Draw(mask)
# Calculate the overlap area
overlap_x = int(image.width * overlap_percentage / 100)
overlap_y = int(image.height * overlap_percentage / 100)
# Draw the mask (black area is where we want to inpaint)
draw.rectangle([
(offset[0] + overlap_x, offset[1] + overlap_y),
(offset[0] + image.width - overlap_x, offset[1] + image.height - overlap_y)
], fill=0)
return background, mask
@spaces.GPU
def inpaint(image, prompt, width, height, overlap_percentage, num_inference_steps, guidance_scale):
# Prepare image and mask
image, mask = prepare_image_and_mask(image, width, height, overlap_percentage)
# Set up generator for reproducibility
generator = torch.Generator(device="cuda").manual_seed(42)
# Run inpainting
result = pipe(
prompt=prompt,
height=height,
width=width,
control_image=image,
control_mask=mask,
num_inference_steps=num_inference_steps,
generator=generator,
controlnet_conditioning_scale=0.9,
guidance_scale=guidance_scale,
negative_prompt="",
true_guidance_scale=guidance_scale
).images[0]
return result
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# FLUX Outpainting Demo")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
prompt_input = gr.Textbox(label="Prompt")
width_slider = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=768)
height_slider = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=768)
overlap_slider = gr.Slider(label="Overlap Percentage", minimum=0, maximum=50, step=1, value=10)
steps_slider = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=28)
guidance_slider = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=10.0, step=0.1, value=3.5)
run_button = gr.Button("Generate")
with gr.Column():
output_image = gr.Image(label="Output Image")
run_button.click(
fn=inpaint,
inputs=[input_image, prompt_input, width_slider, height_slider, overlap_slider, steps_slider, guidance_slider],
outputs=output_image
)
demo.launch()