|
import os |
|
|
|
|
|
|
|
import os |
|
import copy |
|
from PIL import Image |
|
import matplotlib |
|
import numpy as np |
|
import gradio as gr |
|
from utils import load_mask, load_mask_edit |
|
from utils_mask import process_mask_to_follow_priority, mask_union, visualize_mask_list_clean |
|
from pathlib import Path |
|
from PIL import Image |
|
from functools import partial |
|
from main import run_main |
|
import time |
|
LENGTH=512 |
|
TRANSPARENCY = 150 |
|
|
|
def add_mask(mask_np_list_updated, mask_label_list): |
|
mask_new = np.zeros_like(mask_np_list_updated[0]) |
|
mask_np_list_updated.append(mask_new) |
|
mask_label_list.append("new") |
|
return mask_np_list_updated, mask_label_list |
|
|
|
def create_segmentation(mask_np_list): |
|
viridis = matplotlib.pyplot.get_cmap(name = 'viridis', lut = len(mask_np_list)) |
|
segmentation = 0 |
|
for i, m in enumerate(mask_np_list): |
|
color = matplotlib.colors.to_rgb(viridis(i)) |
|
color_mat = np.ones_like(m) |
|
color_mat = np.stack([color_mat*color[0], color_mat*color[1],color_mat*color[2] ], axis = 2) |
|
color_mat = color_mat * m[:,:,np.newaxis] |
|
segmentation += color_mat |
|
segmentation = Image.fromarray(np.uint8(segmentation*255)) |
|
return segmentation |
|
|
|
def load_mask_ui(input_folder="example_tmp",load_edit = False): |
|
if not load_edit: |
|
mask_list, mask_label_list = load_mask(input_folder) |
|
else: |
|
mask_list, mask_label_list = load_mask_edit(input_folder) |
|
|
|
mask_np_list = [] |
|
for m in mask_list: |
|
mask_np_list. append( m.cpu().numpy()) |
|
|
|
return mask_np_list, mask_label_list |
|
|
|
def load_image_ui(load_edit, input_folder="example_tmp"): |
|
|
|
if 1: |
|
image, mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit) |
|
|
|
segmentation = create_segmentation(mask_np_list) |
|
print("!!", len(mask_np_list)) |
|
max_val = len(mask_np_list)-1 |
|
sliderup = gr.Slider(value = 0, minimum=0, maximum=max_val, step=1, interactive=True) |
|
return image, segmentation, mask_np_list, mask_label_list, image, sliderup, sliderup |
|
|
|
|
|
|
|
|
|
|
|
def run_segmentation_wrapper(image): |
|
image, mask_np_list,mask_label_list = run_segmentation(image) |
|
|
|
segmentation = create_segmentation(mask_np_list) |
|
print("!!", len(mask_np_list)) |
|
max_val = len(mask_np_list)-1 |
|
sliderup = gr.Slider(value = 0, minimum=0, maximum=max_val, step=1, interactive=True) |
|
return image, segmentation, mask_np_list, mask_label_list, image, sliderup, sliderup , 'Segmentatin finish.' |
|
|
|
|
|
def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128): |
|
print(type(backimg)) |
|
print(type(foreimg)) |
|
print(type(mask_np)) |
|
backimg_solid_np = np.array(backimg) |
|
bimg = backimg.copy() |
|
fimg = foreimg.copy() |
|
fimg.putalpha(transparency) |
|
bimg.paste(fimg, (0,0), fimg) |
|
|
|
bimg_np = np.array(bimg) |
|
mask_np = mask_np[:,:,np.newaxis] |
|
|
|
new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np |
|
return Image.fromarray(np.uint8(new_img_np)) |
|
|
|
def show_segmentation(image, segmentation, flag): |
|
if flag is False: |
|
flag = True |
|
mask_np = np.ones([image.size[0],image.size[1]]).astype(np.uint8) |
|
image_edit = transparent_paste_with_mask(image, segmentation, mask_np ,transparency = TRANSPARENCY) |
|
return image_edit, flag |
|
else: |
|
flag = False |
|
return image,flag |
|
|
|
def edit_mask_add(canvas, image, idx, mask_np_list): |
|
mask_sel = mask_np_list[idx] |
|
mask_new = np.uint8(canvas["mask"][:, :, 0]/ 255.) |
|
mask_np_list_updated = [] |
|
for midx, m in enumerate(mask_np_list): |
|
if midx == idx: |
|
mask_np_list_updated.append(mask_union(mask_sel, mask_new)) |
|
else: |
|
mask_np_list_updated.append(m) |
|
|
|
priority_list = [0 for _ in range(len(mask_np_list_updated))] |
|
priority_list[idx] = 1 |
|
mask_np_list_updated = process_mask_to_follow_priority(mask_np_list_updated, priority_list) |
|
mask_ones = np.ones([mask_sel.shape[0], mask_sel.shape[1]]).astype(np.uint8) |
|
segmentation = create_segmentation(mask_np_list_updated) |
|
image_edit = transparent_paste_with_mask(image, segmentation, mask_ones ,transparency = TRANSPARENCY) |
|
return mask_np_list_updated, image_edit |
|
|
|
def slider_release(index, image, mask_np_list_updated, mask_label_list): |
|
|
|
if index > len(mask_np_list_updated): |
|
return image, "out of range" |
|
else: |
|
mask_np = mask_np_list_updated[index] |
|
mask_label = mask_label_list[index] |
|
segmentation = create_segmentation(mask_np_list_updated) |
|
new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY) |
|
return new_image, mask_label |
|
|
|
def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"): |
|
print(mask_np_list_updated) |
|
try: |
|
assert np.all(sum(mask_np_list_updated)==1) |
|
except: |
|
print("please check mask") |
|
|
|
import pdb; pdb.set_trace() |
|
|
|
for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)): |
|
|
|
np.save(os.path.join(input_folder, "mask{}_{}.npy".format(midx, mask_label)),mask ) |
|
savepath = os.path.join(input_folder, "seg_current.png") |
|
visualize_mask_list_clean(mask_np_list_updated, savepath) |
|
|
|
def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"): |
|
print(mask_np_list_updated) |
|
try: |
|
assert np.all(sum(mask_np_list_updated)==1) |
|
except: |
|
print("please check mask") |
|
|
|
import pdb; pdb.set_trace() |
|
for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)): |
|
np.save(os.path.join(input_folder, "maskEdited{}_{}.npy".format(midx, mask_label)), mask) |
|
savepath = os.path.join(input_folder, "seg_edited.png") |
|
visualize_mask_list_clean(mask_np_list_updated, savepath) |
|
|
|
|
|
def image_change(): |
|
directory_path = "./example_tmp/" |
|
for filename in os.listdir(directory_path): |
|
file_path = os.path.join(directory_path, filename) |
|
if os.path.isfile(file_path) or os.path.islink(file_path): |
|
os.unlink(file_path) |
|
elif os.path.isdir(file_path): |
|
shutil.rmtree(file_path) |
|
return gr.Button("1.2 Load original masks",visible = False) |
|
|
|
def button_clickable(is_clickable): |
|
return gr.Button(interactive=is_clickable) |
|
|
|
|
|
|
|
def load_pil_img(): |
|
from PIL import Image |
|
return Image.open("example_tmp/text/out_text_0.png") |
|
|
|
import shutil |
|
if os.path.isdir("./example_tmp"): |
|
shutil.rmtree("./example_tmp") |
|
|
|
|
|
from segment import run_segmentation |
|
|
|
with gr.Blocks() as demo: |
|
image = gr.State() |
|
image_loaded = gr.State() |
|
segmentation = gr.State() |
|
|
|
mask_np_list = gr.State([]) |
|
mask_label_list = gr.State([]) |
|
mask_np_list_updated = gr.State([]) |
|
true = gr.State(True) |
|
false = gr.State(False) |
|
block_flag = gr.State(0) |
|
num_tokens_global = gr.State(5) |
|
with gr.Row(): |
|
gr.Markdown("""# D-Edit""") |
|
|
|
with gr.Tab(label="1 Edit mask"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
canvas = gr.Image(value = "./img.png", type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True) |
|
|
|
segment_button = gr.Button("1.1 Run segmentation") |
|
|
|
|
|
|
|
flag = gr.State(False) |
|
|
|
|
|
mask_np_list_updated = mask_np_list |
|
with gr.Column(): |
|
result_info0 = gr.Text(label="Response") |
|
gr.Markdown("""<p style="text-align: center; font-size: 20px">Edit Mask (Optional)</p>""") |
|
slider = gr.Slider(0, 20, step=1, label = 'mask id', interactive=False) |
|
label = gr.Text(label='label') |
|
slider.release(slider_release, |
|
inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list], |
|
outputs= [canvas, label] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Tab(label="2 Optimization"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
result_info = gr.Text(label="Response") |
|
|
|
opt_flag = gr.State(0) |
|
gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""") |
|
num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True) |
|
num_tokens_global = num_tokens |
|
embedding_learning_rate = gr.Textbox(value="0.00005", label="Embedding optimization: Learning rate", interactive= True ) |
|
max_emb_train_steps = gr.Number(value="100", label="embedding optimization: Training steps", interactive= True ) |
|
|
|
diffusion_model_learning_rate = gr.Textbox(value="0.00002", label="UNet Optimization: Learning rate", interactive= True ) |
|
max_diffusion_train_steps = gr.Number(value="100", label="UNet Optimization: Learning rate: Training steps", interactive= True ) |
|
|
|
train_batch_size = gr.Number(value="5", label="Batch size", interactive= True ) |
|
gradient_accumulation_steps=gr.Number(value="5", label="Gradient accumulation", interactive= True ) |
|
|
|
add_button = gr.Button("Run optimization") |
|
def run_optimization_wrapper ( |
|
mask_np_list, |
|
mask_label_list, |
|
image, |
|
opt_flag, |
|
num_tokens, |
|
embedding_learning_rate , |
|
max_emb_train_steps , |
|
diffusion_model_learning_rate , |
|
max_diffusion_train_steps, |
|
train_batch_size, |
|
gradient_accumulation_steps, |
|
): |
|
try: |
|
run_optimization = partial( |
|
run_main, |
|
mask_np_list=mask_np_list, |
|
mask_label_list=mask_label_list, |
|
image_gt=np.array(image), |
|
num_tokens=int(num_tokens), |
|
embedding_learning_rate = float(embedding_learning_rate), |
|
max_emb_train_steps = int(max_emb_train_steps), |
|
diffusion_model_learning_rate= float(diffusion_model_learning_rate), |
|
max_diffusion_train_steps = int(max_diffusion_train_steps), |
|
train_batch_size=int(train_batch_size), |
|
gradient_accumulation_steps=int(gradient_accumulation_steps) |
|
) |
|
run_optimization() |
|
print('finish') |
|
return "Optimization finished!" |
|
except: |
|
return "CUDA out of memory, use a smaller batch size or try another picture." |
|
|
|
|
|
def immediate_update(): |
|
return gr.Button("Run Optimization (Check Log for Completion).", interactive=True) |
|
|
|
add_button.click(run_optimization_wrapper, |
|
inputs = [ |
|
mask_np_list, |
|
mask_label_list, |
|
image_loaded, |
|
opt_flag, |
|
num_tokens, |
|
embedding_learning_rate , |
|
max_emb_train_steps , |
|
diffusion_model_learning_rate , |
|
max_diffusion_train_steps, |
|
train_batch_size, |
|
gradient_accumulation_steps |
|
], |
|
outputs = [result_info], api_name=False, concurrency_limit=45) |
|
add_button.click(fn=immediate_update, inputs=[], outputs=[add_button]) |
|
|
|
def change_text(): |
|
return gr.Textbox("Optimization Finished!", interactive = False) |
|
|
|
|
|
with gr.Tab(label="3 Editing"): |
|
with gr.Tab(label="3.1 Text-based editing"): |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True,visible = True) |
|
|
|
|
|
with gr.Column(): |
|
gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""") |
|
|
|
tgt_prompt = gr.Textbox(value="White bag", label="Editing: Text prompt", interactive= True ) |
|
slider2 = gr.Slider(0, 20, step=1, label = 'mask id', interactive=False) |
|
|
|
guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True ) |
|
num_sampling_steps = gr.Number(value="50", label="Editing: Sampling steps", interactive= True ) |
|
edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True ) |
|
strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True ) |
|
|
|
add_button = gr.Button("Run Editing (Check Log for Completion)") |
|
def run_edit_text_wrapper( |
|
mask_np_list, |
|
mask_label_list, |
|
image, |
|
num_tokens, |
|
guidance_scale, |
|
num_sampling_steps , |
|
strength , |
|
edge_thickness, |
|
tgt_prompt , |
|
tgt_index |
|
): |
|
|
|
run_edit_text = partial( |
|
run_main, |
|
mask_np_list=mask_np_list, |
|
mask_label_list=mask_label_list, |
|
image_gt=np.array(image), |
|
load_trained=True, |
|
text=True, |
|
num_tokens = int(num_tokens_global.value), |
|
guidance_scale = float(guidance_scale), |
|
num_sampling_steps = int(num_sampling_steps), |
|
strength = float(strength), |
|
edge_thickness = int(edge_thickness), |
|
num_imgs = 1, |
|
tgt_prompt = tgt_prompt, |
|
tgt_index = int(tgt_index) |
|
) |
|
run_edit_text() |
|
return load_pil_img() |
|
|
|
add_button.click(run_edit_text_wrapper, |
|
inputs = [ mask_np_list, |
|
mask_label_list, |
|
image_loaded,num_tokens_global, |
|
guidance_scale, |
|
num_sampling_steps, |
|
strength , |
|
edge_thickness, |
|
tgt_prompt , |
|
slider2 |
|
], |
|
outputs = [canvas_text_edit],queue=True, |
|
) |
|
|
|
slider.change( |
|
lambda x: x, |
|
inputs=[slider], |
|
outputs=[slider2] |
|
) |
|
|
|
|
|
segment_button.click(run_segmentation_wrapper, |
|
[canvas] , |
|
[image_loaded, segmentation, mask_np_list, mask_label_list, canvas, slider, slider2, result_info0] ) |
|
|
|
|
|
|
|
demo.queue().launch(debug=True) |
|
|