Spaces:
Runtime error
Runtime error
File size: 14,576 Bytes
640a27b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 |
from omegaconf import OmegaConf
import torch
from PIL import Image
from torchvision import transforms
import os
from tqdm import tqdm
from einops import rearrange
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config
import random
import glob
import re
import shutil
import pdb
import argparse
from convertModels import savemodelDiffusers
# Util Functions
def load_model_from_config(config, ckpt, device="cpu", verbose=False):
"""Loads a model from config and a ckpt
if config is a path will use omegaconf to load
"""
if isinstance(config, (str, Path)):
config = OmegaConf.load(config)
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.to(device)
model.eval()
model.cond_stage_model.device = device
return model
@torch.no_grad()
def sample_model(model, sampler, c, h, w, ddim_steps, scale, ddim_eta, start_code=None, n_samples=1,t_start=-1,log_every_t=None,till_T=None,verbose=True):
"""Sample the model"""
uc = None
if scale != 1.0:
uc = model.get_learned_conditioning(n_samples * [""])
log_t = 100
if log_every_t is not None:
log_t = log_every_t
shape = [4, h // 8, w // 8]
samples_ddim, inters = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=n_samples,
shape=shape,
verbose=False,
x_T=start_code,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
verbose_iter = verbose,
t_start=t_start,
log_every_t = log_t,
till_T = till_T
)
if log_every_t is not None:
return samples_ddim, inters
return samples_ddim
def load_img(path, target_size=512):
"""Load an image, resize and output -1..1"""
image = Image.open(path).convert("RGB")
tform = transforms.Compose([
transforms.Resize(target_size),
transforms.CenterCrop(target_size),
transforms.ToTensor(),
])
image = tform(image)
return 2.*image - 1.
def moving_average(a, n=3) :
ret = np.cumsum(a, dtype=float)
ret[n:] = ret[n:] - ret[:-n]
return ret[n - 1:] / n
def plot_loss(losses, path,word, n=100):
v = moving_average(losses, n)
plt.plot(v, label=f'{word}_loss')
plt.legend(loc="upper left")
plt.title('Average loss in trainings', fontsize=20)
plt.xlabel('Data point', fontsize=16)
plt.ylabel('Loss value', fontsize=16)
plt.savefig(path)
##################### ESD Functions
def get_models(config_path, ckpt_path, devices):
model_orig = load_model_from_config(config_path, ckpt_path, devices[1])
sampler_orig = DDIMSampler(model_orig)
model = load_model_from_config(config_path, ckpt_path, devices[0])
sampler = DDIMSampler(model)
return model_orig, sampler_orig, model, sampler
def train_esd(prompt, train_method, start_guidance, negative_guidance, iterations, lr, config_path, ckpt_path, diffusers_config_path, devices, seperator=None, image_size=512, ddim_steps=50):
'''
Function to train diffusion models to erase concepts from model weights
Parameters
----------
prompt : str
The concept to erase from diffusion model (Eg: "Van Gogh").
train_method : str
The parameters to train for erasure (ESD-x, ESD-u, full, selfattn).
start_guidance : float
Guidance to generate images for training.
negative_guidance : float
Guidance to erase the concepts from diffusion model.
iterations : int
Number of iterations to train.
lr : float
learning rate for fine tuning.
config_path : str
config path for compvis diffusion format.
ckpt_path : str
checkpoint path for pre-trained compvis diffusion weights.
diffusers_config_path : str
Config path for diffusers unet in json format.
devices : str
2 devices used to load the models (Eg: '0,1' will load in cuda:0 and cuda:1).
seperator : str, optional
If the prompt has commas can use this to seperate the prompt for individual simulataneous erasures. The default is None.
image_size : int, optional
Image size for generated images. The default is 512.
ddim_steps : int, optional
Number of diffusion time steps. The default is 50.
Returns
-------
None
'''
# PROMPT CLEANING
word_print = prompt.replace(' ','')
if prompt == 'allartist':
prompt = "Kelly Mckernan, Thomas Kinkade, Ajin Demi Human, Alena Aenami, Tyler Edlin, Kilian Eng"
if prompt == 'i2p':
prompt = "hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity, bodily fluids, blood"
if prompt == "artifact":
prompt = "ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, blurred, text, watermark, grainy"
if seperator is not None:
words = prompt.split(seperator)
words = [word.strip() for word in words]
else:
words = [prompt]
print(words)
ddim_eta = 0
# MODEL TRAINING SETUP
model_orig, sampler_orig, model, sampler = get_models(config_path, ckpt_path, devices)
# choose parameters to train based on train_method
parameters = []
for name, param in model.model.diffusion_model.named_parameters():
# train all layers except x-attns and time_embed layers
if train_method == 'noxattn':
if name.startswith('out.') or 'attn2' in name or 'time_embed' in name:
pass
else:
print(name)
parameters.append(param)
# train only self attention layers
if train_method == 'selfattn':
if 'attn1' in name:
print(name)
parameters.append(param)
# train only x attention layers
if train_method == 'xattn':
if 'attn2' in name:
print(name)
parameters.append(param)
# train all layers
if train_method == 'full':
print(name)
parameters.append(param)
# train all layers except time embed layers
if train_method == 'notime':
if not (name.startswith('out.') or 'time_embed' in name):
print(name)
parameters.append(param)
if train_method == 'xlayer':
if 'attn2' in name:
if 'output_blocks.6.' in name or 'output_blocks.8.' in name:
print(name)
parameters.append(param)
if train_method == 'selflayer':
if 'attn1' in name:
if 'input_blocks.4.' in name or 'input_blocks.7.' in name:
print(name)
parameters.append(param)
# set model to train
model.train()
# create a lambda function for cleaner use of sampling code (only denoising till time step t)
quick_sample_till_t = lambda x, s, code, t: sample_model(model, sampler,
x, image_size, image_size, ddim_steps, s, ddim_eta,
start_code=code, till_T=t, verbose=False)
losses = []
opt = torch.optim.Adam(parameters, lr=lr)
criteria = torch.nn.MSELoss()
history = []
name = f'compvis-word_{word_print}-method_{train_method}-sg_{start_guidance}-ng_{negative_guidance}-iter_{iterations}-lr_{lr}'
# TRAINING CODE
pbar = tqdm(range(iterations))
for i in pbar:
word = random.sample(words,1)[0]
# get text embeddings for unconditional and conditional prompts
emb_0 = model.get_learned_conditioning([''])
emb_p = model.get_learned_conditioning([word])
emb_n = model.get_learned_conditioning([f'{word}'])
opt.zero_grad()
t_enc = torch.randint(ddim_steps, (1,), device=devices[0])
# time step from 1000 to 0 (0 being good)
og_num = round((int(t_enc)/ddim_steps)*1000)
og_num_lim = round((int(t_enc+1)/ddim_steps)*1000)
t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=devices[0])
start_code = torch.randn((1, 4, 64, 64)).to(devices[0])
with torch.no_grad():
# generate an image with the concept from ESD model
z = quick_sample_till_t(emb_p.to(devices[0]), start_guidance, start_code, int(t_enc)) # emb_p seems to work better instead of emb_0
# get conditional and unconditional scores from frozen model at time step t and image z
e_0 = model_orig.apply_model(z.to(devices[1]), t_enc_ddpm.to(devices[1]), emb_0.to(devices[1]))
e_p = model_orig.apply_model(z.to(devices[1]), t_enc_ddpm.to(devices[1]), emb_p.to(devices[1]))
# breakpoint()
# get conditional score from ESD model
e_n = model.apply_model(z.to(devices[0]), t_enc_ddpm.to(devices[0]), emb_n.to(devices[0]))
e_0.requires_grad = False
e_p.requires_grad = False
# reconstruction loss for ESD objective from frozen model and conditional score of ESD model
loss = criteria(e_n.to(devices[0]), e_0.to(devices[0]) - (negative_guidance*(e_p.to(devices[0]) - e_0.to(devices[0])))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
# update weights to erase the concept
loss.backward()
losses.append(loss.item())
pbar.set_postfix({"loss": loss.item()})
history.append(loss.item())
opt.step()
# save checkpoint and loss curve
if (i+1) % 500 == 0 and i+1 != iterations and i+1>= 500:
save_model(model, name, i-1, save_compvis=True, save_diffusers=False)
if i % 100 == 0:
save_history(losses, name, word_print)
model.eval()
save_model(model, name, None, save_compvis=True, save_diffusers=True, compvis_config_file=config_path, diffusers_config_file=diffusers_config_path)
save_history(losses, name, word_print)
def save_model(model, name, num, compvis_config_file=None, diffusers_config_file=None, device='cpu', save_compvis=True, save_diffusers=True):
# SAVE MODEL
# PATH = f'{FOLDER}/{model_type}-word_{word_print}-method_{train_method}-sg_{start_guidance}-ng_{neg_guidance}-iter_{i+1}-lr_{lr}-startmodel_{start_model}-numacc_{numacc}.pt'
folder_path = f'models/{name}'
os.makedirs(folder_path, exist_ok=True)
if num is not None:
path = f'{folder_path}/{name}-epoch_{num}.pt'
else:
path = f'{folder_path}/{name}.pt'
if save_compvis:
torch.save(model.state_dict(), path)
if save_diffusers:
print('Saving Model in Diffusers Format')
savemodelDiffusers(name, compvis_config_file, diffusers_config_file, device=device )
def save_history(losses, name, word_print):
folder_path = f'models/{name}'
os.makedirs(folder_path, exist_ok=True)
with open(f'{folder_path}/loss.txt', 'w') as f:
f.writelines([str(i) for i in losses])
plot_loss(losses,f'{folder_path}/loss.png' , word_print, n=3)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
prog = 'TrainESD',
description = 'Finetuning stable diffusion model to erase concepts using ESD method')
parser.add_argument('--prompt', help='prompt corresponding to concept to erase', type=str, required=True)
parser.add_argument('--train_method', help='method of training', type=str, required=True)
parser.add_argument('--start_guidance', help='guidance of start image used to train', type=float, required=False, default=3)
parser.add_argument('--negative_guidance', help='guidance of negative training used to train', type=float, required=False, default=1)
parser.add_argument('--iterations', help='iterations used to train', type=int, required=False, default=1000)
parser.add_argument('--lr', help='learning rate used to train', type=int, required=False, default=1e-5)
parser.add_argument('--config_path', help='config path for stable diffusion v1-4 inference', type=str, required=False, default='configs/stable-diffusion/v1-inference.yaml')
parser.add_argument('--ckpt_path', help='ckpt path for stable diffusion v1-4', type=str, required=False, default='models/ldm/stable-diffusion-v1/sd-v1-4-full-ema.ckpt')
parser.add_argument('--diffusers_config_path', help='diffusers unet config json path', type=str, required=False, default='diffusers_unet_config.json')
parser.add_argument('--devices', help='cuda devices to train on', type=str, required=False, default='0,0')
parser.add_argument('--seperator', help='separator if you want to train bunch of words separately', type=str, required=False, default=None)
parser.add_argument('--image_size', help='image size used to train', type=int, required=False, default=512)
parser.add_argument('--ddim_steps', help='ddim steps of inference used to train', type=int, required=False, default=50)
args = parser.parse_args()
prompt = args.prompt
train_method = args.train_method
start_guidance = args.start_guidance
negative_guidance = args.negative_guidance
iterations = args.iterations
lr = args.lr
config_path = args.config_path
ckpt_path = args.ckpt_path
diffusers_config_path = args.diffusers_config_path
devices = [f'cuda:{int(d.strip())}' for d in args.devices.split(',')]
seperator = args.seperator
image_size = args.image_size
ddim_steps = args.ddim_steps
train_esd(prompt=prompt, train_method=train_method, start_guidance=start_guidance, negative_guidance=negative_guidance, iterations=iterations, lr=lr, config_path=config_path, ckpt_path=ckpt_path, diffusers_config_path=diffusers_config_path, devices=devices, seperator=seperator, image_size=image_size, ddim_steps=ddim_steps)
|