from typing import Tuple import requests import random import numpy as np import gradio as gr import spaces import torch from PIL import Image from diffusers import FluxInpaintPipeline from huggingface_hub import login import os import time MARKDOWN = """ # FLUX.1 Inpainting with lora """ MAX_SEED = np.iinfo(np.int32).max IMAGE_SIZE = 1024 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" HF_TOKEN = os.environ.get("HF_TOKEN") login(token=HF_TOKEN) class calculateDuration: def __init__(self, activity_name=""): self.activity_name = activity_name def __enter__(self): self.start_time = time.time() self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time)) print(f"Start time: {self.start_time_formatted}") return self def __exit__(self, exc_type, exc_value, traceback): self.end_time = time.time() self.elapsed_time = self.end_time - self.start_time self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time)) print(f"End time: {self.start_time_formatted}") if self.activity_name: print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") else: print(f"Elapsed time: {self.elapsed_time:.6f} seconds") def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image: image = image.convert("RGBA") data = image.getdata() new_data = [] for item in data: avg = sum(item[:3]) / 3 if avg < threshold: new_data.append((0, 0, 0, 0)) else: new_data.append(item) image.putdata(new_data) return image pipe = FluxInpaintPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE) def resize_image_dimensions( original_resolution_wh: Tuple[int, int], maximum_dimension: int = IMAGE_SIZE ) -> Tuple[int, int]: width, height = original_resolution_wh # if width <= maximum_dimension and height <= maximum_dimension: # width = width - (width % 32) # height = height - (height % 32) # return width, height if width > height: scaling_factor = maximum_dimension / width else: scaling_factor = maximum_dimension / height new_width = int(width * scaling_factor) new_height = int(height * scaling_factor) new_width = new_width - (new_width % 32) new_height = new_height - (new_height % 32) return new_width, new_height @spaces.GPU(duration=100) def process( input_image_editor: dict, lora_path: str, lora_weights: str, lora_scale: float, trigger_word: str, input_text: str, seed_slicer: int, randomize_seed_checkbox: bool, strength_slider: float, num_inference_steps_slider: int, progress=gr.Progress(track_tqdm=True) ): if not input_text: gr.Info("Please enter a text prompt.") return None, None image = input_image_editor['background'] mask = input_image_editor['layers'][0] if not image: gr.Info("Please upload an image.") return None, None if not mask: gr.Info("Please draw a mask on the image.") return None, None with calculateDuration("resize image"): width, height = resize_image_dimensions(original_resolution_wh=image.size) resized_image = image.resize((width, height), Image.LANCZOS) resized_mask = mask.resize((width, height), Image.LANCZOS) with calculateDuration("load lora"): pipe.load_lora_weights(lora_path, weight_name=lora_weights) if randomize_seed_checkbox: seed_slicer = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed_slicer) with calculateDuration("run pipe"): result = pipe( prompt=f"{input_text} {trigger_word}", image=resized_image, mask_image=resized_mask, width=width, height=height, strength=strength_slider, generator=generator, num_inference_steps=num_inference_steps_slider, joint_attention_kwargs={"scale": lora_scale}, ).images[0] return result, resized_mask with gr.Blocks() as demo: gr.Markdown(MARKDOWN) with gr.Row(): with gr.Column(): input_image_editor_component = gr.ImageEditor( label='Image', type='pil', sources=["upload", "webcam"], image_mode='RGB', layers=False, brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed")) with gr.Accordion("Prompt Settings", open=True): input_text_component = gr.Textbox( label="Inpaint prompt", show_label=True, max_lines=1, placeholder="Enter your prompt", ) trigger_word = gr.Textbox( label="Lora trigger word", show_label=True, max_lines=1, placeholder="Enter your lora trigger word here", value="a photo of TOK" ) submit_button_component = gr.Button( value='Submit', variant='primary', scale=0) with gr.Accordion("Lora Settings", open=True): lora_path = gr.Textbox( label="Lora path", show_label=True, max_lines=1, placeholder="Enter your lora's model path", ) lora_weights = gr.Textbox( label="Lora weights", show_label=True, max_lines=1, placeholder="Enter your lora weights name", ) lora_scale = gr.Slider( label="Lora scale", show_label=True, minimum=0, maximum=1, step=0.1, value=0.8, ) with gr.Accordion("Advanced Settings", open=True): seed_slicer_component = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ) randomize_seed_checkbox_component = gr.Checkbox( label="Randomize seed", value=True) with gr.Row(): strength_slider_component = gr.Slider( label="Strength", info="Indicates extent to transform the reference `image`. " "Must be between 0 and 1. `image` is used as a starting " "point and more noise is added the higher the `strength`.", minimum=0, maximum=1, step=0.01, value=0.85, ) num_inference_steps_slider_component = gr.Slider( label="Number of inference steps", info="The number of denoising steps. More denoising steps " "usually lead to a higher quality image at the", minimum=1, maximum=50, step=1, value=20, ) with gr.Column(): output_image_component = gr.Image( type='pil', image_mode='RGB', label='Generated image', format="png") with gr.Accordion("Debug", open=False): output_mask_component = gr.Image( type='pil', image_mode='RGB', label='Input mask', format="png") submit_button_component.click( fn=process, inputs=[ input_image_editor_component, lora_path, lora_weights, lora_scale, trigger_word, input_text_component, seed_slicer_component, randomize_seed_checkbox_component, strength_slider_component, num_inference_steps_slider_component ], outputs=[ output_image_component, output_mask_component ] ) demo.launch(debug=False, show_error=True)