import os.path import random import multiprocessing import math from accelerate.utils import set_seed from diffusers import StableDiffusionPipeline from torch.cuda.amp import autocast from torchvision import transforms from StableDiffuser import StableDiffuser from finetuning import FineTunedModel import torch from tqdm import tqdm from isolate_rng import isolate_rng from memory_efficiency import MemoryEfficiencyWrapper from torch.utils.tensorboard import SummaryWriter training_should_cancel = multiprocessing.Semaphore(0) def validate(diffuser: StableDiffuser, finetuner: FineTunedModel, validation_embeddings: torch.FloatTensor, neutral_embeddings: torch.FloatTensor, sample_embeddings: torch.FloatTensor, logger: SummaryWriter, use_amp: bool, global_step: int, validation_seed: int = 555, batch_size: int = 1, sample_batch_size: int = 1 # might need to be smaller than batch_size ): print("validating...") assert batch_size==1, "batch_size != 1 not implemented work" with isolate_rng(include_cuda=True), torch.no_grad(): set_seed(validation_seed) criteria = torch.nn.MSELoss() negative_guidance = 1 nsteps=50 num_validation_batches = validation_embeddings.shape[0] // (batch_size*2) val_count = max(1, 5 // num_validation_batches) val_total_loss = 0 for i in tqdm(range(num_validation_batches)): if training_should_cancel.acquire(block=False): print("cancel requested, bailing") return accumulated_loss = None this_validation_embeddings = validation_embeddings[i*batch_size*2:(i+1)*batch_size*2] for j in range(val_count): iteration = random.randint(1, nsteps) diffused_latents = get_diffused_latents(diffuser, nsteps, this_validation_embeddings, iteration, use_amp) with autocast(enabled=use_amp): positive_latents = diffuser.predict_noise(iteration, diffused_latents, this_validation_embeddings, guidance_scale=1) neutral_latents = diffuser.predict_noise(iteration, diffused_latents, neutral_embeddings, guidance_scale=1) with finetuner, autocast(enabled=use_amp): negative_latents = diffuser.predict_noise(iteration, diffused_latents, this_validation_embeddings, guidance_scale=1) loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) accumulated_loss = (accumulated_loss or 0) + loss.item() val_total_loss += loss.item() logger.add_scalar(f"loss/val_{i}", accumulated_loss/val_count, global_step=global_step) logger.add_scalar(f"loss/_val_all_combined", val_total_loss/(val_count*num_validation_batches), global_step=global_step) num_sample_batches = int(math.ceil(sample_embeddings.shape[0] / (sample_batch_size*2))) for i in tqdm(range(0, num_sample_batches)): print(f'making sample batch {i}...') if training_should_cancel.acquire(block=False): print("cancel requested, bailing") return with finetuner: pipeline = StableDiffusionPipeline(vae=diffuser.vae, text_encoder=diffuser.text_encoder, tokenizer=diffuser.tokenizer, unet=diffuser.unet, scheduler=diffuser.scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False) batch_start = (i * sample_batch_size)*2 next_batch_start = batch_start + sample_batch_size*2 + 1 batch_negative_prompt_embeds = torch.cat([sample_embeddings[i+0:i+1] for i in range(batch_start, next_batch_start, 2)]) batch_prompt_embeds = torch.cat([sample_embeddings[i+1:i+2] for i in range(batch_start, next_batch_start, 2)]) images = pipeline(prompt_embeds=batch_prompt_embeds, #sample_embeddings[i*2+1:i*2+2], negative_prompt_embeds=batch_negative_prompt_embeds, # sample_embeddings[i*2:i*2+1], num_inference_steps=50) for image_index, image in enumerate(images.images): image_tensor = transforms.ToTensor()(image) logger.add_image(f"samples/{i*sample_batch_size+image_index}", img_tensor=image_tensor, global_step=global_step) """ with finetuner, torch.cuda.amp.autocast(enabled=use_amp): images = diffuser( combined_embeddings=sample_embeddings[i*2:i*2+2], n_steps=50 ) logger.add_images(f"samples/{i}", images) """ torch.cuda.empty_cache() def train(repo_id_or_path, img_size, prompts, modules, freeze_modules, iterations, negative_guidance, lr, save_path, use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1, batch_size=1, sample_batch_size=1, save_every_n_steps=-1, validate_every_n_steps=-1, validation_prompts=[], sample_positive_prompts=[], sample_negative_prompts=[]): nsteps = 50 print(f"using img_size of {img_size}") diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path, native_img_size=img_size).to('cuda') logger = SummaryWriter(log_dir=f"logs/{os.path.splitext(os.path.basename(save_path))[0]}") memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers, use_gradient_checkpointing=use_gradient_checkpointing ) with (((((memory_efficiency_wrapper))))): diffuser.train() finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules) if use_adamw8bit: print("using AdamW 8Bit optimizer") import bitsandbytes as bnb optimizer = bnb.optim.AdamW8bit(finetuner.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0.010, eps=1e-8 ) else: print("using Adam optimizer") optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr) criteria = torch.nn.MSELoss() pbar = tqdm(range(iterations)) with torch.no_grad(): neutral_text_embeddings = diffuser.get_cond_and_uncond_embeddings([''], n_imgs=1) all_positive_text_embeddings = diffuser.get_cond_and_uncond_embeddings(prompts, n_imgs=1) validation_embeddings = diffuser.get_cond_and_uncond_embeddings(validation_prompts, n_imgs=1) sample_embeddings = diffuser.get_cond_and_uncond_embeddings(sample_positive_prompts, sample_negative_prompts, n_imgs=1) for i, validation_prompt in enumerate(validation_prompts): logger.add_text(f"val/{i}", f"validation prompt: \"{validation_prompt}\"") for i in range(len(sample_positive_prompts)): positive_prompt = sample_positive_prompts[i] negative_prompt = "" if i >= len(sample_negative_prompts) else sample_negative_prompts[i] logger.add_text(f"sample/{i}", f"sample prompt: \"{positive_prompt}\", negative: \"{negative_prompt}\"") #if use_amp: # diffuser.vae = diffuser.vae.to(diffuser.vae.device, dtype=torch.float16) #del diffuser.text_encoder #del diffuser.tokenizer torch.cuda.empty_cache() if seed == -1: seed = random.randint(0, 2 ** 30) set_seed(int(seed)) validate(diffuser, finetuner, validation_embeddings=validation_embeddings, sample_embeddings=sample_embeddings, neutral_embeddings=neutral_text_embeddings, logger=logger, use_amp=False, global_step=0, batch_size=batch_size, sample_batch_size=sample_batch_size) prev_losses = [] start_loss = None max_prev_loss_count = 10 try: loss=None negative_latents=None neutral_latents=None positive_latents=None num_prompts = all_positive_text_embeddings.shape[0] // 2 for i in pbar: try: loss = None negative_latents = None positive_latents = None neutral_latents = None diffused_latents = None for j in tqdm(range(num_prompts)): positive_text_embeddings = all_positive_text_embeddings[j*2:j*2+2] if training_should_cancel.acquire(block=False): print("cancel requested, bailing") return None with torch.no_grad(): optimizer.zero_grad() iteration = torch.randint(1, nsteps - 1, (1,)).item() with finetuner: diffused_latents = get_diffused_latents(diffuser, nsteps, positive_text_embeddings, iteration, use_amp) iteration = int(iteration / nsteps * 1000) with autocast(enabled=use_amp): positive_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1) neutral_latents = diffuser.predict_noise(iteration, diffused_latents, neutral_text_embeddings, guidance_scale=1) with finetuner: with autocast(enabled=use_amp): negative_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1) positive_latents.requires_grad = False neutral_latents.requires_grad = False # loss = criteria(e_n, e_0) works the best try 5000 epochs loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) memory_efficiency_wrapper.backward(loss) logger.add_scalar("loss", loss.item(), global_step=i) # print moving average loss prev_losses.append(loss.detach().clone()) if len(prev_losses) > max_prev_loss_count: prev_losses.pop(0) if start_loss is None: start_loss = prev_losses[-1] if len(prev_losses) >= max_prev_loss_count: moving_average_loss = sum(prev_losses) / len(prev_losses) print( f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}") else: print(f"step {i}: loss={loss.item()}") memory_efficiency_wrapper.step(optimizer) finally: del loss, negative_latents, positive_latents, neutral_latents, diffused_latents if save_every_n_steps > 0 and ((i+1) % save_every_n_steps) == 0: torch.save(finetuner.state_dict(), save_path + f"__step_{i+1}.pt") if validate_every_n_steps > 0 and ((i+1) % validate_every_n_steps) == 0: validate(diffuser, finetuner, validation_embeddings=validation_embeddings, sample_embeddings=sample_embeddings, neutral_embeddings=neutral_text_embeddings, logger=logger, use_amp=False, global_step=i, batch_size=batch_size, sample_batch_size=sample_batch_size) torch.save(finetuner.state_dict(), save_path) return save_path finally: del diffuser, optimizer, finetuner torch.cuda.empty_cache() def get_diffused_latents(diffuser, nsteps, text_embeddings, end_iteration, use_amp): diffuser.set_scheduler_timesteps(nsteps) latents = diffuser.get_initial_latents(len(text_embeddings)//2, n_prompts=1) latents_steps, _ = diffuser.diffusion( latents, text_embeddings, start_iteration=0, end_iteration=end_iteration, guidance_scale=3, show_progress=False, use_amp=use_amp ) # because return_latents is not passed to diffuser.diffusion(), latents_steps should have only 1 entry # but we take the "last" (-1) entry because paranoia diffused_latents = latents_steps[-1] diffuser.set_scheduler_timesteps(1000) del latents_steps, latents return diffused_latents if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument("--repo_id_or_path", required=True) parser.add_argument("--img_size", type=int, required=False, default=512) parser.add_argument('--prompt', required=True) parser.add_argument('--modules', required=True) parser.add_argument('--freeze_modules', nargs='+', required=True) parser.add_argument('--save_path', required=True) parser.add_argument('--iterations', type=int, required=True) parser.add_argument('--lr', type=float, required=True) parser.add_argument('--negative_guidance', type=float, required=True) parser.add_argument('--seed', type=int, required=False, default=-1, help='Training seed for reproducible results, or -1 to pick a random seed') parser.add_argument('--use_adamw8bit', action='store_true') parser.add_argument('--use_xformers', action='store_true') parser.add_argument('--use_amp', action='store_true') parser.add_argument('--use_gradient_checkpointing', action='store_true') train(**vars(parser.parse_args()))