|
import os |
|
import torch |
|
import numpy as np |
|
import argparse |
|
from peft import LoraConfig |
|
from pipeline_dedit_sdxl import DEditSDXLPipeline |
|
from pipeline_dedit_sd import DEditSDPipeline |
|
from utils import load_image, load_mask, load_mask_edit |
|
from utils_mask import process_mask_move_torch, process_mask_remove_torch, mask_union_torch, mask_substract_torch, create_outer_edge_mask_torch |
|
from utils_mask import check_mask_overlap_torch, check_cover_all_torch, visualize_mask_list, get_mask_difference_torch, save_mask_list_to_npys |
|
|
|
def run_main( |
|
name="example_tmp", |
|
name_2=None, |
|
dpm="sd", |
|
resolution=512, |
|
seed=42, |
|
embedding_learning_rate=1e-4, |
|
max_emb_train_steps=200, |
|
diffusion_model_learning_rate=5e-5, |
|
max_diffusion_train_steps=200, |
|
train_batch_size=1, |
|
gradient_accumulation_steps=1, |
|
num_tokens=1, |
|
|
|
load_trained=False , |
|
num_sampling_steps=50, |
|
guidance_scale= 3 , |
|
strength=0.8, |
|
|
|
train_full_lora=False , |
|
lora_rank=4, |
|
lora_alpha=4, |
|
|
|
prompt_auxin_list = None, |
|
prompt_auxin_idx_list= None, |
|
|
|
load_edited_mask=False, |
|
load_edited_processed_mask=False, |
|
edge_thickness=20, |
|
num_imgs= 1 , |
|
active_mask_list = None, |
|
tgt_index=None, |
|
|
|
recon=False , |
|
recon_an_item=False, |
|
recon_prompt=None, |
|
|
|
text=False, |
|
tgt_prompt=None, |
|
|
|
image=False , |
|
src_index=None, |
|
tgt_name=None, |
|
|
|
move_resize=False , |
|
tgt_indices_list=None, |
|
delta_x_list=None, |
|
delta_y_list=None, |
|
priority_list=None, |
|
force_mask_remain=None, |
|
resize_list=None, |
|
|
|
remove=False, |
|
load_edited_removemask=False |
|
): |
|
|
|
torch.cuda.manual_seed_all(seed) |
|
torch.manual_seed(seed) |
|
base_input_folder = "." |
|
base_output_folder = "." |
|
|
|
input_folder = os.path.join(base_input_folder, name) |
|
|
|
mask_list, mask_label_list = load_mask(input_folder) |
|
assert mask_list[0].shape[0] == resolution, "Segmentation should be done on size {}".format(resolution) |
|
try: |
|
image_gt = load_image(os.path.join(input_folder, "img_{}.png".format(resolution) ), size = resolution) |
|
except: |
|
image_gt = load_image(os.path.join(input_folder, "img_{}.jpg".format(resolution) ), size = resolution) |
|
|
|
if image: |
|
input_folder_2 = os.path.join(base_input_folder, name_2) |
|
mask_list_2, mask_label_list_2 = load_mask(input_folder_2) |
|
assert mask_list_2[0].shape[0] == resolution, "Segmentation should be done on size {}".format(resolution) |
|
try: |
|
image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.png".format(resolution) ), size = resolution) |
|
except: |
|
image_gt_2 = load_image(os.path.join(input_folder_2, "img_{}.jpg".format(resolution) ), size = resolution) |
|
output_dir = os.path.join(base_output_folder, name + "_" + name_2) |
|
os.makedirs(output_dir, exist_ok = True) |
|
else: |
|
output_dir = os.path.join(base_output_folder, name) |
|
os.makedirs(output_dir, exist_ok = True) |
|
|
|
if dpm == "sd": |
|
if image: |
|
pipe = DEditSDPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = resolution, num_tokens = num_tokens) |
|
else: |
|
pipe = DEditSDPipeline(mask_list, mask_label_list, resolution = resolution, num_tokens = num_tokens) |
|
|
|
elif dpm == "sdxl": |
|
if image: |
|
pipe = DEditSDXLPipeline(mask_list, mask_label_list, mask_list_2, mask_label_list_2, resolution = resolution, num_tokens = num_tokens) |
|
else: |
|
pipe = DEditSDXLPipeline(mask_list, mask_label_list, resolution = resolution, num_tokens = num_tokens) |
|
|
|
else: |
|
raise NotImplementedError |
|
|
|
set_string_list = pipe.set_string_list |
|
if prompt_auxin_list is not None: |
|
for auxin_idx, auxin_prompt in zip(prompt_auxin_idx_list, prompt_auxin_list): |
|
set_string_list[auxin_idx] = auxin_prompt.replace("*", set_string_list[auxin_idx] ) |
|
print(set_string_list) |
|
|
|
if image: |
|
set_string_list_2 = pipe.set_string_list_2 |
|
print(set_string_list_2) |
|
|
|
if load_trained: |
|
unet_save_path = os.path.join(output_dir, "unet.pt") |
|
unet_state_dict = torch.load(unet_save_path) |
|
text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt") |
|
text_encoder1_state_dict = torch.load(text_encoder1_save_path) |
|
if dpm == "sdxl": |
|
text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt") |
|
text_encoder2_state_dict = torch.load(text_encoder2_save_path) |
|
|
|
if 'lora' in ''.join(unet_state_dict.keys()): |
|
unet_lora_config = LoraConfig( |
|
r=lora_rank, |
|
lora_alpha=lora_alpha, |
|
init_lora_weights="gaussian", |
|
target_modules=["to_k", "to_q", "to_v", "to_out.0"], |
|
) |
|
pipe.unet.add_adapter(unet_lora_config) |
|
|
|
pipe.unet.load_state_dict(unet_state_dict) |
|
pipe.text_encoder.load_state_dict(text_encoder1_state_dict) |
|
if dpm == "sdxl": |
|
pipe.text_encoder_2.load_state_dict(text_encoder2_state_dict) |
|
else: |
|
if image: |
|
pipe.mask_list = [m.cuda() for m in pipe.mask_list] |
|
pipe.mask_list_2 = [m.cuda() for m in pipe.mask_list_2] |
|
pipe.train_emb_2imgs( |
|
image_gt, |
|
image_gt_2, |
|
set_string_list, |
|
set_string_list_2, |
|
gradient_accumulation_steps = gradient_accumulation_steps, |
|
embedding_learning_rate = embedding_learning_rate, |
|
max_emb_train_steps = max_emb_train_steps, |
|
train_batch_size = train_batch_size, |
|
) |
|
|
|
pipe.train_model_2imgs( |
|
image_gt, |
|
image_gt_2, |
|
set_string_list, |
|
set_string_list_2, |
|
gradient_accumulation_steps = gradient_accumulation_steps, |
|
max_diffusion_train_steps = max_diffusion_train_steps, |
|
diffusion_model_learning_rate = diffusion_model_learning_rate , |
|
train_batch_size =train_batch_size, |
|
train_full_lora = train_full_lora, |
|
lora_rank = lora_rank, |
|
lora_alpha = lora_alpha |
|
) |
|
|
|
else: |
|
pipe.mask_list = [m.cuda() for m in pipe.mask_list] |
|
pipe.train_emb( |
|
image_gt, |
|
set_string_list, |
|
gradient_accumulation_steps = gradient_accumulation_steps, |
|
embedding_learning_rate = embedding_learning_rate, |
|
max_emb_train_steps = max_emb_train_steps, |
|
train_batch_size = train_batch_size, |
|
) |
|
|
|
pipe.train_model( |
|
image_gt, |
|
set_string_list, |
|
gradient_accumulation_steps = gradient_accumulation_steps, |
|
max_diffusion_train_steps = max_diffusion_train_steps, |
|
diffusion_model_learning_rate = diffusion_model_learning_rate , |
|
train_batch_size = train_batch_size, |
|
train_full_lora = train_full_lora, |
|
lora_rank = lora_rank, |
|
lora_alpha = lora_alpha |
|
) |
|
|
|
|
|
unet_save_path = os.path.join(output_dir, "unet.pt") |
|
torch.save(pipe.unet.state_dict(),unet_save_path ) |
|
text_encoder1_save_path = os.path.join(output_dir, "text_encoder1.pt") |
|
torch.save(pipe.text_encoder.state_dict(), text_encoder1_save_path) |
|
if dpm == "sdxl": |
|
text_encoder2_save_path = os.path.join(output_dir, "text_encoder2.pt") |
|
torch.save(pipe.text_encoder_2.state_dict(), text_encoder2_save_path ) |
|
|
|
|
|
if recon: |
|
output_dir = os.path.join(output_dir, "recon") |
|
os.makedirs(output_dir, exist_ok = True) |
|
if recon_an_item: |
|
mask_list = [torch.from_numpy(np.ones_like(mask_list[0].numpy()))] |
|
tgt_string = set_string_list[tgt_index] |
|
tgt_string = recon_prompt.replace("*", tgt_string) |
|
set_string_list = [tgt_string] |
|
print(set_string_list) |
|
save_path = os.path.join(output_dir, "out_recon.png") |
|
x_np = pipe.inference_with_mask( |
|
save_path, |
|
guidance_scale = guidance_scale, |
|
num_sampling_steps = num_sampling_steps, |
|
seed = seed, |
|
num_imgs = num_imgs, |
|
set_string_list = set_string_list, |
|
mask_list = mask_list |
|
) |
|
|
|
if text: |
|
print("*** Text-guided editing ") |
|
output_dir = os.path.join(output_dir, "text") |
|
os.makedirs(output_dir, exist_ok = True) |
|
save_path = os.path.join(output_dir, "out_text.png") |
|
set_string_list[tgt_index] = tgt_prompt |
|
mask_active = torch.zeros_like(mask_list[0]) |
|
mask_active = mask_union_torch(mask_active, mask_list[tgt_index]) |
|
|
|
if active_mask_list is not None: |
|
for midx in active_mask_list: |
|
mask_active = mask_union_torch(mask_active, mask_list[midx]) |
|
|
|
if load_edited_mask: |
|
mask_list_edited, mask_label_list_edited = load_mask_edit(input_folder) |
|
mask_diff = get_mask_difference_torch(mask_list_edited, mask_list) |
|
mask_active = mask_union_torch(mask_active, mask_diff) |
|
mask_list = mask_list_edited |
|
save_path = os.path.join(output_dir, "out_textEdited.png") |
|
|
|
mask_hard = mask_substract_torch(torch.ones_like(mask_list[0]), mask_active) |
|
mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = edge_thickness) |
|
mask_hard = mask_substract_torch(mask_hard, mask_soft) |
|
|
|
pipe.inference_with_mask( |
|
save_path, |
|
orig_image = image_gt, |
|
set_string_list = set_string_list, |
|
guidance_scale = guidance_scale, |
|
strength = strength, |
|
num_imgs = num_imgs, |
|
mask_hard= mask_hard, |
|
mask_soft = mask_soft, |
|
mask_list = mask_list, |
|
seed = seed, |
|
num_sampling_steps = num_sampling_steps |
|
) |
|
|
|
if remove: |
|
output_dir = os.path.join(output_dir, "remove") |
|
save_path = os.path.join(output_dir, "out_remove.png") |
|
os.makedirs(output_dir, exist_ok = True) |
|
mask_active = torch.zeros_like(mask_list[0]) |
|
|
|
if load_edited_mask: |
|
mask_list_edited, _ = load_mask_edit(input_folder) |
|
mask_diff = get_mask_difference_torch(mask_list_edited, mask_list) |
|
mask_active = mask_union_torch(mask_active, mask_diff) |
|
mask_list = mask_list_edited |
|
|
|
if load_edited_processed_mask: |
|
|
|
mask_list_processed, _ = load_mask_edit(output_dir) |
|
mask_remain = get_mask_difference_torch(mask_list_processed, mask_list) |
|
else: |
|
|
|
mask_list_processed, mask_remain = process_mask_remove_torch(mask_list, tgt_index) |
|
save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask") |
|
visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_removed.png")) |
|
check_cover_all_torch(*mask_list_processed) |
|
mask_active = mask_union_torch(mask_active, mask_remain) |
|
|
|
if active_mask_list is not None: |
|
for midx in active_mask_list: |
|
mask_active = mask_union_torch(mask_active, mask_list[midx]) |
|
|
|
mask_hard = 1 - mask_active |
|
mask_soft = create_outer_edge_mask_torch(mask_remain, edge_thickness = edge_thickness) |
|
mask_hard = mask_substract_torch(mask_hard, mask_soft) |
|
|
|
pipe.inference_with_mask( |
|
save_path, |
|
orig_image = image_gt, |
|
guidance_scale = guidance_scale, |
|
strength = strength, |
|
num_imgs = num_imgs, |
|
mask_hard= mask_hard, |
|
mask_soft = mask_soft, |
|
mask_list = mask_list_processed, |
|
seed = seed, |
|
num_sampling_steps = num_sampling_steps |
|
) |
|
|
|
if image: |
|
output_dir = os.path.join(output_dir, "image") |
|
save_path = os.path.join(output_dir, "out_image.png") |
|
os.makedirs(output_dir, exist_ok = True) |
|
mask_active = torch.zeros_like(mask_list[0]) |
|
|
|
if None not in (tgt_name, src_index, tgt_index): |
|
if tgt_name == name: |
|
set_string_list_tgt = set_string_list |
|
set_string_list_src = set_string_list_2 |
|
image_tgt = image_gt |
|
if load_edited_mask: |
|
mask_list_edited, _ = load_mask_edit(input_folder) |
|
mask_diff = get_mask_difference_torch(mask_list_edited, mask_list) |
|
mask_active = mask_union_torch(mask_active, mask_diff) |
|
mask_list = mask_list_edited |
|
save_path = os.path.join(output_dir, "out_imageEdited.png") |
|
mask_list_tgt = mask_list |
|
|
|
elif tgt_name == name_2: |
|
set_string_list_tgt = set_string_list_2 |
|
set_string_list_src = set_string_list |
|
image_tgt = image_gt_2 |
|
if load_edited_mask: |
|
mask_list_2_edited, _ = load_mask_edit(input_folder_2) |
|
mask_diff = get_mask_difference_torch(mask_list_2_edited, mask_list_2) |
|
mask_active = mask_union_torch(mask_active, mask_diff) |
|
mask_list_2 = mask_list_2_edited |
|
save_path = os.path.join(output_dir, "out_imageEdited.png") |
|
mask_list_tgt = mask_list_2 |
|
else: |
|
exit("tgt_name should be either name or name_2") |
|
|
|
set_string_list_tgt[tgt_index] = set_string_list_src[src_index] |
|
|
|
mask_active = mask_list_tgt[tgt_index] |
|
mask_frozen = (1-mask_active.float()).to(mask_active.device) |
|
mask_soft = create_outer_edge_mask_torch(mask_active.cpu(), edge_thickness = edge_thickness) |
|
mask_hard = mask_substract_torch(mask_frozen.cpu(), mask_soft.cpu()) |
|
|
|
mask_list_tgt = [m.cuda() for m in mask_list_tgt] |
|
|
|
pipe.inference_with_mask( |
|
save_path, |
|
set_string_list = set_string_list_tgt, |
|
mask_list = mask_list_tgt, |
|
guidance_scale = guidance_scale, |
|
num_sampling_steps = num_sampling_steps, |
|
mask_hard = mask_hard.cuda(), |
|
mask_soft = mask_soft.cuda(), |
|
num_imgs = num_imgs, |
|
orig_image = image_tgt, |
|
strength = strength, |
|
) |
|
|
|
if move_resize: |
|
output_dir = os.path.join(output_dir, "move_resize") |
|
os.makedirs(output_dir, exist_ok = True) |
|
save_path = os.path.join(output_dir, "out_moveresize.png") |
|
mask_active = torch.zeros_like(mask_list[0]) |
|
|
|
if load_edited_mask: |
|
mask_list_edited, _ = load_mask_edit(input_folder) |
|
mask_diff = get_mask_difference_torch(mask_list_edited, mask_list) |
|
mask_active = mask_union_torch(mask_active, mask_diff) |
|
mask_list = mask_list_edited |
|
|
|
|
|
if load_edited_processed_mask: |
|
mask_list_processed, _ = load_mask_edit(output_dir) |
|
mask_remain = get_mask_difference_torch(mask_list_processed, mask_list) |
|
else: |
|
mask_list_processed, mask_remain = process_mask_move_torch( |
|
mask_list, |
|
tgt_indices_list, |
|
delta_x_list, |
|
delta_y_list, priority_list, |
|
force_mask_remain = force_mask_remain, |
|
resize_list = resize_list |
|
) |
|
save_mask_list_to_npys(output_dir, mask_list_processed, mask_label_list, name = "mask") |
|
visualize_mask_list(mask_list_processed, os.path.join(output_dir, "seg_move_resize.png")) |
|
active_idxs = tgt_indices_list |
|
|
|
mask_active = mask_union_torch(mask_active, *[m for midx, m in enumerate(mask_list_processed) if midx in active_idxs]) |
|
mask_active = mask_union_torch(mask_remain, mask_active) |
|
if active_mask_list is not None: |
|
for midx in active_mask_list: |
|
mask_active = mask_union_torch(mask_active, mask_list_processed[midx]) |
|
|
|
mask_frozen =(1 - mask_active.float()) |
|
mask_soft = create_outer_edge_mask_torch(mask_active, edge_thickness = edge_thickness) |
|
mask_hard = mask_substract_torch(mask_frozen, mask_soft) |
|
|
|
check_mask_overlap_torch(mask_hard, mask_soft) |
|
|
|
pipe.inference_with_mask( |
|
save_path, |
|
strength = strength, |
|
orig_image = image_gt, |
|
guidance_scale = guidance_scale, |
|
num_sampling_steps = num_sampling_steps, |
|
num_imgs = num_imgs, |
|
mask_hard= mask_hard, |
|
mask_soft = mask_soft, |
|
mask_list = mask_list_processed, |
|
seed = seed |
|
) |
|
|