|
from .. import WarpCore |
|
from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary |
|
from abc import abstractmethod |
|
from dataclasses import dataclass |
|
import torch |
|
from torch import nn |
|
from torch.utils.data import DataLoader |
|
from gdf import GDF |
|
import numpy as np |
|
from tqdm import tqdm |
|
import wandb |
|
|
|
import webdataset as wds |
|
from webdataset.handlers import warn_and_continue |
|
from torch.distributed import barrier |
|
from enum import Enum |
|
|
|
class TargetReparametrization(Enum): |
|
EPSILON = 'epsilon' |
|
X0 = 'x0' |
|
|
|
class DiffusionCore(WarpCore): |
|
@dataclass(frozen=True) |
|
class Config(WarpCore.Config): |
|
|
|
lr: float = EXPECTED_TRAIN |
|
grad_accum_steps: int = EXPECTED_TRAIN |
|
batch_size: int = EXPECTED_TRAIN |
|
updates: int = EXPECTED_TRAIN |
|
warmup_updates: int = EXPECTED_TRAIN |
|
save_every: int = 500 |
|
backup_every: int = 20000 |
|
use_fsdp: bool = True |
|
|
|
|
|
ema_start_iters: int = None |
|
ema_iters: int = None |
|
ema_beta: float = None |
|
|
|
|
|
gdf_target_reparametrization: TargetReparametrization = None |
|
|
|
@dataclass() |
|
class Info(WarpCore.Info): |
|
ema_loss: float = None |
|
|
|
@dataclass(frozen=True) |
|
class Models(WarpCore.Models): |
|
generator : nn.Module = EXPECTED |
|
generator_ema : nn.Module = None |
|
|
|
@dataclass(frozen=True) |
|
class Optimizers(WarpCore.Optimizers): |
|
generator : any = EXPECTED |
|
|
|
@dataclass(frozen=True) |
|
class Schedulers(WarpCore.Schedulers): |
|
generator: any = None |
|
|
|
@dataclass(frozen=True) |
|
class Extras(WarpCore.Extras): |
|
gdf: GDF = EXPECTED |
|
sampling_configs: dict = EXPECTED |
|
|
|
|
|
info: Info |
|
config: Config |
|
|
|
@abstractmethod |
|
def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: |
|
raise NotImplementedError("This method needs to be overriden") |
|
|
|
@abstractmethod |
|
def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: |
|
raise NotImplementedError("This method needs to be overriden") |
|
|
|
@abstractmethod |
|
def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False): |
|
raise NotImplementedError("This method needs to be overriden") |
|
|
|
@abstractmethod |
|
def webdataset_path(self, extras: Extras): |
|
raise NotImplementedError("This method needs to be overriden") |
|
|
|
@abstractmethod |
|
def webdataset_filters(self, extras: Extras): |
|
raise NotImplementedError("This method needs to be overriden") |
|
|
|
@abstractmethod |
|
def webdataset_preprocessors(self, extras: Extras): |
|
raise NotImplementedError("This method needs to be overriden") |
|
|
|
@abstractmethod |
|
def sample(self, models: Models, data: WarpCore.Data, extras: Extras): |
|
raise NotImplementedError("This method needs to be overriden") |
|
|
|
|
|
def setup_data(self, extras: Extras) -> WarpCore.Data: |
|
|
|
dataset_path = self.webdataset_path(extras) |
|
preprocessors = self.webdataset_preprocessors(extras) |
|
filters = self.webdataset_filters(extras) |
|
|
|
handler = warn_and_continue |
|
|
|
dataset = wds.WebDataset( |
|
dataset_path, resampled=True, handler=handler |
|
).select(filters).shuffle(690, handler=handler).decode( |
|
"pilrgb", handler=handler |
|
).to_tuple( |
|
*[p[0] for p in preprocessors], handler=handler |
|
).map_tuple( |
|
*[p[1] for p in preprocessors], handler=handler |
|
).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)}) |
|
|
|
|
|
real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps) |
|
dataloader = DataLoader( |
|
dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True |
|
) |
|
|
|
return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader)) |
|
|
|
def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models): |
|
batch = next(data.iterator) |
|
|
|
with torch.no_grad(): |
|
conditions = self.get_conditions(batch, models, extras) |
|
latents = self.encode_latents(batch, models, extras) |
|
noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1) |
|
|
|
|
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
pred = models.generator(noised, noise_cond, **conditions) |
|
if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON: |
|
pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] |
|
target = noise |
|
elif self.config.gdf_target_reparametrization == TargetReparametrization.X0: |
|
pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] |
|
target = latents |
|
loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) |
|
loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps |
|
|
|
return loss, loss_adjusted |
|
|
|
def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers): |
|
start_iter = self.info.iter+1 |
|
max_iters = self.config.updates * self.config.grad_accum_steps |
|
if self.is_main_node: |
|
print(f"STARTING AT STEP: {start_iter}/{max_iters}") |
|
|
|
pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) |
|
models.generator.train() |
|
for i in pbar: |
|
|
|
loss, loss_adjusted = self.forward_pass(data, extras, models) |
|
|
|
|
|
if i % self.config.grad_accum_steps == 0 or i == max_iters: |
|
loss_adjusted.backward() |
|
grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0) |
|
optimizers_dict = optimizers.to_dict() |
|
for k in optimizers_dict: |
|
optimizers_dict[k].step() |
|
schedulers_dict = schedulers.to_dict() |
|
for k in schedulers_dict: |
|
schedulers_dict[k].step() |
|
models.generator.zero_grad(set_to_none=True) |
|
self.info.total_steps += 1 |
|
else: |
|
with models.generator.no_sync(): |
|
loss_adjusted.backward() |
|
self.info.iter = i |
|
|
|
|
|
if models.generator_ema is not None and i % self.config.ema_iters == 0: |
|
update_weights_ema( |
|
models.generator_ema, models.generator, |
|
beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0) |
|
) |
|
|
|
|
|
self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 |
|
|
|
if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()): |
|
wandb.alert( |
|
title=f"NaN value encountered in training run {self.info.wandb_run_id}", |
|
text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}", |
|
wait_duration=60*30 |
|
) |
|
|
|
if self.is_main_node: |
|
logs = { |
|
'loss': self.info.ema_loss, |
|
'raw_loss': loss.mean().item(), |
|
'grad_norm': grad_norm.item(), |
|
'lr': optimizers.generator.param_groups[0]['lr'], |
|
'total_steps': self.info.total_steps, |
|
} |
|
|
|
pbar.set_postfix(logs) |
|
if self.config.wandb_project is not None: |
|
wandb.log(logs) |
|
|
|
if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters: |
|
|
|
if np.isnan(loss.mean().item()): |
|
if self.is_main_node and self.config.wandb_project is not None: |
|
tqdm.write("Skipping sampling & checkpoint because the loss is NaN") |
|
wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN") |
|
else: |
|
self.save_checkpoints(models, optimizers) |
|
if self.is_main_node: |
|
create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') |
|
self.sample(models, data, extras) |
|
|
|
def models_to_save(self): |
|
return ['generator', 'generator_ema'] |
|
|
|
def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): |
|
barrier() |
|
suffix = '' if suffix is None else suffix |
|
self.save_info(self.info, suffix=suffix) |
|
models_dict = models.to_dict() |
|
optimizers_dict = optimizers.to_dict() |
|
for key in self.models_to_save(): |
|
model = models_dict[key] |
|
if model is not None: |
|
self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp) |
|
for key in optimizers_dict: |
|
optimizer = optimizers_dict[key] |
|
if optimizer is not None: |
|
self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None) |
|
if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0: |
|
self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k") |
|
torch.cuda.empty_cache() |
|
|