Damian Stewart
allow multiple train prompts
bf1e262
import os.path
import random
import multiprocessing
import math
from accelerate.utils import set_seed
from diffusers import StableDiffusionPipeline
from torch.cuda.amp import autocast
from torchvision import transforms
from StableDiffuser import StableDiffuser
from finetuning import FineTunedModel
import torch
from tqdm import tqdm
from isolate_rng import isolate_rng
from memory_efficiency import MemoryEfficiencyWrapper
from torch.utils.tensorboard import SummaryWriter
training_should_cancel = multiprocessing.Semaphore(0)
def validate(diffuser: StableDiffuser, finetuner: FineTunedModel,
validation_embeddings: torch.FloatTensor,
neutral_embeddings: torch.FloatTensor,
sample_embeddings: torch.FloatTensor,
logger: SummaryWriter, use_amp: bool,
global_step: int,
validation_seed: int = 555,
batch_size: int = 1,
sample_batch_size: int = 1 # might need to be smaller than batch_size
):
print("validating...")
assert batch_size==1, "batch_size != 1 not implemented work"
with isolate_rng(include_cuda=True), torch.no_grad():
set_seed(validation_seed)
criteria = torch.nn.MSELoss()
negative_guidance = 1
nsteps=50
num_validation_batches = validation_embeddings.shape[0] // (batch_size*2)
val_count = max(1, 5 // num_validation_batches)
val_total_loss = 0
for i in tqdm(range(num_validation_batches)):
if training_should_cancel.acquire(block=False):
print("cancel requested, bailing")
return
accumulated_loss = None
this_validation_embeddings = validation_embeddings[i*batch_size*2:(i+1)*batch_size*2]
for j in range(val_count):
iteration = random.randint(1, nsteps)
diffused_latents = get_diffused_latents(diffuser, nsteps, this_validation_embeddings, iteration, use_amp)
with autocast(enabled=use_amp):
positive_latents = diffuser.predict_noise(iteration, diffused_latents, this_validation_embeddings, guidance_scale=1)
neutral_latents = diffuser.predict_noise(iteration, diffused_latents, neutral_embeddings, guidance_scale=1)
with finetuner, autocast(enabled=use_amp):
negative_latents = diffuser.predict_noise(iteration, diffused_latents, this_validation_embeddings, guidance_scale=1)
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
accumulated_loss = (accumulated_loss or 0) + loss.item()
val_total_loss += loss.item()
logger.add_scalar(f"loss/val_{i}", accumulated_loss/val_count, global_step=global_step)
logger.add_scalar(f"loss/_val_all_combined", val_total_loss/(val_count*num_validation_batches), global_step=global_step)
num_sample_batches = int(math.ceil(sample_embeddings.shape[0] / (sample_batch_size*2)))
for i in tqdm(range(0, num_sample_batches)):
print(f'making sample batch {i}...')
if training_should_cancel.acquire(block=False):
print("cancel requested, bailing")
return
with finetuner:
pipeline = StableDiffusionPipeline(vae=diffuser.vae,
text_encoder=diffuser.text_encoder,
tokenizer=diffuser.tokenizer,
unet=diffuser.unet,
scheduler=diffuser.scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False)
batch_start = (i * sample_batch_size)*2
next_batch_start = batch_start + sample_batch_size*2 + 1
batch_negative_prompt_embeds = torch.cat([sample_embeddings[i+0:i+1] for i in range(batch_start, next_batch_start, 2)])
batch_prompt_embeds = torch.cat([sample_embeddings[i+1:i+2] for i in range(batch_start, next_batch_start, 2)])
images = pipeline(prompt_embeds=batch_prompt_embeds, #sample_embeddings[i*2+1:i*2+2],
negative_prompt_embeds=batch_negative_prompt_embeds, # sample_embeddings[i*2:i*2+1],
num_inference_steps=50)
for image_index, image in enumerate(images.images):
image_tensor = transforms.ToTensor()(image)
logger.add_image(f"samples/{i*sample_batch_size+image_index}", img_tensor=image_tensor, global_step=global_step)
"""
with finetuner, torch.cuda.amp.autocast(enabled=use_amp):
images = diffuser(
combined_embeddings=sample_embeddings[i*2:i*2+2],
n_steps=50
)
logger.add_images(f"samples/{i}", images)
"""
torch.cuda.empty_cache()
def train(repo_id_or_path, img_size, prompts, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1,
batch_size=1, sample_batch_size=1,
save_every_n_steps=-1, validate_every_n_steps=-1,
validation_prompts=[], sample_positive_prompts=[], sample_negative_prompts=[]):
nsteps = 50
print(f"using img_size of {img_size}")
diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path, native_img_size=img_size).to('cuda')
logger = SummaryWriter(log_dir=f"logs/{os.path.splitext(os.path.basename(save_path))[0]}")
memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
use_gradient_checkpointing=use_gradient_checkpointing )
with (((((memory_efficiency_wrapper))))):
diffuser.train()
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
if use_adamw8bit:
print("using AdamW 8Bit optimizer")
import bitsandbytes as bnb
optimizer = bnb.optim.AdamW8bit(finetuner.parameters(),
lr=lr,
betas=(0.9, 0.999),
weight_decay=0.010,
eps=1e-8
)
else:
print("using Adam optimizer")
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
criteria = torch.nn.MSELoss()
pbar = tqdm(range(iterations))
with torch.no_grad():
neutral_text_embeddings = diffuser.get_cond_and_uncond_embeddings([''], n_imgs=1)
all_positive_text_embeddings = diffuser.get_cond_and_uncond_embeddings(prompts, n_imgs=1)
validation_embeddings = diffuser.get_cond_and_uncond_embeddings(validation_prompts, n_imgs=1)
sample_embeddings = diffuser.get_cond_and_uncond_embeddings(sample_positive_prompts, sample_negative_prompts, n_imgs=1)
for i, validation_prompt in enumerate(validation_prompts):
logger.add_text(f"val/{i}", f"validation prompt: \"{validation_prompt}\"")
for i in range(len(sample_positive_prompts)):
positive_prompt = sample_positive_prompts[i]
negative_prompt = "" if i >= len(sample_negative_prompts) else sample_negative_prompts[i]
logger.add_text(f"sample/{i}", f"sample prompt: \"{positive_prompt}\", negative: \"{negative_prompt}\"")
#if use_amp:
# diffuser.vae = diffuser.vae.to(diffuser.vae.device, dtype=torch.float16)
#del diffuser.text_encoder
#del diffuser.tokenizer
torch.cuda.empty_cache()
if seed == -1:
seed = random.randint(0, 2 ** 30)
set_seed(int(seed))
validate(diffuser, finetuner,
validation_embeddings=validation_embeddings,
sample_embeddings=sample_embeddings,
neutral_embeddings=neutral_text_embeddings,
logger=logger, use_amp=False, global_step=0,
batch_size=batch_size, sample_batch_size=sample_batch_size)
prev_losses = []
start_loss = None
max_prev_loss_count = 10
try:
loss=None
negative_latents=None
neutral_latents=None
positive_latents=None
num_prompts = all_positive_text_embeddings.shape[0] // 2
for i in pbar:
try:
loss = None
negative_latents = None
positive_latents = None
neutral_latents = None
diffused_latents = None
for j in tqdm(range(num_prompts)):
positive_text_embeddings = all_positive_text_embeddings[j*2:j*2+2]
if training_should_cancel.acquire(block=False):
print("cancel requested, bailing")
return None
with torch.no_grad():
optimizer.zero_grad()
iteration = torch.randint(1, nsteps - 1, (1,)).item()
with finetuner:
diffused_latents = get_diffused_latents(diffuser, nsteps, positive_text_embeddings, iteration, use_amp)
iteration = int(iteration / nsteps * 1000)
with autocast(enabled=use_amp):
positive_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1)
neutral_latents = diffuser.predict_noise(iteration, diffused_latents, neutral_text_embeddings, guidance_scale=1)
with finetuner:
with autocast(enabled=use_amp):
negative_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1)
positive_latents.requires_grad = False
neutral_latents.requires_grad = False
# loss = criteria(e_n, e_0) works the best try 5000 epochs
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
memory_efficiency_wrapper.backward(loss)
logger.add_scalar("loss", loss.item(), global_step=i)
# print moving average loss
prev_losses.append(loss.detach().clone())
if len(prev_losses) > max_prev_loss_count:
prev_losses.pop(0)
if start_loss is None:
start_loss = prev_losses[-1]
if len(prev_losses) >= max_prev_loss_count:
moving_average_loss = sum(prev_losses) / len(prev_losses)
print(
f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}")
else:
print(f"step {i}: loss={loss.item()}")
memory_efficiency_wrapper.step(optimizer)
finally:
del loss, negative_latents, positive_latents, neutral_latents, diffused_latents
if save_every_n_steps > 0 and ((i+1) % save_every_n_steps) == 0:
torch.save(finetuner.state_dict(), save_path + f"__step_{i+1}.pt")
if validate_every_n_steps > 0 and ((i+1) % validate_every_n_steps) == 0:
validate(diffuser, finetuner,
validation_embeddings=validation_embeddings,
sample_embeddings=sample_embeddings,
neutral_embeddings=neutral_text_embeddings,
logger=logger, use_amp=False, global_step=i,
batch_size=batch_size, sample_batch_size=sample_batch_size)
torch.save(finetuner.state_dict(), save_path)
return save_path
finally:
del diffuser, optimizer, finetuner
torch.cuda.empty_cache()
def get_diffused_latents(diffuser, nsteps, text_embeddings, end_iteration, use_amp):
diffuser.set_scheduler_timesteps(nsteps)
latents = diffuser.get_initial_latents(len(text_embeddings)//2, n_prompts=1)
latents_steps, _ = diffuser.diffusion(
latents,
text_embeddings,
start_iteration=0,
end_iteration=end_iteration,
guidance_scale=3,
show_progress=False,
use_amp=use_amp
)
# because return_latents is not passed to diffuser.diffusion(), latents_steps should have only 1 entry
# but we take the "last" (-1) entry because paranoia
diffused_latents = latents_steps[-1]
diffuser.set_scheduler_timesteps(1000)
del latents_steps, latents
return diffused_latents
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--repo_id_or_path", required=True)
parser.add_argument("--img_size", type=int, required=False, default=512)
parser.add_argument('--prompt', required=True)
parser.add_argument('--modules', required=True)
parser.add_argument('--freeze_modules', nargs='+', required=True)
parser.add_argument('--save_path', required=True)
parser.add_argument('--iterations', type=int, required=True)
parser.add_argument('--lr', type=float, required=True)
parser.add_argument('--negative_guidance', type=float, required=True)
parser.add_argument('--seed', type=int, required=False, default=-1,
help='Training seed for reproducible results, or -1 to pick a random seed')
parser.add_argument('--use_adamw8bit', action='store_true')
parser.add_argument('--use_xformers', action='store_true')
parser.add_argument('--use_amp', action='store_true')
parser.add_argument('--use_gradient_checkpointing', action='store_true')
train(**vars(parser.parse_args()))