|
|
|
import os |
|
import copy |
|
from PIL import Image |
|
import matplotlib |
|
import numpy as np |
|
import gradio as gr |
|
from utils import load_mask, load_mask_edit |
|
from utils_mask import process_mask_to_follow_priority, mask_union, visualize_mask_list_clean |
|
from pathlib import Path |
|
import subprocess |
|
from PIL import Image |
|
from functools import partial |
|
from main import run_main |
|
LENGTH=512 |
|
TRANSPARENCY = 150 |
|
|
|
def add_mask(mask_np_list_updated, mask_label_list): |
|
mask_new = np.zeros_like(mask_np_list_updated[0]) |
|
mask_np_list_updated.append(mask_new) |
|
mask_label_list.append("new") |
|
return mask_np_list_updated, mask_label_list |
|
|
|
def create_segmentation(mask_np_list): |
|
viridis = matplotlib.pyplot.get_cmap(name = 'viridis', lut = len(mask_np_list)) |
|
segmentation = 0 |
|
for i, m in enumerate(mask_np_list): |
|
color = matplotlib.colors.to_rgb(viridis(i)) |
|
color_mat = np.ones_like(m) |
|
color_mat = np.stack([color_mat*color[0], color_mat*color[1],color_mat*color[2] ], axis = 2) |
|
color_mat = color_mat * m[:,:,np.newaxis] |
|
segmentation += color_mat |
|
segmentation = Image.fromarray(np.uint8(segmentation*255)) |
|
return segmentation |
|
|
|
def load_mask_ui(input_folder="example_tmp",load_edit = False): |
|
if not load_edit: |
|
mask_list, mask_label_list = load_mask(input_folder) |
|
else: |
|
mask_list, mask_label_list = load_mask_edit(input_folder) |
|
|
|
mask_np_list = [] |
|
for m in mask_list: |
|
mask_np_list. append( m.cpu().numpy()) |
|
|
|
return mask_np_list, mask_label_list |
|
|
|
def load_image_ui(load_edit, input_folder="example_tmp"): |
|
try: |
|
for img_path in Path(input_folder).iterdir(): |
|
if img_path.name in ["img_512.png"]: |
|
image = Image.open(img_path) |
|
mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit) |
|
image = image.convert('RGB') |
|
segmentation = create_segmentation(mask_np_list) |
|
print("!!", len(mask_np_list)) |
|
return image, segmentation, mask_np_list, mask_label_list, image |
|
except: |
|
print("Image folder invalid: The folder should contain image.png") |
|
return None, None, None, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128): |
|
backimg_solid_np = np.array(backimg) |
|
bimg = backimg.copy() |
|
fimg = foreimg.copy() |
|
fimg.putalpha(transparency) |
|
bimg.paste(fimg, (0,0), fimg) |
|
|
|
bimg_np = np.array(bimg) |
|
mask_np = mask_np[:,:,np.newaxis] |
|
|
|
try: |
|
new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np |
|
return Image.fromarray(new_img_np) |
|
except: |
|
import pdb; pdb.set_trace() |
|
|
|
def show_segmentation(image, segmentation, flag): |
|
if flag is False: |
|
flag = True |
|
mask_np = np.ones([image.size[0],image.size[1]]).astype(np.uint8) |
|
image_edit = transparent_paste_with_mask(image, segmentation, mask_np ,transparency = TRANSPARENCY) |
|
return image_edit, flag |
|
else: |
|
flag = False |
|
return image,flag |
|
|
|
def edit_mask_add(canvas, image, idx, mask_np_list): |
|
mask_sel = mask_np_list[idx] |
|
mask_new = np.uint8(canvas["mask"][:, :, 0]/ 255.) |
|
mask_np_list_updated = [] |
|
for midx, m in enumerate(mask_np_list): |
|
if midx == idx: |
|
mask_np_list_updated.append(mask_union(mask_sel, mask_new)) |
|
else: |
|
mask_np_list_updated.append(m) |
|
|
|
priority_list = [0 for _ in range(len(mask_np_list_updated))] |
|
priority_list[idx] = 1 |
|
mask_np_list_updated = process_mask_to_follow_priority(mask_np_list_updated, priority_list) |
|
mask_ones = np.ones([mask_sel.shape[0], mask_sel.shape[1]]).astype(np.uint8) |
|
segmentation = create_segmentation(mask_np_list_updated) |
|
image_edit = transparent_paste_with_mask(image, segmentation, mask_ones ,transparency = TRANSPARENCY) |
|
return mask_np_list_updated, image_edit |
|
|
|
def slider_release(index, image, mask_np_list_updated, mask_label_list): |
|
|
|
if index > len(mask_np_list_updated): |
|
return image, "out of range" |
|
else: |
|
mask_np = mask_np_list_updated[index] |
|
mask_label = mask_label_list[index] |
|
segmentation = create_segmentation(mask_np_list_updated) |
|
new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY) |
|
return new_image, mask_label |
|
|
|
def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"): |
|
try: |
|
assert np.all(sum(mask_np_list_updated)==1) |
|
except: |
|
print("please check mask") |
|
|
|
import pdb; pdb.set_trace() |
|
|
|
for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)): |
|
|
|
np.save(os.path.join(input_folder, "mask{}_{}.npy".format(midx, mask_label)),mask ) |
|
savepath = os.path.join(input_folder, "seg_current.png") |
|
visualize_mask_list_clean(mask_np_list_updated, savepath) |
|
|
|
def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"): |
|
try: |
|
assert np.all(sum(mask_np_list_updated)==1) |
|
except: |
|
print("please check mask") |
|
|
|
import pdb; pdb.set_trace() |
|
for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)): |
|
np.save(os.path.join(input_folder, "maskEdited{}_{}.npy".format(midx, mask_label)), mask) |
|
savepath = os.path.join(input_folder, "seg_edited.png") |
|
visualize_mask_list_clean(mask_np_list_updated, savepath) |
|
|
|
|
|
import shutil |
|
if os.path.isdir("./example_tmp"): |
|
shutil.rmtree("./example_tmp") |
|
|
|
from segment import run_segmentation |
|
with gr.Blocks() as demo: |
|
image = gr.State() |
|
image_loaded = gr.State() |
|
segmentation = gr.State() |
|
|
|
mask_np_list = gr.State([]) |
|
mask_label_list = gr.State([]) |
|
mask_np_list_updated = gr.State([]) |
|
true = gr.State(True) |
|
false = gr.State(False) |
|
block_flag = gr.State(0) |
|
num_tokens_global = gr.State(5) |
|
with gr.Row(): |
|
gr.Markdown("""# D-Edit""") |
|
|
|
with gr.Tab(label="1 Edit mask"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
canvas = gr.Image(value = "./img.png", type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True) |
|
|
|
segment_button = gr.Button("1.1 Run segmentation") |
|
segment_button.click(run_segmentation, |
|
[canvas, block_flag] , |
|
[block_flag] ) |
|
|
|
text_button = gr.Button("Waiting 1.1 to complete") |
|
text_button.click(load_image_ui, |
|
[ false] , |
|
[image_loaded, segmentation, mask_np_list, mask_label_list, canvas] ) |
|
|
|
load_edit_button = gr.Button("Waiting 1.1 to complete") |
|
load_edit_button.click(load_image_ui, |
|
[ true] , |
|
[image_loaded, segmentation, mask_np_list, mask_label_list, canvas] ) |
|
|
|
show_segment = gr.Checkbox(label = "Waiting 1.1 to complete") |
|
flag = gr.State(False) |
|
show_segment.select(show_segmentation, |
|
[image_loaded, segmentation, flag], |
|
[canvas, flag]) |
|
def show_more_buttons(): |
|
return gr.Button("1.2 Load original masks"), gr.Button("1.2 Load edited masks") , gr.Checkbox(label = "Show Segmentation") |
|
block_flag.change(show_more_buttons, [], [text_button,load_edit_button,show_segment ]) |
|
|
|
|
|
|
|
mask_np_list_updated = mask_np_list |
|
with gr.Column(): |
|
gr.Markdown("""<p style="text-align: center; font-size: 20px">Edit Mask (Optional)</p>""") |
|
slider = gr.Slider(0, 20, step=1, interactive=True) |
|
label = gr.Textbox() |
|
slider.release(slider_release, |
|
inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list], |
|
outputs= [canvas, label] |
|
) |
|
add_button = gr.Button("Add") |
|
add_button.click( edit_mask_add, |
|
[canvas, image_loaded, slider, mask_np_list_updated] , |
|
[mask_np_list_updated, canvas] |
|
) |
|
|
|
save_button2 = gr.Button("Set and Save as edited masks") |
|
save_button2.click( save_as_edit_mask, |
|
[mask_np_list_updated, mask_label_list] , |
|
[] ) |
|
|
|
save_button = gr.Button("Set and Save as original masks") |
|
save_button.click( save_as_orig_mask, |
|
[mask_np_list_updated, mask_label_list] , |
|
[] ) |
|
|
|
back_button = gr.Button("Back to current seg") |
|
back_button.click( load_mask_ui, |
|
[] , |
|
[ mask_np_list_updated,mask_label_list] ) |
|
|
|
add_mask_button = gr.Button("Add new empty mask") |
|
add_mask_button.click(add_mask, |
|
[mask_np_list_updated, mask_label_list] , |
|
[mask_np_list_updated, mask_label_list] ) |
|
|
|
with gr.Tab(label="2 Optimization"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
txt_box = gr.Textbox("Click to start optimization...", interactive = False) |
|
|
|
opt_flag = gr.State(0) |
|
gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""") |
|
num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True) |
|
num_tokens_global = num_tokens |
|
embedding_learning_rate = gr.Textbox(value="0.0001", label="Embedding optimization: Learning rate", interactive= True ) |
|
max_emb_train_steps = gr.Number(value="200", label="embedding optimization: Training steps", interactive= True ) |
|
|
|
diffusion_model_learning_rate = gr.Textbox(value="0.00005", label="UNet Optimization: Learning rate", interactive= True ) |
|
max_diffusion_train_steps = gr.Number(value="200", label="UNet Optimization: Learning rate: Training steps", interactive= True ) |
|
|
|
train_batch_size = gr.Number(value="5", label="Batch size", interactive= True ) |
|
gradient_accumulation_steps=gr.Number(value="5", label="Gradient accumulation", interactive= True ) |
|
|
|
add_button = gr.Button("Run optimization") |
|
def run_optimization_wrapper ( |
|
opt_flag, |
|
num_tokens, |
|
embedding_learning_rate , |
|
max_emb_train_steps , |
|
diffusion_model_learning_rate , |
|
max_diffusion_train_steps, |
|
train_batch_size, |
|
gradient_accumulation_steps |
|
): |
|
run_optimization = partial( |
|
run_main, |
|
num_tokens=int(num_tokens), |
|
embedding_learning_rate = float(embedding_learning_rate), |
|
max_emb_train_steps = int(max_emb_train_steps), |
|
diffusion_model_learning_rate= float(diffusion_model_learning_rate), |
|
max_diffusion_train_steps = int(max_diffusion_train_steps), |
|
train_batch_size=int(train_batch_size), |
|
gradient_accumulation_steps=int(gradient_accumulation_steps) |
|
) |
|
run_optimization() |
|
return opt_flag+1 |
|
|
|
add_button.click(run_optimization_wrapper, |
|
inputs = [ |
|
opt_flag, |
|
num_tokens, |
|
embedding_learning_rate , |
|
max_emb_train_steps , |
|
diffusion_model_learning_rate , |
|
max_diffusion_train_steps, |
|
train_batch_size, |
|
gradient_accumulation_steps |
|
], |
|
outputs = [opt_flag] |
|
) |
|
|
|
def change_text(txt_box): |
|
return gr.Textbox("Optimization Finished!", interactive = False) |
|
def change_text2(txt_box): |
|
return gr.Textbox("Start optimization, check logs for progress...", interactive = False) |
|
add_button.click(change_text2, txt_box, txt_box) |
|
opt_flag.change(change_text, txt_box, txt_box) |
|
|
|
with gr.Tab(label="3 Editing"): |
|
with gr.Tab(label="3.1 Text-based editing"): |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True) |
|
|
|
|
|
with gr.Column(): |
|
gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""") |
|
|
|
tgt_prompt = gr.Textbox(value="White bag", label="Editing: Text prompt", interactive= True ) |
|
tgt_index = gr.Number(value="0", label="Editing: Object index", interactive= True ) |
|
guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True ) |
|
num_sampling_steps = gr.Number(value="50", label="Editing: Sampling steps", interactive= True ) |
|
edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True ) |
|
strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True ) |
|
|
|
add_button = gr.Button("Run Editing") |
|
def run_edit_text_wrapper( |
|
num_tokens, |
|
guidance_scale, |
|
num_sampling_steps , |
|
strength , |
|
edge_thickness, |
|
tgt_prompt , |
|
tgt_index |
|
): |
|
|
|
run_edit_text = partial( |
|
run_main, |
|
load_trained=True, |
|
text=True, |
|
num_tokens = int(num_tokens_global.value), |
|
guidance_scale = float(guidance_scale), |
|
num_sampling_steps = int(num_sampling_steps), |
|
strength = float(strength), |
|
edge_thickness = int(edge_thickness), |
|
num_imgs = 1, |
|
tgt_prompt = tgt_prompt, |
|
tgt_index = int(tgt_index) |
|
) |
|
return run_edit_text() |
|
|
|
add_button.click(run_edit_text_wrapper, |
|
inputs = [num_tokens_global, |
|
guidance_scale, |
|
num_sampling_steps, |
|
strength , |
|
edge_thickness, |
|
tgt_prompt , |
|
tgt_index |
|
], |
|
outputs = [canvas_text_edit] |
|
) |
|
|
|
def load_pil_img(): |
|
from PIL import Image |
|
return Image.open("example_tmp/text/out_text_0.png") |
|
|
|
load_button = gr.Button("Load results") |
|
load_button.click(load_pil_img, |
|
inputs = [], |
|
outputs = [canvas_text_edit] |
|
) |
|
|
|
|
|
|
|
|
|
demo.queue().launch(share=True, debug=True) |
|
|