##!/usr/bin/python3
# -*- coding: utf-8 -*-
import os, random, sys
import numpy as np
import requests
import torch
import spaces
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download, snapshot_download
from scipy.ndimage import binary_dilation, binary_erosion
from transformers import (LlavaNextProcessor, LlavaNextForConditionalGeneration,
Qwen2VLForConditionalGeneration, Qwen2VLProcessor)
from segment_anything import SamPredictor, build_sam, SamAutomaticMaskGenerator
from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
from app.src.vlm_pipeline import (
vlm_response_editing_type,
vlm_response_object_wait_for_edit,
vlm_response_mask,
vlm_response_prompt_after_apply_instruction
)
from app.src.brushedit_all_in_one_pipeline import BrushEdit_Pipeline
from app.utils.utils import load_grounding_dino_model
from app.src.vlm_template import vlms_template
from app.src.base_model_template import base_models_template
from app.src.aspect_ratio_template import aspect_ratios
from openai import OpenAI
# base_openai_url = ""
#### Description ####
logo = r"""
"""
head = r"""
BrushEdit: All-In-One Image Inpainting and Editing
"""
descriptions = r"""
Official Gradio Demo for BrushEdit: All-In-One Image Inpainting and Editing
🧙 BrushEdit enables precise, user-friendly instruction-based image editing via a inpainting model.
"""
instructions = r"""
Currently, we support two modes: fully automated command editing and interactive command editing.
🛠️ Fully automated instruction-based editing:
- ⭐️ 1.Choose Image: Upload or select one image from Example.
- ⭐️ 2.Input ⌨️ Instructions: Input the instructions (supports addition, deletion, and modification), e.g. remove xxx .
- ⭐️ 3.Run: Click 💫 Run button to automatic edit image.
🛠️ Interactive instruction-based editing:
- ⭐️ 1.Choose Image: Upload or select one image from Example.
- ⭐️ 2.Finely Brushing: Use a brush to outline the area you want to edit. And You can also use the eraser to restore.
- ⭐️ 3.Input ⌨️ Instructions: Input the instructions.
- ⭐️ 4.Run: Click 💫 Run button to automatic edit image.
We strongly recommend using GPT-4o for reasoning. After selecting the VLM model as gpt4-o, enter the API KEY and click the Submit and Verify button. If the output is success, you can use gpt4-o normally. Secondarily, we recommend using the Qwen2VL model.
We recommend zooming out in your browser for a better viewing range and experience.
For more detailed feature descriptions, see the bottom.
☕️ Have fun! 🎄 Wishing you a merry Christmas!
"""
tips = r"""
💡 Some Tips:
- 🤠 After input the instructions, you can click the Generate Mask button. The mask generated by VLM will be displayed in the preview panel on the right side.
- 🤠 After generating the mask or when you use the brush to draw the mask, you can perform operations such as randomization, dilation, erosion, and movement.
- 🤠 After input the instructions, you can click the Generate Target Prompt button. The target prompt will be displayed in the text box, and you can modify it according to your ideas.
💡 Detailed Features:
- 🎨 Aspect Ratio: Select the aspect ratio of the image. To prevent OOM, 1024px is the maximum resolution.
- 🎨 VLM Model: Select the VLM model. We use preloaded models to save time. To use other VLM models, download them and uncomment the relevant lines in vlm_template.py from our GitHub repo.
- 🎨 Generate Mask: According to the input instructions, generate a mask for the area that may need to be edited.
- 🎨 Square/Circle Mask: Based on the existing mask, generate masks for squares and circles. (The coarse-grained mask provides more editing imagination.)
- 🎨 Invert Mask: Invert the mask to generate a new mask.
- 🎨 Dilation/Erosion Mask: Expand or shrink the mask to include or exclude more areas.
- 🎨 Move Mask: Move the mask to a new position.
- 🎨 Generate Target Prompt: Generate a target prompt based on the input instructions.
- 🎨 Target Prompt: Description for masking area, manual input or modification can be made when the content generated by VLM does not meet expectations.
- 🎨 Blending: Blending brushnet's output and the original input, ensuring the original image details in the unedited areas. (turn off is beeter when removing.)
- 🎨 Control length: The intensity of editing and inpainting.
💡 Advanced Features:
- 🎨 Base Model: We use preloaded models to save time. To use other VLM models, download them and uncomment the relevant lines in vlm_template.py from our GitHub repo.
- 🎨 Blending: Blending brushnet's output and the original input, ensuring the original image details in the unedited areas. (turn off is beeter when removing.)
- 🎨 Control length: The intensity of editing and inpainting.
- 🎨 Num samples: The number of samples to generate.
- 🎨 Negative prompt: The negative prompt for the classifier-free guidance.
- 🎨 Guidance scale: The guidance scale for the classifier-free guidance.
"""
citation = r"""
If BrushEdit is helpful, please help to ⭐ the Github Repo. Thanks!
[![GitHub Stars](https://img.shields.io/github/stars/TencentARC/BrushEdit?style=social)](https://github.com/TencentARC/BrushEdit)
---
📝 **Citation**
If our work is useful for your research, please consider citing:
```bibtex
@misc{li2024brushedit,
title={BrushEdit: All-In-One Image Inpainting and Editing},
author={Yaowei Li and Yuxuan Bian and Xuan Ju and Zhaoyang Zhang and and Junhao Zhuang and Ying Shan and Yuexian Zou and Qiang Xu},
year={2024},
eprint={2412.10316},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
📧 **Contact**
If you have any questions, please feel free to reach me out at liyaowei@gmail.com.
"""
# - - - - - examples - - - - - #
EXAMPLES = [
[
Image.open("./assets/frog/frog.jpeg").convert("RGBA"),
"add a magic hat on frog head.",
642087011,
"frog",
"frog",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/chinese_girl/chinese_girl.png").convert("RGBA"),
"replace the background to ancient China.",
648464818,
"chinese_girl",
"chinese_girl",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/angel_christmas/angel_christmas.png").convert("RGBA"),
"remove the deer.",
648464818,
"angel_christmas",
"angel_christmas",
False,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/sunflower_girl/sunflower_girl.png").convert("RGBA"),
"add a wreath on head.",
648464818,
"sunflower_girl",
"sunflower_girl",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/girl_on_sun/girl_on_sun.png").convert("RGBA"),
"add a butterfly fairy.",
648464818,
"girl_on_sun",
"girl_on_sun",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/spider_man_rm/spider_man.png").convert("RGBA"),
"remove the christmas hat.",
642087011,
"spider_man_rm",
"spider_man_rm",
False,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/anime_flower/anime_flower.png").convert("RGBA"),
"remove the flower.",
642087011,
"anime_flower",
"anime_flower",
False,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/chenduling/chengduling.jpg").convert("RGBA"),
"replace the clothes to a delicated floral skirt.",
648464818,
"chenduling",
"chenduling",
True,
False,
"GPT4-o (Highly Recommended)"
],
[
Image.open("./assets/hedgehog_rp_bg/hedgehog.png").convert("RGBA"),
"make the hedgehog in Italy.",
648464818,
"hedgehog_rp_bg",
"hedgehog_rp_bg",
True,
False,
"GPT4-o (Highly Recommended)"
],
]
INPUT_IMAGE_PATH = {
"frog": "./assets/frog/frog.jpeg",
"chinese_girl": "./assets/chinese_girl/chinese_girl.png",
"angel_christmas": "./assets/angel_christmas/angel_christmas.png",
"sunflower_girl": "./assets/sunflower_girl/sunflower_girl.png",
"girl_on_sun": "./assets/girl_on_sun/girl_on_sun.png",
"spider_man_rm": "./assets/spider_man_rm/spider_man.png",
"anime_flower": "./assets/anime_flower/anime_flower.png",
"chenduling": "./assets/chenduling/chengduling.jpg",
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/hedgehog.png",
}
MASK_IMAGE_PATH = {
"frog": "./assets/frog/mask_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
"chinese_girl": "./assets/chinese_girl/mask_54759648-0989-48e0-bc82-f20e28b5ec29.png",
"angel_christmas": "./assets/angel_christmas/mask_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
"sunflower_girl": "./assets/sunflower_girl/mask_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
"girl_on_sun": "./assets/girl_on_sun/mask_264eac8b-8b65-479c-9755-020a60880c37.png",
"spider_man_rm": "./assets/spider_man_rm/mask_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
"anime_flower": "./assets/anime_flower/mask_37553172-9b38-4727-bf2e-37d7e2b93461.png",
"chenduling": "./assets/chenduling/mask_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/mask_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
}
MASKED_IMAGE_PATH = {
"frog": "./assets/frog/masked_image_f7b350de-6f2c-49e3-b535-995c486d78e7.png",
"chinese_girl": "./assets/chinese_girl/masked_image_54759648-0989-48e0-bc82-f20e28b5ec29.png",
"angel_christmas": "./assets/angel_christmas/masked_image_f15d9b45-c978-4e3d-9f5f-251e308560c3.png",
"sunflower_girl": "./assets/sunflower_girl/masked_image_99cc50b4-7dc4-4de5-8748-ec10772f0317.png",
"girl_on_sun": "./assets/girl_on_sun/masked_image_264eac8b-8b65-479c-9755-020a60880c37.png",
"spider_man_rm": "./assets/spider_man_rm/masked_image_a5d410e6-8e8d-432f-8144-defbc3e1eae9.png",
"anime_flower": "./assets/anime_flower/masked_image_37553172-9b38-4727-bf2e-37d7e2b93461.png",
"chenduling": "./assets/chenduling/masked_image_68e3ff6f-da07-4b37-91df-13d6eed7b997.png",
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/masked_image_db7f8bf8-8349-46d3-b14e-43d67fbe25d3.png",
}
OUTPUT_IMAGE_PATH = {
"frog": "./assets/frog/image_edit_f7b350de-6f2c-49e3-b535-995c486d78e7_1.png",
"chinese_girl": "./assets/chinese_girl/image_edit_54759648-0989-48e0-bc82-f20e28b5ec29_1.png",
"angel_christmas": "./assets/angel_christmas/image_edit_f15d9b45-c978-4e3d-9f5f-251e308560c3_0.png",
"sunflower_girl": "./assets/sunflower_girl/image_edit_99cc50b4-7dc4-4de5-8748-ec10772f0317_3.png",
"girl_on_sun": "./assets/girl_on_sun/image_edit_264eac8b-8b65-479c-9755-020a60880c37_0.png",
"spider_man_rm": "./assets/spider_man_rm/image_edit_a5d410e6-8e8d-432f-8144-defbc3e1eae9_0.png",
"anime_flower": "./assets/anime_flower/image_edit_37553172-9b38-4727-bf2e-37d7e2b93461_2.png",
"chenduling": "./assets/chenduling/image_edit_68e3ff6f-da07-4b37-91df-13d6eed7b997_0.png",
"hedgehog_rp_bg": "./assets/hedgehog_rp_bg/image_edit_db7f8bf8-8349-46d3-b14e-43d67fbe25d3_3.png",
}
# os.environ['GRADIO_TEMP_DIR'] = 'gradio_temp_dir'
# os.makedirs('gradio_temp_dir', exist_ok=True)
VLM_MODEL_NAMES = list(vlms_template.keys())
DEFAULT_VLM_MODEL_NAME = "Qwen2-VL-7B-Instruct (Default)"
BASE_MODELS = list(base_models_template.keys())
DEFAULT_BASE_MODEL = "realisticVision (Default)"
ASPECT_RATIO_LABELS = list(aspect_ratios)
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
## init device
try:
if torch.cuda.is_available():
device = "cuda"
elif sys.platform == "darwin" and torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
except:
device = "cpu"
# ## init torch dtype
# if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
# torch_dtype = torch.bfloat16
# else:
# torch_dtype = torch.float16
# if device == "mps":
# torch_dtype = torch.float16
torch_dtype = torch.float16
# download hf models
BrushEdit_path = "models/"
if not os.path.exists(BrushEdit_path):
BrushEdit_path = snapshot_download(
repo_id="TencentARC/BrushEdit",
local_dir=BrushEdit_path,
token=os.getenv("HF_TOKEN"),
)
## init default VLM
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_MODEL_NAME]
if vlm_processor != "" and vlm_model != "":
vlm_model.to(device)
else:
raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
## init base model
base_model_path = os.path.join(BrushEdit_path, "base_model/realisticVisionV60B1_v51VAE")
brushnet_path = os.path.join(BrushEdit_path, "brushnetX")
sam_path = os.path.join(BrushEdit_path, "sam/sam_vit_h_4b8939.pth")
groundingdino_path = os.path.join(BrushEdit_path, "grounding_dino/groundingdino_swint_ogc.pth")
# input brushnetX ckpt path
brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch_dtype)
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
)
# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed or when using Torch 2.0.
# pipe.enable_xformers_memory_efficient_attention()
pipe.enable_model_cpu_offload()
## init SAM
sam = build_sam(checkpoint=sam_path)
sam.to(device=device)
sam_predictor = SamPredictor(sam)
sam_automask_generator = SamAutomaticMaskGenerator(sam)
## init groundingdino_model
config_file = 'app/utils/GroundingDINO_SwinT_OGC.py'
groundingdino_model = load_grounding_dino_model(config_file, groundingdino_path, device=device)
## Ordinary function
def crop_and_resize(image: Image.Image,
target_width: int,
target_height: int) -> Image.Image:
"""
Crops and resizes an image while preserving the aspect ratio.
Args:
image (Image.Image): Input PIL image to be cropped and resized.
target_width (int): Target width of the output image.
target_height (int): Target height of the output image.
Returns:
Image.Image: Cropped and resized image.
"""
# Original dimensions
original_width, original_height = image.size
original_aspect = original_width / original_height
target_aspect = target_width / target_height
# Calculate crop box to maintain aspect ratio
if original_aspect > target_aspect:
# Crop horizontally
new_width = int(original_height * target_aspect)
new_height = original_height
left = (original_width - new_width) / 2
top = 0
right = left + new_width
bottom = original_height
else:
# Crop vertically
new_width = original_width
new_height = int(original_width / target_aspect)
left = 0
top = (original_height - new_height) / 2
right = original_width
bottom = top + new_height
# Crop and resize
cropped_image = image.crop((left, top, right, bottom))
resized_image = cropped_image.resize((target_width, target_height), Image.NEAREST)
return resized_image
## Ordinary function
def resize(image: Image.Image,
target_width: int,
target_height: int) -> Image.Image:
"""
Crops and resizes an image while preserving the aspect ratio.
Args:
image (Image.Image): Input PIL image to be cropped and resized.
target_width (int): Target width of the output image.
target_height (int): Target height of the output image.
Returns:
Image.Image: Cropped and resized image.
"""
# Original dimensions
resized_image = image.resize((target_width, target_height), Image.NEAREST)
return resized_image
def move_mask_func(mask, direction, units):
binary_mask = mask.squeeze()>0
rows, cols = binary_mask.shape
moved_mask = np.zeros_like(binary_mask, dtype=bool)
if direction == 'down':
# move down
moved_mask[max(0, units):, :] = binary_mask[:rows - units, :]
elif direction == 'up':
# move up
moved_mask[:rows - units, :] = binary_mask[units:, :]
elif direction == 'right':
# move left
moved_mask[:, max(0, units):] = binary_mask[:, :cols - units]
elif direction == 'left':
# move right
moved_mask[:, :cols - units] = binary_mask[:, units:]
return moved_mask
def random_mask_func(mask, dilation_type='square', dilation_size=20):
# Randomly select the size of dilation
binary_mask = mask.squeeze()>0
if dilation_type == 'square_dilation':
structure = np.ones((dilation_size, dilation_size), dtype=bool)
dilated_mask = binary_dilation(binary_mask, structure=structure)
elif dilation_type == 'square_erosion':
structure = np.ones((dilation_size, dilation_size), dtype=bool)
dilated_mask = binary_erosion(binary_mask, structure=structure)
elif dilation_type == 'bounding_box':
# find the most left top and left bottom point
rows, cols = np.where(binary_mask)
if len(rows) == 0 or len(cols) == 0:
return mask # return original mask if no valid points
min_row = np.min(rows)
max_row = np.max(rows)
min_col = np.min(cols)
max_col = np.max(cols)
# create a bounding box
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
dilated_mask[min_row:max_row + 1, min_col:max_col + 1] = True
elif dilation_type == 'bounding_ellipse':
# find the most left top and left bottom point
rows, cols = np.where(binary_mask)
if len(rows) == 0 or len(cols) == 0:
return mask # return original mask if no valid points
min_row = np.min(rows)
max_row = np.max(rows)
min_col = np.min(cols)
max_col = np.max(cols)
# calculate the center and axis length of the ellipse
center = ((min_col + max_col) // 2, (min_row + max_row) // 2)
a = (max_col - min_col) // 2 # half long axis
b = (max_row - min_row) // 2 # half short axis
# create a bounding ellipse
y, x = np.ogrid[:mask.shape[0], :mask.shape[1]]
ellipse_mask = ((x - center[0])**2 / a**2 + (y - center[1])**2 / b**2) <= 1
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
dilated_mask[ellipse_mask] = True
else:
ValueError("dilation_type must be 'square' or 'ellipse'")
# use binary dilation
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
return dilated_mask
## Gradio component function
def update_vlm_model(vlm_name):
global vlm_model, vlm_processor
if vlm_model is not None:
del vlm_model
torch.cuda.empty_cache()
vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[vlm_name]
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via vlm_template.py
if vlm_type == "llava-next":
if vlm_processor != "" and vlm_model != "":
vlm_model.to(device)
return vlm_model_dropdown
else:
if os.path.exists(vlm_local_path):
vlm_processor = LlavaNextProcessor.from_pretrained(vlm_local_path)
vlm_model = LlavaNextForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
else:
if vlm_name == "llava-v1.6-mistral-7b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype="auto", device_map="auto")
elif vlm_name == "llama3-llava-next-8b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llama3-llava-next-8b-hf", torch_dtype="auto", device_map="auto")
elif vlm_name == "llava-v1.6-vicuna-13b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype="auto", device_map="auto")
elif vlm_name == "llava-v1.6-34b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-34b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-34b-hf", torch_dtype="auto", device_map="auto")
elif vlm_name == "llava-next-72b-hf (Preload)":
vlm_processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-next-72b-hf")
vlm_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-next-72b-hf", torch_dtype="auto", device_map="auto")
elif vlm_type == "qwen2-vl":
if vlm_processor != "" and vlm_model != "":
vlm_model.to(device)
return vlm_model_dropdown
else:
if os.path.exists(vlm_local_path):
vlm_processor = Qwen2VLProcessor.from_pretrained(vlm_local_path)
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained(vlm_local_path, torch_dtype="auto", device_map="auto")
else:
if vlm_name == "qwen2-vl-2b-instruct (Preload)":
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
elif vlm_name == "qwen2-vl-7b-instruct (Preload)":
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto")
elif vlm_name == "qwen2-vl-72b-instruct (Preload)":
vlm_processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-72B-Instruct")
vlm_model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-72B-Instruct", torch_dtype="auto", device_map="auto")
elif vlm_type == "openai":
pass
return "success"
def update_base_model(base_model_name):
global pipe
## we recommend using preload models, otherwise it will take a long time to download the model. you can edit the code via base_model_template.py
if pipe is not None:
del pipe
torch.cuda.empty_cache()
base_model_path, pipe = base_models_template[base_model_name]
if pipe != "":
pipe.to(device)
else:
if os.path.exists(base_model_path):
pipe = StableDiffusionBrushNetPipeline.from_pretrained(
base_model_path, brushnet=brushnet, torch_dtype=torch_dtype, low_cpu_mem_usage=False
)
# pipe.enable_xformers_memory_efficient_attention()
pipe.enable_model_cpu_offload()
else:
raise gr.Error(f"The base model {base_model_name} does not exist")
return "success"
def submit_GPT4o_KEY(GPT4o_KEY):
global vlm_model, vlm_processor
if vlm_model is not None:
del vlm_model
torch.cuda.empty_cache()
try:
vlm_model = OpenAI(api_key=GPT4o_KEY)
vlm_processor = ""
response = vlm_model.chat.completions.create(
model="gpt-4o-2024-08-06",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Say this is a test"}
]
)
response_str = response.choices[0].message.content
return "Success, " + response_str, "GPT4-o (Highly Recommended)"
except Exception as e:
return "Invalid GPT4o API Key", "GPT4-o (Highly Recommended)"
@spaces.GPU(duration=180)
def process(input_image,
original_image,
original_mask,
prompt,
negative_prompt,
control_strength,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
num_samples,
blending,
category,
target_prompt,
resize_default,
aspect_ratio_name,
invert_mask_state):
if original_image is None:
if input_image is None:
raise gr.Error('Please upload the input image')
else:
image_pil = input_image["background"].convert("RGB")
original_image = np.array(image_pil)
if prompt is None or prompt == "":
if target_prompt is None or target_prompt == "":
raise gr.Error("Please input your instructions, e.g., remove the xxx")
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if invert_mask_state:
original_mask = original_mask
else:
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
## inpainting directly if target_prompt is not None
if category is not None:
pass
elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
pass
else:
try:
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
except Exception as e:
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
if original_mask is not None:
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
else:
try:
object_wait_for_edit = vlm_response_object_wait_for_edit(
vlm_processor,
vlm_model,
original_image,
category,
prompt,
device)
original_mask = vlm_response_mask(vlm_processor,
vlm_model,
category,
original_image,
prompt,
object_wait_for_edit,
sam,
sam_predictor,
sam_automask_generator,
groundingdino_model,
device)
except Exception as e:
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
if target_prompt is not None and len(target_prompt) >= 1:
prompt_after_apply_instruction = target_prompt
else:
try:
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
vlm_processor,
vlm_model,
original_image,
prompt,
device)
except Exception as e:
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
with torch.autocast(device):
image, mask_image, mask_np, init_image_np = BrushEdit_Pipeline(pipe,
prompt_after_apply_instruction,
original_mask,
original_image,
generator,
num_inference_steps,
guidance_scale,
control_strength,
negative_prompt,
num_samples,
blending)
original_image = np.array(init_image_np)
masked_image = original_image * (1 - (mask_np>0))
masked_image = masked_image.astype(np.uint8)
masked_image = Image.fromarray(masked_image)
# Save the images (optional)
# import uuid
# uuid = str(uuid.uuid4())
# image[0].save(f"outputs/image_edit_{uuid}_0.png")
# image[1].save(f"outputs/image_edit_{uuid}_1.png")
# image[2].save(f"outputs/image_edit_{uuid}_2.png")
# image[3].save(f"outputs/image_edit_{uuid}_3.png")
# mask_image.save(f"outputs/mask_{uuid}.png")
# masked_image.save(f"outputs/masked_image_{uuid}.png")
# gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=16)
return image, [mask_image], [masked_image], prompt, '', False
def generate_target_prompt(input_image,
original_image,
prompt):
# load example image
if isinstance(original_image, str):
original_image = input_image
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
vlm_processor,
vlm_model,
original_image,
prompt,
device)
return prompt_after_apply_instruction
def process_mask(input_image,
original_image,
prompt,
resize_default,
aspect_ratio_name):
if original_image is None:
raise gr.Error('Please upload the input image')
if prompt is None:
raise gr.Error("Please input your instructions, e.g., remove the xxx")
## load mask
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.array(alpha_mask)
# load example image
if isinstance(original_image, str):
original_image = input_image["background"]
if input_mask.max() == 0:
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
object_wait_for_edit = vlm_response_object_wait_for_edit(vlm_processor,
vlm_model,
original_image,
category,
prompt,
device)
# original mask: h,w,1 [0, 255]
original_mask = vlm_response_mask(
vlm_processor,
vlm_model,
category,
original_image,
prompt,
object_wait_for_edit,
sam,
sam_predictor,
sam_automask_generator,
groundingdino_model,
device)
else:
original_mask = input_mask
category = None
## resize mask if needed
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
mask_image = Image.fromarray(original_mask.squeeze().astype(np.uint8)).convert("RGB")
masked_image = original_image * (1 - (original_mask>0))
masked_image = masked_image.astype(np.uint8)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], original_mask.astype(np.uint8), category
def process_random_mask(input_image,
original_image,
original_mask,
resize_default,
aspect_ratio_name,
):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
dilation_type = np.random.choice(['bounding_box', 'bounding_ellipse'])
random_mask = random_mask_func(original_mask, dilation_type).squeeze()
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
def process_dilation_mask(input_image,
original_image,
original_mask,
resize_default,
aspect_ratio_name,
dilation_size=20):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
dilation_type = np.random.choice(['square_dilation'])
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
def process_erosion_mask(input_image,
original_image,
original_mask,
resize_default,
aspect_ratio_name,
dilation_size=20):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
dilation_type = np.random.choice(['square_erosion'])
random_mask = random_mask_func(original_mask, dilation_type, dilation_size).squeeze()
mask_image = Image.fromarray(random_mask.astype(np.uint8)).convert("RGB")
masked_image = original_image * (1 - (random_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], random_mask[:,:,None].astype(np.uint8)
def move_mask_left(input_image,
original_image,
original_mask,
moving_pixels,
resize_default,
aspect_ratio_name):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
moved_mask = move_mask_func(original_mask, 'left', int(moving_pixels)).squeeze()
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
if moved_mask.max() <= 1:
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
original_mask = moved_mask
return [masked_image], [mask_image], original_mask.astype(np.uint8)
def move_mask_right(input_image,
original_image,
original_mask,
moving_pixels,
resize_default,
aspect_ratio_name):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
moved_mask = move_mask_func(original_mask, 'right', int(moving_pixels)).squeeze()
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
if moved_mask.max() <= 1:
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
original_mask = moved_mask
return [masked_image], [mask_image], original_mask.astype(np.uint8)
def move_mask_up(input_image,
original_image,
original_mask,
moving_pixels,
resize_default,
aspect_ratio_name):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
moved_mask = move_mask_func(original_mask, 'up', int(moving_pixels)).squeeze()
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
if moved_mask.max() <= 1:
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
original_mask = moved_mask
return [masked_image], [mask_image], original_mask.astype(np.uint8)
def move_mask_down(input_image,
original_image,
original_mask,
moving_pixels,
resize_default,
aspect_ratio_name):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
output_w, output_h = aspect_ratios[aspect_ratio_name]
if output_w == "" or output_h == "":
output_h, output_w = original_image.shape[:2]
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
else:
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
pass
else:
if resize_default:
short_side = min(output_w, output_h)
scale_ratio = 640 / short_side
output_w = int(output_w * scale_ratio)
output_h = int(output_h * scale_ratio)
gr.Info(f"Output aspect ratio: {output_w}:{output_h}")
original_image = resize(Image.fromarray(original_image), target_width=int(output_w), target_height=int(output_h))
original_image = np.array(original_image)
if input_mask is not None:
input_mask = resize(Image.fromarray(np.squeeze(input_mask)), target_width=int(output_w), target_height=int(output_h))
input_mask = np.array(input_mask)
if original_mask is not None:
original_mask = resize(Image.fromarray(np.squeeze(original_mask)), target_width=int(output_w), target_height=int(output_h))
original_mask = np.array(original_mask)
if input_mask.max() == 0:
original_mask = original_mask
else:
original_mask = input_mask
if original_mask is None:
raise gr.Error('Please generate mask first')
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
moved_mask = move_mask_func(original_mask, 'down', int(moving_pixels)).squeeze()
mask_image = Image.fromarray(((moved_mask>0).astype(np.uint8)*255)).convert("RGB")
masked_image = original_image * (1 - (moved_mask[:,:,None]>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
if moved_mask.max() <= 1:
moved_mask = ((moved_mask * 255)[:,:,None]).astype(np.uint8)
original_mask = moved_mask
return [masked_image], [mask_image], original_mask.astype(np.uint8)
def invert_mask(input_image,
original_image,
original_mask,
):
alpha_mask = input_image["layers"][0].split()[3]
input_mask = np.asarray(alpha_mask)
if input_mask.max() == 0:
original_mask = 1 - (original_mask>0).astype(np.uint8)
else:
original_mask = 1 - (input_mask>0).astype(np.uint8)
if original_mask is None:
raise gr.Error('Please generate mask first')
original_mask = original_mask.squeeze()
mask_image = Image.fromarray(original_mask*255).convert("RGB")
if original_mask.ndim == 2:
original_mask = original_mask[:,:,None]
if original_mask.max() <= 1:
original_mask = (original_mask * 255).astype(np.uint8)
masked_image = original_image * (1 - (original_mask>0))
masked_image = masked_image.astype(original_image.dtype)
masked_image = Image.fromarray(masked_image)
return [masked_image], [mask_image], original_mask, True
def init_img(base,
init_type,
prompt,
aspect_ratio,
example_change_times
):
image_pil = base["background"].convert("RGB")
original_image = np.array(image_pil)
if max(original_image.shape[0], original_image.shape[1]) * 1.0 / min(original_image.shape[0], original_image.shape[1])>2.0:
raise gr.Error('image aspect ratio cannot be larger than 2.0')
if init_type in MASK_IMAGE_PATH.keys() and example_change_times < 2:
mask_gallery = [Image.open(MASK_IMAGE_PATH[init_type]).convert("L")]
masked_gallery = [Image.open(MASKED_IMAGE_PATH[init_type]).convert("RGB")]
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[init_type]).convert("RGB")]
width, height = image_pil.size
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
image_pil = image_pil.resize((width_new, height_new))
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "Custom resolution", False, False, example_change_times
else:
if aspect_ratio not in ASPECT_RATIO_LABELS:
aspect_ratio = "Custom resolution"
return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
def reset_func(input_image,
original_image,
original_mask,
prompt,
target_prompt,
):
input_image = None
original_image = None
original_mask = None
prompt = ''
mask_gallery = []
masked_gallery = []
result_gallery = []
target_prompt = ''
if torch.cuda.is_available():
torch.cuda.empty_cache()
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
def update_example(example_type,
prompt,
example_change_times):
input_image = INPUT_IMAGE_PATH[example_type]
image_pil = Image.open(input_image).convert("RGB")
mask_gallery = [Image.open(MASK_IMAGE_PATH[example_type]).convert("L")]
masked_gallery = [Image.open(MASKED_IMAGE_PATH[example_type]).convert("RGB")]
result_gallery = [Image.open(OUTPUT_IMAGE_PATH[example_type]).convert("RGB")]
width, height = image_pil.size
image_processor = VaeImageProcessor(vae_scale_factor=pipe.vae_scale_factor, do_convert_rgb=True)
height_new, width_new = image_processor.get_default_height_width(image_pil, height, width)
image_pil = image_pil.resize((width_new, height_new))
mask_gallery[0] = mask_gallery[0].resize((width_new, height_new))
masked_gallery[0] = masked_gallery[0].resize((width_new, height_new))
result_gallery[0] = result_gallery[0].resize((width_new, height_new))
original_image = np.array(image_pil)
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
aspect_ratio = "Custom resolution"
example_change_times += 1
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
block = gr.Blocks(
theme=gr.themes.Soft(
radius_size=gr.themes.sizes.radius_none,
text_size=gr.themes.sizes.text_md
)
)
with block as demo:
with gr.Row():
with gr.Column():
gr.HTML(head)
gr.Markdown(descriptions)
with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
with gr.Row(equal_height=True):
gr.Markdown(instructions)
original_image = gr.State(value=None)
original_mask = gr.State(value=None)
category = gr.State(value=None)
status = gr.State(value=None)
invert_mask_state = gr.State(value=False)
example_change_times = gr.State(value=0)
with gr.Row():
with gr.Column():
with gr.Row():
input_image = gr.ImageEditor(
label="Input Image",
type="pil",
brush=gr.Brush(colors=["#FFFFFF"], default_size = 30, color_mode="fixed"),
layers = False,
interactive=True,
height=1024,
sources=["upload"],
)
prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
run_button = gr.Button("💫 Run")
vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
with gr.Group():
with gr.Row():
GPT4o_KEY = gr.Textbox(label="GPT4o API Key", placeholder="Please input your GPT4o API Key when use GPT4o VLM (highly recommended).", value="", lines=1)
GPT4o_KEY_submit = gr.Button("Submit and Verify")
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
with gr.Row():
mask_button = gr.Button("Generate Mask")
random_mask_button = gr.Button("Square/Circle Mask ")
with gr.Row():
generate_target_prompt_button = gr.Button("Generate Target Prompt")
target_prompt = gr.Text(
label="Input Target Prompt",
max_lines=5,
placeholder="VLM-generated target prompt, you can first generate if and then modify it (optional)",
value='',
lines=2
)
with gr.Accordion("Advanced Options", open=False, elem_id="accordion1"):
base_model_dropdown = gr.Dropdown(label="Base model", choices=BASE_MODELS, value=DEFAULT_BASE_MODEL, interactive=True)
negative_prompt = gr.Text(
label="Negative Prompt",
max_lines=5,
placeholder="Please input your negative prompt",
value='ugly, low quality',lines=1
)
control_strength = gr.Slider(
label="Control Strength: ", show_label=True, minimum=0, maximum=1.1, value=1, step=0.01
)
with gr.Group():
seed = gr.Slider(
label="Seed: ", minimum=0, maximum=2147483647, step=1, value=648464818
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
blending = gr.Checkbox(label="Blending mode", value=True)
num_samples = gr.Slider(
label="Num samples", minimum=0, maximum=4, step=1, value=4
)
with gr.Group():
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1,
maximum=12,
step=0.1,
value=7.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=50,
)
with gr.Column():
with gr.Row():
with gr.Tab(elem_classes="feedback", label="Masked Image"):
masked_gallery = gr.Gallery(label='Masked Image', show_label=True, elem_id="gallery", preview=True, height=360)
with gr.Tab(elem_classes="feedback", label="Mask"):
mask_gallery = gr.Gallery(label='Mask', show_label=True, elem_id="gallery", preview=True, height=360)
invert_mask_button = gr.Button("Invert Mask")
dilation_size = gr.Slider(
label="Dilation size: ", minimum=0, maximum=50, step=1, value=20
)
with gr.Row():
dilation_mask_button = gr.Button("Dilation Generated Mask")
erosion_mask_button = gr.Button("Erosion Generated Mask")
moving_pixels = gr.Slider(
label="Moving pixels:", show_label=True, minimum=0, maximum=50, value=4, step=1
)
with gr.Row():
move_left_button = gr.Button("Move Left")
move_right_button = gr.Button("Move Right")
with gr.Row():
move_up_button = gr.Button("Move Up")
move_down_button = gr.Button("Move Down")
with gr.Tab(elem_classes="feedback", label="Output"):
result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
# target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
reset_button = gr.Button("Reset")
init_type = gr.Textbox(label="Init Name", value="", visible=False)
example_type = gr.Textbox(label="Example Name", value="", visible=False)
with gr.Row():
example = gr.Examples(
label="Quick Example",
examples=EXAMPLES,
inputs=[input_image, prompt, seed, init_type, example_type, blending, resize_default, vlm_model_dropdown],
examples_per_page=10,
cache_examples=False,
)
with gr.Accordion(label="🎬 Feature Details:", open=True, elem_id="accordion"):
with gr.Row(equal_height=True):
gr.Markdown(tips)
with gr.Row():
gr.Markdown(citation)
## gr.examples can not be used to update the gr.Gallery, so we need to use the following two functions to update the gr.Gallery.
## And we need to solve the conflict between the upload and change example functions.
input_image.upload(
init_img,
[input_image, init_type, prompt, aspect_ratio, example_change_times],
[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
)
example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
## vlm and base model dropdown
vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
base_model_dropdown.change(fn=update_base_model, inputs=[base_model_dropdown], outputs=[status])
GPT4o_KEY_submit.click(fn=submit_GPT4o_KEY, inputs=[GPT4o_KEY], outputs=[GPT4o_KEY, vlm_model_dropdown])
invert_mask_button.click(fn=invert_mask, inputs=[input_image, original_image, original_mask], outputs=[masked_gallery, mask_gallery, original_mask, invert_mask_state])
ips=[input_image,
original_image,
original_mask,
prompt,
negative_prompt,
control_strength,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
num_samples,
blending,
category,
target_prompt,
resize_default,
aspect_ratio,
invert_mask_state]
## run brushedit
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
## mask func
mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
random_mask_button.click(fn=process_random_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
dilation_mask_button.click(fn=process_dilation_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
erosion_mask_button.click(fn=process_erosion_mask, inputs=[input_image, original_image, original_mask, resize_default, aspect_ratio, dilation_size], outputs=[ masked_gallery, mask_gallery, original_mask])
## move mask func
move_left_button.click(fn=move_mask_left, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
move_right_button.click(fn=move_mask_right, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
move_up_button.click(fn=move_mask_up, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
## prompt func
generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
## reset func
reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
demo.launch()