Spaces:
Configuration error
Configuration error
import argparse | |
import datetime | |
import logging | |
import inspect | |
import math | |
import os | |
import random | |
import gc | |
import copy | |
from typing import Dict, Optional, Tuple | |
from omegaconf import OmegaConf | |
import cv2 | |
import torch | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
import torchvision.transforms as T | |
import diffusers | |
import transformers | |
from torchvision import transforms | |
from tqdm.auto import tqdm | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger | |
from accelerate.utils import set_seed | |
from models.unet_3d_condition import UNet3DConditionModel | |
from diffusers.models import AutoencoderKL | |
from diffusers import DPMSolverMultistepScheduler, DDPMScheduler, TextToVideoSDPipeline | |
from diffusers.optimization import get_scheduler | |
from diffusers.utils import check_min_version, export_to_video | |
from diffusers.utils.import_utils import is_xformers_available | |
from diffusers.models.attention_processor import AttnProcessor2_0, Attention | |
from diffusers.models.attention import BasicTransformerBlock | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from transformers.models.clip.modeling_clip import CLIPEncoder | |
from utils.dataset import VideoJsonDataset, SingleVideoDataset, \ | |
ImageDataset, VideoFolderDataset, CachedDataset | |
from einops import rearrange, repeat | |
from utils.lora import ( | |
extract_lora_ups_down, | |
inject_trainable_lora, | |
inject_trainable_lora_extended, | |
save_lora_weight, | |
train_patch_pipe, | |
monkeypatch_or_replace_lora, | |
monkeypatch_or_replace_lora_extended | |
) | |
already_printed_trainables = False | |
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. | |
check_min_version("0.10.0.dev0") | |
logger = get_logger(__name__, log_level="INFO") | |
def create_logging(logging, logger, accelerator): | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
logger.info(accelerator.state, main_process_only=False) | |
def accelerate_set_verbose(accelerator): | |
if accelerator.is_local_main_process: | |
transformers.utils.logging.set_verbosity_warning() | |
diffusers.utils.logging.set_verbosity_info() | |
else: | |
transformers.utils.logging.set_verbosity_error() | |
diffusers.utils.logging.set_verbosity_error() | |
def get_train_dataset(dataset_types, train_data, tokenizer): | |
train_datasets = [] | |
# Loop through all available datasets, get the name, then add to list of data to process. | |
for DataSet in [VideoJsonDataset, SingleVideoDataset, ImageDataset, VideoFolderDataset]: | |
for dataset in dataset_types: | |
if dataset == DataSet.__getname__(): | |
train_datasets.append(DataSet(**train_data, tokenizer=tokenizer)) | |
if len(train_datasets) > 0: | |
return train_datasets | |
else: | |
raise ValueError("Dataset type not found: 'json', 'single_video', 'folder', 'image'") | |
def extend_datasets(datasets, dataset_items, extend=False): | |
biggest_data_len = max(x.__len__() for x in datasets) | |
extended = [] | |
for dataset in datasets: | |
if dataset.__len__() == 0: | |
del dataset | |
continue | |
if dataset.__len__() < biggest_data_len: | |
for item in dataset_items: | |
if extend and item not in extended and hasattr(dataset, item): | |
print(f"Extending {item}") | |
value = getattr(dataset, item) | |
value *= biggest_data_len | |
value = value[:biggest_data_len] | |
setattr(dataset, item, value) | |
print(f"New {item} dataset length: {dataset.__len__()}") | |
extended.append(item) | |
def export_to_video(video_frames, output_video_path, fps): | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
h, w, _ = video_frames[0].shape | |
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h)) | |
for i in range(len(video_frames)): | |
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) | |
video_writer.write(img) | |
def create_output_folders(output_dir, config): | |
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | |
out_dir = os.path.join(output_dir, f"train_{now}") | |
os.makedirs(out_dir, exist_ok=True) | |
os.makedirs(f"{out_dir}/samples", exist_ok=True) | |
OmegaConf.save(config, os.path.join(out_dir, 'config.yaml')) | |
return out_dir | |
def load_primary_models(pretrained_model_path): | |
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") | |
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") | |
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") | |
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") | |
unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") | |
return noise_scheduler, tokenizer, text_encoder, vae, unet | |
def unet_and_text_g_c(unet, text_encoder, unet_enable, text_enable): | |
unet._set_gradient_checkpointing(value=unet_enable) | |
text_encoder._set_gradient_checkpointing(CLIPEncoder, value=text_enable) | |
def freeze_models(models_to_freeze): | |
for model in models_to_freeze: | |
if model is not None: model.requires_grad_(False) | |
def is_attn(name): | |
return ('attn1' or 'attn2' == name.split('.')[-1]) | |
def set_processors(attentions): | |
for attn in attentions: attn.set_processor(AttnProcessor2_0()) | |
def set_torch_2_attn(unet): | |
optim_count = 0 | |
for name, module in unet.named_modules(): | |
if is_attn(name): | |
if isinstance(module, torch.nn.ModuleList): | |
for m in module: | |
if isinstance(m, BasicTransformerBlock): | |
set_processors([m.attn1, m.attn2]) | |
optim_count += 1 | |
if optim_count > 0: | |
print(f"{optim_count} Attention layers using Scaled Dot Product Attention.") | |
def handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet): | |
try: | |
is_torch_2 = hasattr(F, 'scaled_dot_product_attention') | |
if enable_xformers_memory_efficient_attention and not is_torch_2: | |
if is_xformers_available(): | |
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp | |
unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) | |
else: | |
raise ValueError("xformers is not available. Make sure it is installed correctly") | |
if enable_torch_2_attn and is_torch_2: | |
set_torch_2_attn(unet) | |
except: | |
print("Could not enable memory efficient attention for xformers or Torch 2.0.") | |
def inject_lora(use_lora, model, replace_modules, is_extended=False, dropout=0.0, lora_path='', r=16): | |
injector = ( | |
inject_trainable_lora if not is_extended | |
else | |
inject_trainable_lora_extended | |
) | |
params = None | |
negation = None | |
if os.path.exists(lora_path): | |
try: | |
for f in os.listdir(lora_path): | |
if f.endswith('.pt'): | |
lora_file = os.path.join(lora_path, f) | |
if 'text_encoder' in f and isinstance(model, CLIPTextModel): | |
monkeypatch_or_replace_lora( | |
model, | |
torch.load(lora_file), | |
target_replace_module=replace_modules, | |
r=r | |
) | |
print("Successfully loaded Text Encoder LoRa.") | |
if 'unet' in f and isinstance(model, UNet3DConditionModel): | |
monkeypatch_or_replace_lora_extended( | |
model, | |
torch.load(lora_file), | |
target_replace_module=replace_modules, | |
r=r | |
) | |
print("Successfully loaded UNET LoRa.") | |
except Exception as e: | |
print(e) | |
print("Could not load LoRAs. Injecting new ones instead...") | |
if use_lora: | |
REPLACE_MODULES = replace_modules | |
injector_args = { | |
"model": model, | |
"target_replace_module": REPLACE_MODULES, | |
"r": r | |
} | |
if not is_extended: injector_args['dropout_p'] = dropout | |
params, negation = injector(**injector_args) | |
for _up, _down in extract_lora_ups_down( | |
model, | |
target_replace_module=REPLACE_MODULES): | |
if all(x is not None for x in [_up, _down]): | |
print(f"Lora successfully injected into {model.__class__.__name__}.") | |
break | |
return params, negation | |
def save_lora(model, name, condition, replace_modules, step, save_path): | |
if condition and replace_modules is not None: | |
save_path = f"{save_path}/{step}_{name}.pt" | |
save_lora_weight(model, save_path, replace_modules) | |
def handle_lora_save( | |
use_unet_lora, | |
use_text_lora, | |
model, | |
save_path, | |
checkpoint_step, | |
unet_target_modules, | |
text_encoder_target_modules | |
): | |
save_path = f"{save_path}/lora" | |
os.makedirs(save_path, exist_ok=True) | |
save_lora( | |
model.unet, | |
'unet', | |
use_unet_lora, | |
unet_target_modules, | |
checkpoint_step, | |
save_path, | |
) | |
save_lora( | |
model.text_encoder, | |
'text_encoder', | |
use_text_lora, | |
text_encoder_target_modules, | |
checkpoint_step, | |
save_path | |
) | |
train_patch_pipe(model, use_unet_lora, use_text_lora) | |
def param_optim(model, condition, extra_params=None, is_lora=False, negation=None): | |
return { | |
"model": model, | |
"condition": condition, | |
'extra_params': extra_params, | |
'is_lora': is_lora, | |
"negation": negation | |
} | |
def create_optim_params(name='param', params=None, lr=5e-6, extra_params=None): | |
params = { | |
"name": name, | |
"params": params, | |
"lr": lr | |
} | |
if extra_params is not None: | |
for k, v in extra_params.items(): | |
params[k] = v | |
return params | |
def negate_params(name, negation): | |
# We have to do this if we are co-training with LoRA. | |
# This ensures that parameter groups aren't duplicated. | |
if negation is None: return False | |
for n in negation: | |
if n in name and 'temp' not in name: | |
return True | |
return False | |
def create_optimizer_params(model_list, lr): | |
import itertools | |
optimizer_params = [] | |
for optim in model_list: | |
model, condition, extra_params, is_lora, negation = optim.values() | |
# Check if we are doing LoRA training. | |
if is_lora and condition: | |
params = create_optim_params( | |
params=itertools.chain(*model), | |
extra_params=extra_params | |
) | |
optimizer_params.append(params) | |
continue | |
# If this is true, we can train it. | |
if condition: | |
for n, p in model.named_parameters(): | |
should_negate = 'lora' in n | |
if should_negate: continue | |
params = create_optim_params(n, p, lr, extra_params) | |
optimizer_params.append(params) | |
return optimizer_params | |
def get_optimizer(use_8bit_adam): | |
if use_8bit_adam: | |
try: | |
import bitsandbytes as bnb | |
except ImportError: | |
raise ImportError( | |
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" | |
) | |
return bnb.optim.AdamW8bit | |
else: | |
return torch.optim.AdamW | |
def is_mixed_precision(accelerator): | |
weight_dtype = torch.float32 | |
if accelerator.mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
elif accelerator.mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
return weight_dtype | |
def cast_to_gpu_and_type(model_list, accelerator, weight_dtype): | |
for model in model_list: | |
if model is not None: model.to(accelerator.device, dtype=weight_dtype) | |
def handle_cache_latents( | |
should_cache, | |
output_dir, | |
train_dataloader, | |
train_batch_size, | |
vae, | |
cached_latent_dir=None | |
): | |
# Cache latents by storing them in VRAM. | |
# Speeds up training and saves memory by not encoding during the train loop. | |
if not should_cache: return None | |
vae.to('cuda', dtype=torch.float16) | |
vae.enable_slicing() | |
cached_latent_dir = ( | |
os.path.abspath(cached_latent_dir) if cached_latent_dir is not None else None | |
) | |
if cached_latent_dir is None: | |
cache_save_dir = f"{output_dir}/cached_latents" | |
os.makedirs(cache_save_dir, exist_ok=True) | |
for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")): | |
save_name = f"cached_{i}" | |
full_out_path = f"{cache_save_dir}/{save_name}.pt" | |
pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float16) | |
batch['pixel_values'] = tensor_to_vae_latent(pixel_values, vae) | |
for k, v in batch.items(): batch[k] = v[0] | |
torch.save(batch, full_out_path) | |
del pixel_values | |
del batch | |
# We do this to avoid fragmentation from casting latents between devices. | |
torch.cuda.empty_cache() | |
else: | |
cache_save_dir = cached_latent_dir | |
return torch.utils.data.DataLoader( | |
CachedDataset(cache_dir=cache_save_dir), | |
batch_size=train_batch_size, | |
shuffle=True, | |
num_workers=0 | |
) | |
def handle_trainable_modules(model, trainable_modules=None, is_enabled=True, negation=None): | |
global already_printed_trainables | |
# This can most definitely be refactored :-) | |
unfrozen_params = 0 | |
if trainable_modules is not None: | |
for name, module in model.named_modules(): | |
for tm in tuple(trainable_modules): | |
if tm == 'all': | |
model.requires_grad_(is_enabled) | |
unfrozen_params =len(list(model.parameters())) | |
break | |
if tm in name and 'lora' not in name: | |
for m in module.parameters(): | |
m.requires_grad_(is_enabled) | |
if is_enabled: unfrozen_params +=1 | |
if unfrozen_params > 0 and not already_printed_trainables: | |
already_printed_trainables = True | |
print(f"{unfrozen_params} params have been unfrozen for training.") | |
def tensor_to_vae_latent(t, vae): | |
video_length = t.shape[1] | |
t = rearrange(t, "b f c h w -> (b f) c h w") | |
latents = vae.encode(t).latent_dist.sample() | |
latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) | |
latents = latents * 0.18215 | |
return latents | |
def sample_noise(latents, noise_strength, use_offset_noise): | |
b ,c, f, *_ = latents.shape | |
noise_latents = torch.randn_like(latents, device=latents.device) | |
offset_noise = None | |
if use_offset_noise: | |
offset_noise = torch.randn(b, c, f, 1, 1, device=latents.device) | |
noise_latents = noise_latents + noise_strength * offset_noise | |
return noise_latents | |
def should_sample(global_step, validation_steps, validation_data): | |
return (global_step % validation_steps == 0 or global_step == 1) \ | |
and validation_data.sample_preview | |
def save_pipe( | |
path, | |
global_step, | |
accelerator, | |
unet, | |
text_encoder, | |
vae, | |
output_dir, | |
use_unet_lora, | |
use_text_lora, | |
unet_target_replace_module=None, | |
text_target_replace_module=None, | |
is_checkpoint=False, | |
): | |
if is_checkpoint: | |
save_path = os.path.join(output_dir, f"checkpoint-{global_step}") | |
os.makedirs(save_path, exist_ok=True) | |
else: | |
save_path = output_dir | |
# Save the dtypes so we can continue training at the same precision. | |
u_dtype, t_dtype, v_dtype = unet.dtype, text_encoder.dtype, vae.dtype | |
# Copy the model without creating a reference to it. This allows keeping the state of our lora training if enabled. | |
unet_out = copy.deepcopy(accelerator.unwrap_model(unet, keep_fp32_wrapper=False)) | |
text_encoder_out = copy.deepcopy(accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)) | |
pipeline = TextToVideoSDPipeline.from_pretrained( | |
path, | |
unet=unet_out, | |
text_encoder=text_encoder_out, | |
vae=vae, | |
).to(torch_dtype=torch.float16) | |
handle_lora_save( | |
use_unet_lora, | |
use_text_lora, | |
pipeline, | |
output_dir, | |
global_step, | |
unet_target_replace_module, | |
text_target_replace_module | |
) | |
pipeline.save_pretrained(save_path) | |
if is_checkpoint: | |
unet, text_encoder = accelerator.prepare(unet, text_encoder) | |
models_to_cast_back = [(unet, u_dtype), (text_encoder, t_dtype), (vae, v_dtype)] | |
[x[0].to(accelerator.device, dtype=x[1]) for x in models_to_cast_back] | |
logger.info(f"Saved model at {save_path} on step {global_step}") | |
del pipeline | |
del unet_out | |
del text_encoder_out | |
torch.cuda.empty_cache() | |
gc.collect() | |
def replace_prompt(prompt, token, wlist): | |
for w in wlist: | |
if w in prompt: return prompt.replace(w, token) | |
return prompt | |
def main( | |
pretrained_model_path: str, | |
output_dir: str, | |
train_data: Dict, | |
validation_data: Dict, | |
dataset_types: Tuple[str] = ('json'), | |
validation_steps: int = 100, | |
trainable_modules: Tuple[str] = ("attn1", "attn2"), | |
trainable_text_modules: Tuple[str] = ("all"), | |
extra_unet_params = None, | |
extra_text_encoder_params = None, | |
train_batch_size: int = 1, | |
max_train_steps: int = 500, | |
learning_rate: float = 5e-5, | |
scale_lr: bool = False, | |
lr_scheduler: str = "constant", | |
lr_warmup_steps: int = 0, | |
adam_beta1: float = 0.9, | |
adam_beta2: float = 0.999, | |
adam_weight_decay: float = 1e-2, | |
adam_epsilon: float = 1e-08, | |
max_grad_norm: float = 1.0, | |
gradient_accumulation_steps: int = 1, | |
gradient_checkpointing: bool = False, | |
text_encoder_gradient_checkpointing: bool = False, | |
checkpointing_steps: int = 500, | |
resume_from_checkpoint: Optional[str] = None, | |
mixed_precision: Optional[str] = "fp16", | |
use_8bit_adam: bool = False, | |
enable_xformers_memory_efficient_attention: bool = True, | |
enable_torch_2_attn: bool = False, | |
seed: Optional[int] = None, | |
train_text_encoder: bool = False, | |
use_offset_noise: bool = False, | |
offset_noise_strength: float = 0.1, | |
extend_dataset: bool = False, | |
cache_latents: bool = False, | |
cached_latent_dir = None, | |
use_unet_lora: bool = False, | |
use_text_lora: bool = False, | |
unet_lora_modules: Tuple[str] = ["ResnetBlock2D"], | |
text_encoder_lora_modules: Tuple[str] = ["CLIPEncoderLayer"], | |
lora_rank: int = 16, | |
lora_path: str = '', | |
**kwargs | |
): | |
*_, config = inspect.getargvalues(inspect.currentframe()) | |
accelerator = Accelerator( | |
gradient_accumulation_steps=gradient_accumulation_steps, | |
mixed_precision=mixed_precision, | |
log_with="tensorboard", | |
logging_dir=output_dir | |
) | |
# Make one log on every process with the configuration for debugging. | |
create_logging(logging, logger, accelerator) | |
# Initialize accelerate, transformers, and diffusers warnings | |
accelerate_set_verbose(accelerator) | |
# If passed along, set the training seed now. | |
if seed is not None: | |
set_seed(seed) | |
# Handle the output folder creation | |
if accelerator.is_main_process: | |
output_dir = create_output_folders(output_dir, config) | |
# Load scheduler, tokenizer and models. | |
noise_scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(pretrained_model_path) | |
# Freeze any necessary models | |
freeze_models([vae, text_encoder, unet]) | |
# Enable xformers if available | |
handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet) | |
if scale_lr: | |
learning_rate = ( | |
learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes | |
) | |
# Initialize the optimizer | |
optimizer_cls = get_optimizer(use_8bit_adam) | |
# Use LoRA if enabled. | |
unet_lora_params, unet_negation = inject_lora( | |
use_unet_lora, unet, unet_lora_modules, is_extended=True, | |
r=lora_rank, lora_path=lora_path | |
) | |
text_encoder_lora_params, text_encoder_negation = inject_lora( | |
use_text_lora, text_encoder, text_encoder_lora_modules, | |
r=lora_rank, lora_path=lora_path | |
) | |
# Create parameters to optimize over with a condition (if "condition" is true, optimize it) | |
optim_params = [ | |
param_optim(unet, trainable_modules is not None, extra_params=extra_unet_params, negation=unet_negation), | |
param_optim(text_encoder, train_text_encoder and not use_text_lora, extra_params=extra_text_encoder_params, | |
negation=text_encoder_negation | |
), | |
param_optim(text_encoder_lora_params, use_text_lora, is_lora=True, extra_params={"lr": 1e-5}), | |
param_optim(unet_lora_params, use_unet_lora, is_lora=True, extra_params={"lr": 1e-5}) | |
] | |
params = create_optimizer_params(optim_params, learning_rate) | |
# Create Optimizer | |
optimizer = optimizer_cls( | |
params, | |
lr=learning_rate, | |
betas=(adam_beta1, adam_beta2), | |
weight_decay=adam_weight_decay, | |
eps=adam_epsilon, | |
) | |
# Scheduler | |
lr_scheduler = get_scheduler( | |
lr_scheduler, | |
optimizer=optimizer, | |
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, | |
num_training_steps=max_train_steps * gradient_accumulation_steps, | |
) | |
# Get the training dataset based on types (json, single_video, image) | |
train_datasets = get_train_dataset(dataset_types, train_data, tokenizer) | |
# Extend datasets that are less than the greatest one. This allows for more balanced training. | |
attrs = ['train_data', 'frames', 'image_dir', 'video_files'] | |
extend_datasets(train_datasets, attrs, extend=extend_dataset) | |
# Process one dataset | |
if len(train_datasets) == 1: | |
train_dataset = train_datasets[0] | |
# Process many datasets | |
else: | |
train_dataset = torch.utils.data.ConcatDataset(train_datasets) | |
# DataLoaders creation: | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=train_batch_size, | |
shuffle=True | |
) | |
# Latents caching | |
cached_data_loader = handle_cache_latents( | |
cache_latents, | |
output_dir, | |
train_dataloader, | |
train_batch_size, | |
vae, | |
cached_latent_dir | |
) | |
if cached_data_loader is not None: | |
train_dataloader = cached_data_loader | |
# Prepare everything with our `accelerator`. | |
unet, optimizer,train_dataloader, lr_scheduler, text_encoder = accelerator.prepare( | |
unet, | |
optimizer, | |
train_dataloader, | |
lr_scheduler, | |
text_encoder | |
) | |
# Use Gradient Checkpointing if enabled. | |
unet_and_text_g_c( | |
unet, | |
text_encoder, | |
gradient_checkpointing, | |
text_encoder_gradient_checkpointing | |
) | |
# Enable VAE slicing to save memory. | |
vae.enable_slicing() | |
# For mixed precision training we cast the text_encoder and vae weights to half-precision | |
# as these models are only used for inference, keeping weights in full precision is not required. | |
weight_dtype = is_mixed_precision(accelerator) | |
# Move text encoders, and VAE to GPU | |
models_to_cast = [text_encoder, vae] | |
cast_to_gpu_and_type(models_to_cast, accelerator, weight_dtype) | |
# We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | |
# Afterwards we recalculate our number of training epochs | |
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) | |
# We need to initialize the trackers we use, and also store our configuration. | |
# The trackers initializes automatically on the main process. | |
if accelerator.is_main_process: | |
accelerator.init_trackers("text2video-fine-tune") | |
# Train! | |
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {len(train_dataset)}") | |
logger.info(f" Num Epochs = {num_train_epochs}") | |
logger.info(f" Instantaneous batch size per device = {train_batch_size}") | |
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") | |
logger.info(f" Total optimization steps = {max_train_steps}") | |
global_step = 0 | |
first_epoch = 0 | |
# Only show the progress bar once on each machine. | |
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process) | |
progress_bar.set_description("Steps") | |
def finetune_unet(batch, train_encoder=False): | |
# Check if we are training the text encoder | |
text_trainable = (train_text_encoder or use_text_lora) | |
# Unfreeze UNET Layers | |
if global_step == 0: | |
already_printed_trainables = False | |
unet.train() | |
handle_trainable_modules( | |
unet, | |
trainable_modules, | |
is_enabled=True, | |
negation=unet_negation | |
) | |
# Convert videos to latent space | |
pixel_values = batch["pixel_values"] | |
if not cache_latents: | |
latents = tensor_to_vae_latent(pixel_values, vae) | |
else: | |
latents = pixel_values | |
# Get video length | |
video_length = latents.shape[2] | |
# Sample noise that we'll add to the latents | |
noise = sample_noise(latents, offset_noise_strength, use_offset_noise) | |
bsz = latents.shape[0] | |
# Sample a random timestep for each video | |
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) | |
timesteps = timesteps.long() | |
# Add noise to the latents according to the noise magnitude at each timestep | |
# (this is the forward diffusion process) | |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | |
# Enable text encoder training | |
if text_trainable: | |
text_encoder.train() | |
if use_text_lora: | |
text_encoder.text_model.embeddings.requires_grad_(True) | |
if global_step == 0 and train_text_encoder: | |
handle_trainable_modules( | |
text_encoder, | |
trainable_modules=trainable_text_modules, | |
negation=text_encoder_negation | |
) | |
cast_to_gpu_and_type([text_encoder], accelerator, torch.float32) | |
# Fixes gradient checkpointing training. | |
# See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb | |
if gradient_checkpointing or text_encoder_gradient_checkpointing: | |
unet.eval() | |
text_encoder.eval() | |
# Encode text embeddings | |
token_ids = batch['prompt_ids'] | |
encoder_hidden_states = text_encoder(token_ids)[0] | |
# Get the target for loss depending on the prediction type | |
if noise_scheduler.prediction_type == "epsilon": | |
target = noise | |
elif noise_scheduler.prediction_type == "v_prediction": | |
target = noise_scheduler.get_velocity(latents, noise, timesteps) | |
else: | |
raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}") | |
# Here we do two passes for video and text training. | |
# If we are on the second iteration of the loop, get one frame. | |
# This allows us to train text information only on the spatial layers. | |
losses = [] | |
should_truncate_video = (video_length > 1 and text_trainable) | |
# We detach the encoder hidden states for the first pass (video frames > 1) | |
# Then we make a clone of the initial state to ensure we can train it in the loop. | |
detached_encoder_state = encoder_hidden_states.clone().detach() | |
trainable_encoder_state = encoder_hidden_states.clone() | |
for i in range(2): | |
should_detach = noisy_latents.shape[2] > 1 and i == 0 | |
if should_truncate_video and i == 1: | |
noisy_latents = noisy_latents[:,:,1,:,:].unsqueeze(2) | |
target = target[:,:,1,:,:].unsqueeze(2) | |
encoder_hidden_states = ( | |
detached_encoder_state if should_detach else trainable_encoder_state | |
) | |
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample | |
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | |
losses.append(loss) | |
# This was most likely single frame training or a single image. | |
if video_length == 1 and i == 0: break | |
loss = losses[0] if len(losses) == 1 else losses[0] + losses[1] | |
return loss, latents | |
for epoch in range(first_epoch, num_train_epochs): | |
train_loss = 0.0 | |
for step, batch in enumerate(train_dataloader): | |
# Skip steps until we reach the resumed step | |
if resume_from_checkpoint and epoch == first_epoch and step < resume_step: | |
if step % gradient_accumulation_steps == 0: | |
progress_bar.update(1) | |
continue | |
with accelerator.accumulate(unet) ,accelerator.accumulate(text_encoder): | |
text_prompt = batch['text_prompt'][0] | |
with accelerator.autocast(): | |
loss, latents = finetune_unet(batch, train_encoder=train_text_encoder) | |
# Gather the losses across all processes for logging (if we use distributed training). | |
avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean() | |
train_loss += avg_loss.item() / gradient_accumulation_steps | |
# Backpropagate | |
try: | |
accelerator.backward(loss) | |
params_to_clip = ( | |
unet.parameters() if not train_text_encoder | |
else | |
list(unet.parameters()) + list(text_encoder.parameters()) | |
) | |
accelerator.clip_grad_norm_(params_to_clip, max_grad_norm) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad(set_to_none=True) | |
except Exception as e: | |
print(f"An error has occured during backpropogation! {e}") | |
continue | |
# Checks if the accelerator has performed an optimization step behind the scenes | |
if accelerator.sync_gradients: | |
progress_bar.update(1) | |
global_step += 1 | |
accelerator.log({"train_loss": train_loss}, step=global_step) | |
train_loss = 0.0 | |
if global_step % checkpointing_steps == 0: | |
save_pipe( | |
pretrained_model_path, | |
global_step, | |
accelerator, | |
unet, | |
text_encoder, | |
vae, | |
output_dir, | |
use_unet_lora, | |
use_text_lora, | |
unet_lora_modules, | |
text_encoder_lora_modules, | |
is_checkpoint=True | |
) | |
if should_sample(global_step, validation_steps, validation_data): | |
if global_step == 1: print("Performing validation prompt.") | |
if accelerator.is_main_process: | |
with accelerator.autocast(): | |
unet.eval() | |
text_encoder.eval() | |
unet_and_text_g_c(unet, text_encoder, False, False) | |
pipeline = TextToVideoSDPipeline.from_pretrained( | |
pretrained_model_path, | |
text_encoder=text_encoder, | |
vae=vae, | |
unet=unet | |
) | |
diffusion_scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) | |
pipeline.scheduler = diffusion_scheduler | |
prompt = text_prompt if len(validation_data.prompt) <= 0 else validation_data.prompt | |
curr_dataset_name = batch['dataset'] | |
save_filename = f"{global_step}_dataset-{curr_dataset_name}_{prompt}" | |
out_file = f"{output_dir}/samples/{save_filename}.mp4" | |
with torch.no_grad(): | |
video_frames = pipeline( | |
prompt, | |
width=validation_data.width, | |
height=validation_data.height, | |
num_frames=validation_data.num_frames, | |
num_inference_steps=validation_data.num_inference_steps, | |
guidance_scale=validation_data.guidance_scale | |
).frames | |
export_to_video(video_frames, out_file, train_data.get('fps', 8)) | |
del pipeline | |
torch.cuda.empty_cache() | |
logger.info(f"Saved a new sample to {out_file}") | |
unet_and_text_g_c( | |
unet, | |
text_encoder, | |
gradient_checkpointing, | |
text_encoder_gradient_checkpointing | |
) | |
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} | |
accelerator.log({"training_loss": loss.detach().item()}, step=step) | |
progress_bar.set_postfix(**logs) | |
if global_step >= max_train_steps: | |
break | |
# Create the pipeline using the trained modules and save it. | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process: | |
save_pipe( | |
pretrained_model_path, | |
global_step, | |
accelerator, | |
unet, | |
text_encoder, | |
vae, | |
output_dir, | |
use_unet_lora, | |
use_text_lora, | |
unet_lora_modules, | |
text_encoder_lora_modules, | |
is_checkpoint=False | |
) | |
accelerator.end_training() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, default="./configs/my_config.yaml") | |
args = parser.parse_args() | |
main(**OmegaConf.load(args.config)) | |