Spaces:
Runtime error
Runtime error
import os.path | |
import random | |
import multiprocessing | |
import math | |
from accelerate.utils import set_seed | |
from diffusers import StableDiffusionPipeline | |
from torch.cuda.amp import autocast | |
from torchvision import transforms | |
from StableDiffuser import StableDiffuser | |
from finetuning import FineTunedModel | |
import torch | |
from tqdm import tqdm | |
from isolate_rng import isolate_rng | |
from memory_efficiency import MemoryEfficiencyWrapper | |
from torch.utils.tensorboard import SummaryWriter | |
training_should_cancel = multiprocessing.Semaphore(0) | |
def validate(diffuser: StableDiffuser, finetuner: FineTunedModel, | |
validation_embeddings: torch.FloatTensor, | |
neutral_embeddings: torch.FloatTensor, | |
sample_embeddings: torch.FloatTensor, | |
logger: SummaryWriter, use_amp: bool, | |
global_step: int, | |
validation_seed: int = 555, | |
batch_size: int = 1, | |
sample_batch_size: int = 1 # might need to be smaller than batch_size | |
): | |
print("validating...") | |
assert batch_size==1, "batch_size != 1 not implemented work" | |
with isolate_rng(include_cuda=True), torch.no_grad(): | |
set_seed(validation_seed) | |
criteria = torch.nn.MSELoss() | |
negative_guidance = 1 | |
nsteps=50 | |
num_validation_batches = validation_embeddings.shape[0] // (batch_size*2) | |
val_count = max(1, 5 // num_validation_batches) | |
val_total_loss = 0 | |
for i in tqdm(range(num_validation_batches)): | |
if training_should_cancel.acquire(block=False): | |
print("cancel requested, bailing") | |
return | |
accumulated_loss = None | |
this_validation_embeddings = validation_embeddings[i*batch_size*2:(i+1)*batch_size*2] | |
for j in range(val_count): | |
iteration = random.randint(1, nsteps) | |
diffused_latents = get_diffused_latents(diffuser, nsteps, this_validation_embeddings, iteration, use_amp) | |
with autocast(enabled=use_amp): | |
positive_latents = diffuser.predict_noise(iteration, diffused_latents, this_validation_embeddings, guidance_scale=1) | |
neutral_latents = diffuser.predict_noise(iteration, diffused_latents, neutral_embeddings, guidance_scale=1) | |
with finetuner, autocast(enabled=use_amp): | |
negative_latents = diffuser.predict_noise(iteration, diffused_latents, this_validation_embeddings, guidance_scale=1) | |
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) | |
accumulated_loss = (accumulated_loss or 0) + loss.item() | |
val_total_loss += loss.item() | |
logger.add_scalar(f"loss/val_{i}", accumulated_loss/val_count, global_step=global_step) | |
logger.add_scalar(f"loss/_val_all_combined", val_total_loss/(val_count*num_validation_batches), global_step=global_step) | |
num_sample_batches = int(math.ceil(sample_embeddings.shape[0] / (sample_batch_size*2))) | |
for i in tqdm(range(0, num_sample_batches)): | |
print(f'making sample batch {i}...') | |
if training_should_cancel.acquire(block=False): | |
print("cancel requested, bailing") | |
return | |
with finetuner: | |
pipeline = StableDiffusionPipeline(vae=diffuser.vae, | |
text_encoder=diffuser.text_encoder, | |
tokenizer=diffuser.tokenizer, | |
unet=diffuser.unet, | |
scheduler=diffuser.scheduler, | |
safety_checker=None, | |
feature_extractor=None, | |
requires_safety_checker=False) | |
batch_start = (i * sample_batch_size)*2 | |
next_batch_start = batch_start + sample_batch_size*2 + 1 | |
batch_negative_prompt_embeds = torch.cat([sample_embeddings[i+0:i+1] for i in range(batch_start, next_batch_start, 2)]) | |
batch_prompt_embeds = torch.cat([sample_embeddings[i+1:i+2] for i in range(batch_start, next_batch_start, 2)]) | |
images = pipeline(prompt_embeds=batch_prompt_embeds, #sample_embeddings[i*2+1:i*2+2], | |
negative_prompt_embeds=batch_negative_prompt_embeds, # sample_embeddings[i*2:i*2+1], | |
num_inference_steps=50) | |
for image_index, image in enumerate(images.images): | |
image_tensor = transforms.ToTensor()(image) | |
logger.add_image(f"samples/{i*sample_batch_size+image_index}", img_tensor=image_tensor, global_step=global_step) | |
""" | |
with finetuner, torch.cuda.amp.autocast(enabled=use_amp): | |
images = diffuser( | |
combined_embeddings=sample_embeddings[i*2:i*2+2], | |
n_steps=50 | |
) | |
logger.add_images(f"samples/{i}", images) | |
""" | |
torch.cuda.empty_cache() | |
def train(repo_id_or_path, img_size, prompts, modules, freeze_modules, iterations, negative_guidance, lr, save_path, | |
use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1, | |
batch_size=1, sample_batch_size=1, | |
save_every_n_steps=-1, validate_every_n_steps=-1, | |
validation_prompts=[], sample_positive_prompts=[], sample_negative_prompts=[]): | |
nsteps = 50 | |
print(f"using img_size of {img_size}") | |
diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path, native_img_size=img_size).to('cuda') | |
logger = SummaryWriter(log_dir=f"logs/{os.path.splitext(os.path.basename(save_path))[0]}") | |
memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers, | |
use_gradient_checkpointing=use_gradient_checkpointing ) | |
with (((((memory_efficiency_wrapper))))): | |
diffuser.train() | |
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules) | |
if use_adamw8bit: | |
print("using AdamW 8Bit optimizer") | |
import bitsandbytes as bnb | |
optimizer = bnb.optim.AdamW8bit(finetuner.parameters(), | |
lr=lr, | |
betas=(0.9, 0.999), | |
weight_decay=0.010, | |
eps=1e-8 | |
) | |
else: | |
print("using Adam optimizer") | |
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr) | |
criteria = torch.nn.MSELoss() | |
pbar = tqdm(range(iterations)) | |
with torch.no_grad(): | |
neutral_text_embeddings = diffuser.get_cond_and_uncond_embeddings([''], n_imgs=1) | |
all_positive_text_embeddings = diffuser.get_cond_and_uncond_embeddings(prompts, n_imgs=1) | |
validation_embeddings = diffuser.get_cond_and_uncond_embeddings(validation_prompts, n_imgs=1) | |
sample_embeddings = diffuser.get_cond_and_uncond_embeddings(sample_positive_prompts, sample_negative_prompts, n_imgs=1) | |
for i, validation_prompt in enumerate(validation_prompts): | |
logger.add_text(f"val/{i}", f"validation prompt: \"{validation_prompt}\"") | |
for i in range(len(sample_positive_prompts)): | |
positive_prompt = sample_positive_prompts[i] | |
negative_prompt = "" if i >= len(sample_negative_prompts) else sample_negative_prompts[i] | |
logger.add_text(f"sample/{i}", f"sample prompt: \"{positive_prompt}\", negative: \"{negative_prompt}\"") | |
#if use_amp: | |
# diffuser.vae = diffuser.vae.to(diffuser.vae.device, dtype=torch.float16) | |
#del diffuser.text_encoder | |
#del diffuser.tokenizer | |
torch.cuda.empty_cache() | |
if seed == -1: | |
seed = random.randint(0, 2 ** 30) | |
set_seed(int(seed)) | |
validate(diffuser, finetuner, | |
validation_embeddings=validation_embeddings, | |
sample_embeddings=sample_embeddings, | |
neutral_embeddings=neutral_text_embeddings, | |
logger=logger, use_amp=False, global_step=0, | |
batch_size=batch_size, sample_batch_size=sample_batch_size) | |
prev_losses = [] | |
start_loss = None | |
max_prev_loss_count = 10 | |
try: | |
loss=None | |
negative_latents=None | |
neutral_latents=None | |
positive_latents=None | |
num_prompts = all_positive_text_embeddings.shape[0] // 2 | |
for i in pbar: | |
try: | |
loss = None | |
negative_latents = None | |
positive_latents = None | |
neutral_latents = None | |
diffused_latents = None | |
for j in tqdm(range(num_prompts)): | |
positive_text_embeddings = all_positive_text_embeddings[j*2:j*2+2] | |
if training_should_cancel.acquire(block=False): | |
print("cancel requested, bailing") | |
return None | |
with torch.no_grad(): | |
optimizer.zero_grad() | |
iteration = torch.randint(1, nsteps - 1, (1,)).item() | |
with finetuner: | |
diffused_latents = get_diffused_latents(diffuser, nsteps, positive_text_embeddings, iteration, use_amp) | |
iteration = int(iteration / nsteps * 1000) | |
with autocast(enabled=use_amp): | |
positive_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1) | |
neutral_latents = diffuser.predict_noise(iteration, diffused_latents, neutral_text_embeddings, guidance_scale=1) | |
with finetuner: | |
with autocast(enabled=use_amp): | |
negative_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1) | |
positive_latents.requires_grad = False | |
neutral_latents.requires_grad = False | |
# loss = criteria(e_n, e_0) works the best try 5000 epochs | |
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) | |
memory_efficiency_wrapper.backward(loss) | |
logger.add_scalar("loss", loss.item(), global_step=i) | |
# print moving average loss | |
prev_losses.append(loss.detach().clone()) | |
if len(prev_losses) > max_prev_loss_count: | |
prev_losses.pop(0) | |
if start_loss is None: | |
start_loss = prev_losses[-1] | |
if len(prev_losses) >= max_prev_loss_count: | |
moving_average_loss = sum(prev_losses) / len(prev_losses) | |
print( | |
f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}") | |
else: | |
print(f"step {i}: loss={loss.item()}") | |
memory_efficiency_wrapper.step(optimizer) | |
finally: | |
del loss, negative_latents, positive_latents, neutral_latents, diffused_latents | |
if save_every_n_steps > 0 and ((i+1) % save_every_n_steps) == 0: | |
torch.save(finetuner.state_dict(), save_path + f"__step_{i+1}.pt") | |
if validate_every_n_steps > 0 and ((i+1) % validate_every_n_steps) == 0: | |
validate(diffuser, finetuner, | |
validation_embeddings=validation_embeddings, | |
sample_embeddings=sample_embeddings, | |
neutral_embeddings=neutral_text_embeddings, | |
logger=logger, use_amp=False, global_step=i, | |
batch_size=batch_size, sample_batch_size=sample_batch_size) | |
torch.save(finetuner.state_dict(), save_path) | |
return save_path | |
finally: | |
del diffuser, optimizer, finetuner | |
torch.cuda.empty_cache() | |
def get_diffused_latents(diffuser, nsteps, text_embeddings, end_iteration, use_amp): | |
diffuser.set_scheduler_timesteps(nsteps) | |
latents = diffuser.get_initial_latents(len(text_embeddings)//2, n_prompts=1) | |
latents_steps, _ = diffuser.diffusion( | |
latents, | |
text_embeddings, | |
start_iteration=0, | |
end_iteration=end_iteration, | |
guidance_scale=3, | |
show_progress=False, | |
use_amp=use_amp | |
) | |
# because return_latents is not passed to diffuser.diffusion(), latents_steps should have only 1 entry | |
# but we take the "last" (-1) entry because paranoia | |
diffused_latents = latents_steps[-1] | |
diffuser.set_scheduler_timesteps(1000) | |
del latents_steps, latents | |
return diffused_latents | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--repo_id_or_path", required=True) | |
parser.add_argument("--img_size", type=int, required=False, default=512) | |
parser.add_argument('--prompt', required=True) | |
parser.add_argument('--modules', required=True) | |
parser.add_argument('--freeze_modules', nargs='+', required=True) | |
parser.add_argument('--save_path', required=True) | |
parser.add_argument('--iterations', type=int, required=True) | |
parser.add_argument('--lr', type=float, required=True) | |
parser.add_argument('--negative_guidance', type=float, required=True) | |
parser.add_argument('--seed', type=int, required=False, default=-1, | |
help='Training seed for reproducible results, or -1 to pick a random seed') | |
parser.add_argument('--use_adamw8bit', action='store_true') | |
parser.add_argument('--use_xformers', action='store_true') | |
parser.add_argument('--use_amp', action='store_true') | |
parser.add_argument('--use_gradient_checkpointing', action='store_true') | |
train(**vars(parser.parse_args())) |