Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Modified from: | |
https://github.com/NVlabs/LSGM/blob/main/training_obj_joint.py | |
""" | |
import copy | |
import functools | |
import json | |
import os | |
from pathlib import Path | |
from pdb import set_trace as st | |
from typing import Any | |
import blobfile as bf | |
import imageio | |
import numpy as np | |
import torch as th | |
import torch.distributed as dist | |
import torchvision | |
from PIL import Image | |
from torch.nn.parallel.distributed import DistributedDataParallel as DDP | |
from torch.optim import AdamW | |
from torch.utils.tensorboard.writer import SummaryWriter | |
from tqdm import tqdm | |
from guided_diffusion import dist_util, logger | |
from guided_diffusion.fp16_util import MixedPrecisionTrainer | |
from guided_diffusion.nn import update_ema | |
from guided_diffusion.resample import LossAwareSampler, UniformSampler | |
# from .train_util import TrainLoop3DRec | |
from guided_diffusion.train_util import (TrainLoop, calc_average_loss, | |
find_ema_checkpoint, | |
find_resume_checkpoint, | |
get_blob_logdir, log_loss_dict, | |
log_rec3d_loss_dict, | |
parse_resume_step_from_filename) | |
from guided_diffusion.gaussian_diffusion import ModelMeanType | |
import dnnlib | |
from dnnlib.util import calculate_adaptive_weight | |
from ..train_util_diffusion import TrainLoop3DDiffusion | |
from ..cvD.nvsD_canoD import TrainLoop3DcvD_nvsD_canoD | |
class TrainLoop3DDiffusionLSGM(TrainLoop3DDiffusion,TrainLoop3DcvD_nvsD_canoD): | |
def __init__(self, *, rec_model, denoise_model, diffusion, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, schedule_sampler=None, weight_decay=0, lr_anneal_steps=0, iterations=10001, ignore_resume_opt=False, freeze_ae=False, denoised_ae=True, triplane_scaling_divider=10, use_amp=False, diffusion_input_size=224, **kwargs): | |
super().__init__(rec_model=rec_model, denoise_model=denoise_model, diffusion=diffusion, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, schedule_sampler=schedule_sampler, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, ignore_resume_opt=ignore_resume_opt, freeze_ae=freeze_ae, denoised_ae=denoised_ae, triplane_scaling_divider=triplane_scaling_divider, use_amp=use_amp, diffusion_input_size=diffusion_input_size, **kwargs) | |
def run_step(self, batch, step='g_step'): | |
if step == 'diffusion_step_rec': | |
self.forward_diffusion(batch, behaviour='diffusion_step_rec') | |
_ = self.mp_trainer_rec.optimize(self.opt_rec) # TODO, update two groups of parameters | |
took_step_ddpm = self.mp_trainer.optimize(self.opt) # TODO, update two groups of parameters | |
if took_step_ddpm: | |
self._update_ema() # g_ema # TODO, ema only needs to track ddpm, remove ema tracking in rec | |
elif step == 'd_step_rec': | |
self.forward_D(batch, behaviour='rec') | |
# _ = self.mp_trainer_cvD.optimize(self.opt_cvD) | |
_ = self.mp_trainer_canonical_cvD.optimize(self.opt_cano_cvD) | |
elif step == 'diffusion_step_nvs': | |
self.forward_diffusion(batch, behaviour='diffusion_step_nvs') | |
_ = self.mp_trainer_rec.optimize(self.opt_rec) # TODO, update two groups of parameters | |
took_step_ddpm = self.mp_trainer.optimize(self.opt) # TODO, update two groups of parameters | |
if took_step_ddpm: | |
self._update_ema() # g_ema | |
elif step == 'd_step_nvs': | |
self.forward_D(batch, behaviour='nvs') | |
_ = self.mp_trainer_cvD.optimize(self.opt_cvD) | |
self._anneal_lr() | |
self.log_step() | |
def run_loop(self): | |
while (not self.lr_anneal_steps | |
or self.step + self.resume_step < self.lr_anneal_steps): | |
# let all processes sync up before starting with a new epoch of training | |
dist_util.synchronize() | |
# batch, cond = next(self.data) | |
# if batch is None: | |
# batch = next(self.data) | |
# self.run_step(batch, 'g_step_rec') | |
batch = next(self.data) | |
self.run_step(batch, step='diffusion_step_rec') | |
batch = next(self.data) | |
self.run_step(batch, 'd_step_rec') | |
# batch = next(self.data) | |
# self.run_step(batch, 'g_step_nvs') | |
batch = next(self.data) | |
self.run_step(batch, step='diffusion_step_nvs') | |
batch = next(self.data) | |
self.run_step(batch, 'd_step_nvs') | |
if self.step % self.log_interval == 0 and dist_util.get_rank( | |
) == 0: | |
out = logger.dumpkvs() | |
# * log to tensorboard | |
for k, v in out.items(): | |
self.writer.add_scalar(f'Loss/{k}', v, | |
self.step + self.resume_step) | |
# if self.step % self.eval_interval == 0 and self.step != 0: | |
if self.step % self.eval_interval == 0: | |
if dist_util.get_rank() == 0: | |
self.eval_loop() | |
# self.eval_novelview_loop() | |
# let all processes sync up before starting with a new epoch of training | |
th.cuda.empty_cache() | |
dist_util.synchronize() | |
if self.step % self.save_interval == 0: | |
self.save(self.mp_trainer, self.mp_trainer.model_name) | |
self.save(self.mp_trainer_rec, self.mp_trainer_rec.model_name) | |
self.save(self.mp_trainer_cvD, 'cvD') | |
self.save(self.mp_trainer_canonical_cvD, 'cano_cvD') | |
dist_util.synchronize() | |
# Run for a finite amount of time in integration tests. | |
if os.environ.get("DIFFUSION_TRAINING_TEST", | |
"") and self.step > 0: | |
return | |
self.step += 1 | |
if self.step > self.iterations: | |
print('reached maximum iterations, exiting') | |
# Save the last checkpoint if it wasn't already saved. | |
if (self.step - 1) % self.save_interval != 0: | |
self.save(self.mp_trainer, self.mp_trainer.model_name) | |
self.save(self.mp_trainer_rec, self.mp_trainer_rec.model_name) | |
self.save(self.mp_trainer_cvD, 'cvD') | |
self.save(self.mp_trainer_canonical_cvD, 'cano_cvD') | |
exit() | |
# Save the last checkpoint if it wasn't already saved. | |
if (self.step - 1) % self.save_interval != 0: | |
self.save() | |
self.save(self.mp_trainer_canonical_cvD, 'cvD') | |
def forward_diffusion(self, batch, behaviour='rec', *args, **kwargs): | |
""" | |
add sds grad to all ae predicted x_0 | |
""" | |
self.ddp_cano_cvD.requires_grad_(False) | |
self.ddp_nvs_cvD.requires_grad_(False) | |
self.ddp_model.requires_grad_(True) | |
self.ddp_rec_model.requires_grad_(True) | |
# if behaviour != 'diff' and 'rec' in behaviour: | |
# if behaviour != 'diff' and 'rec' in behaviour: # pure diffusion step | |
# self.ddp_rec_model.requires_grad_(True) | |
for param in self.ddp_rec_model.module.decoder.triplane_decoder.parameters( # type: ignore | |
): # type: ignore | |
param.requires_grad_(False) # ! disable triplane_decoder grad in each iteration indepenently; | |
# else: | |
self.mp_trainer_rec.zero_grad() | |
self.mp_trainer.zero_grad() | |
# ! no 'sds' step now, both add sds grad back to ViT | |
# assert behaviour != 'sds' | |
# if behaviour == 'sds': | |
# else: | |
# self.ddp_ddpm_model.requires_grad_(True) | |
batch_size = batch['img'].shape[0] | |
for i in range(0, batch_size, self.microbatch): | |
micro = { | |
k: v[i:i + self.microbatch].to(dist_util.dev()) | |
for k, v in batch.items() | |
} | |
last_batch = (i + self.microbatch) >= batch_size | |
vae_nelbo_loss = th.tensor(0.0).to(dist_util.dev()) | |
vision_aided_loss = th.tensor(0.0).to(dist_util.dev()) | |
denoise_loss = th.tensor(0.0).to(dist_util.dev()) | |
d_weight = th.tensor(0.0).to(dist_util.dev()) | |
# =================================== ae part =================================== | |
with th.cuda.amp.autocast(dtype=th.float16, | |
enabled=self.mp_trainer.use_amp | |
and not self.freeze_ae): | |
# apply vae | |
vae_out = self.ddp_rec_model( | |
img=micro['img_to_encoder'], | |
c=micro['c'], | |
behaviour='enc_dec_wo_triplane') # pred: (B, 3, 64, 64) | |
if behaviour == 'diffusion_step_rec': | |
target = micro | |
pred = self.ddp_rec_model(latent=vae_out, | |
c=micro['c'], | |
behaviour='triplane_dec') | |
# vae reconstruction loss | |
if last_batch or not self.use_ddp: | |
vae_nelbo_loss, loss_dict = self.loss_class(pred, | |
target, | |
test_mode=False) | |
else: | |
with self.ddp_model.no_sync(): # type: ignore | |
vae_nelbo_loss, loss_dict = self.loss_class( | |
pred, target, test_mode=False) | |
last_layer = self.ddp_rec_model.module.decoder.triplane_decoder.decoder.net[ # type: ignore | |
-1].weight # type: ignore | |
if 'image_sr' in pred: | |
vision_aided_loss = self.ddp_cano_cvD( | |
0.5 * pred['image_sr'] + | |
0.5 * th.nn.functional.interpolate( | |
pred['image_raw'], | |
size=pred['image_sr'].shape[2:], | |
mode='bilinear'), | |
for_G=True).mean() # [B, 1] shape | |
else: | |
vision_aided_loss = self.ddp_cano_cvD( | |
pred['image_raw'], for_G=True | |
).mean( | |
) # [B, 1] shape | |
d_weight = calculate_adaptive_weight( | |
vae_nelbo_loss, | |
vision_aided_loss, | |
last_layer, | |
# disc_weight_max=1) * 1 | |
disc_weight_max=1) * self.loss_class.opt.rec_cvD_lambda | |
# d_weight = self.loss_class.opt.rec_cvD_lambda # since decoder is fixed here. set to 0.001 | |
vision_aided_loss *= d_weight | |
# d_weight = self.loss_class.opt.rec_cvD_lambda | |
loss_dict.update({ | |
'vision_aided_loss/G_rec': | |
vision_aided_loss, | |
'd_weight_G_rec': | |
d_weight, | |
}) | |
log_rec3d_loss_dict(loss_dict) | |
elif behaviour == 'diffusion_step_nvs': | |
novel_view_c = th.cat([micro['c'][1:], micro['c'][:1]]) | |
pred = self.ddp_rec_model(latent=vae_out, | |
c=novel_view_c, | |
behaviour='triplane_dec') | |
if 'image_sr' in pred: | |
vision_aided_loss = self.ddp_nvs_cvD( | |
# pred_for_rec['image_sr'], | |
0.5 * pred['image_sr'] + | |
0.5 * th.nn.functional.interpolate( | |
pred['image_raw'], | |
size=pred['image_sr'].shape[2:], | |
mode='bilinear'), | |
for_G=True).mean() # [B, 1] shape | |
else: | |
vision_aided_loss = self.ddp_nvs_cvD( | |
pred['image_raw'], for_G=True | |
).mean( | |
) # [B, 1] shape | |
d_weight = self.loss_class.opt.nvs_cvD_lambda | |
vision_aided_loss *= d_weight | |
log_rec3d_loss_dict({ | |
'vision_aided_loss/G_nvs': | |
vision_aided_loss, | |
}) | |
# ae_loss = th.tensor(0.0).to(dist_util.dev()) | |
# elif behaviour == 'diff': | |
# self.ddp_rec_model.requires_grad_(False) | |
# # assert self.ddp_rec_model.module.requires_grad == False, 'freeze ddpm_rec for pure diff step' | |
else: | |
raise NotImplementedError(behaviour) | |
# assert behaviour == 'sds' | |
# pred = None | |
# if behaviour != 'sds': # also train diffusion | |
# assert pred is not None | |
# TODO, train diff and sds together, available? | |
eps = vae_out[self.latent_name] | |
# if behaviour != 'sds': | |
# micro_to_denoise.detach_() | |
eps.requires_grad_(True) # single stage diffusion | |
t, weights = self.schedule_sampler.sample( | |
eps.shape[0], dist_util.dev()) | |
noise = th.randn(size=vae_out.size(), device='cuda') # note that this noise value is currently shared! | |
model_kwargs = {} | |
# ? | |
# or directly use SSD NeRF version? | |
# get diffusion quantities for p (sgm prior) sampling scheme and reweighting for q (vae) | |
# ! handle the sampling | |
# get diffusion quantities for p (sgm prior) sampling scheme and reweighting for q (vae) | |
t_p, var_t_p, m_t_p, obj_weight_t_p, obj_weight_t_q, g2_t_p = \ | |
diffusion.iw_quantities(args.batch_size, args.time_eps, args.iw_sample_p, args.iw_subvp_like_vp_sde) | |
eps_t_p = diffusion.sample_q(vae_out, noise, var_t_p, m_t_p) | |
# in case we want to train q (vae) with another batch using a different sampling scheme for times t | |
if args.iw_sample_q in ['ll_uniform', 'll_iw']: | |
t_q, var_t_q, m_t_q, obj_weight_t_q, _, g2_t_q = \ | |
diffusion.iw_quantities(args.batch_size, args.time_eps, args.iw_sample_q, args.iw_subvp_like_vp_sde) | |
eps_t_q = diffusion.sample_q(vae_out, noise, var_t_q, m_t_q) | |
eps_t_p = eps_t_p.detach().requires_grad_(True) | |
eps_t = th.cat([eps_t_p, eps_t_q], dim=0) | |
var_t = th.cat([var_t_p, var_t_q], dim=0) | |
t = th.cat([t_p, t_q], dim=0) | |
noise = th.cat([noise, noise], dim=0) | |
else: | |
eps_t, m_t, var_t, t, g2_t = eps_t_p, m_t_p, var_t_p, t_p, g2_t_p | |
# run the diffusion | |
# mixing normal trick | |
# TODO, create a new partial training_losses function | |
mixing_component = diffusion.mixing_component(eps_t, var_t, t, enabled=dae.mixed_prediction) # TODO, which should I use? | |
params = utils.get_mixed_prediction(dae.mixed_prediction, pred_params, dae.mixing_logit, mixing_component) | |
# nelbo loss with kl balancing | |
# ! remainign parts of cross entropy in likelihook training | |
cross_entropy_per_var += diffusion.cross_entropy_const(args.time_eps) | |
cross_entropy = th.sum(cross_entropy_per_var, dim=[1, 2, 3]) | |
cross_entropy += remaining_neg_log_p_total # for remaining scales if there is any | |
all_neg_log_p = vae.decompose_eps(cross_entropy_per_var) | |
all_neg_log_p.extend(remaining_neg_log_p_per_ver) # add the remaining neg_log_p | |
kl_all_list, kl_vals_per_group, kl_diag_list = utils.kl_per_group_vada(all_log_q, all_neg_log_p) | |
kl_coeff = 1.0 | |
# ! calculate p/q loss; | |
# ? no spectral regularizer here | |
# ? try adding grid_clip and sn later on. | |
q_loss = th.mean(nelbo_loss) | |
p_loss = th.mean(p_objective) | |
# backpropagate q_loss for vae and update vae params, if trained | |
if args.train_vae: | |
grad_scalar.scale(q_loss).backward(retain_graph=utils.different_p_q_objectives(args.iw_sample_p, args.iw_sample_q)) | |
utils.average_gradients(vae.parameters(), args.distributed) | |
if args.grad_clip_max_norm > 0.: # apply gradient clipping | |
grad_scalar.unscale_(vae_optimizer) | |
th.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=args.grad_clip_max_norm) | |
grad_scalar.step(vae_optimizer) | |
# if we use different p and q objectives or are not training the vae, discard gradients and backpropagate p_loss | |
if utils.different_p_q_objectives(args.iw_sample_p, args.iw_sample_q) or not args.train_vae: | |
if args.train_vae: | |
# discard current gradients computed by weighted loss for VAE | |
dae_optimizer.zero_grad() | |
# compute gradients with unweighted loss | |
grad_scalar.scale(p_loss).backward() | |
# update dae parameters | |
utils.average_gradients(dae.parameters(), args.distributed) | |
if args.grad_clip_max_norm > 0.: # apply gradient clipping | |
grad_scalar.unscale_(dae_optimizer) | |
th.nn.utils.clip_grad_norm_(dae.parameters(), max_norm=args.grad_clip_max_norm) | |
grad_scalar.step(dae_optimizer) | |
# unpack separate objectives, in case we want to train q (vae) using a different sampling scheme for times t | |
if args.iw_sample_q in ['ll_uniform', 'll_iw']: | |
l2_term_p, l2_term_q = th.chunk(l2_term, chunks=2, dim=0) | |
p_objective = th.sum(obj_weight_t_p * l2_term_p, dim=[1, 2, 3]) | |
# cross_entropy_per_var = obj_weight_t_q * l2_term_q | |
else: | |
p_objective = th.sum(obj_weight_t_p * l2_term, dim=[1, 2, 3]) | |
# cross_entropy_per_var = obj_weight_t_q * l2_term | |
# print(micro_to_denoise.min(), micro_to_denoise.max()) | |
compute_losses = functools.partial( | |
self.diffusion.training_losses, | |
self.ddp_model, | |
eps, # x_start | |
t, | |
model_kwargs=model_kwargs, | |
return_detail=True) | |
# ! DDPM step | |
if last_batch or not self.use_ddp: | |
losses = compute_losses() | |
# denoised_out = denoised_fn() | |
else: | |
with self.ddp_model.no_sync(): # type: ignore | |
losses = compute_losses() | |
if isinstance(self.schedule_sampler, LossAwareSampler): | |
self.schedule_sampler.update_with_local_losses( | |
t, losses["loss"].detach()) | |
denoise_loss = (losses["loss"] * weights).mean() | |
x_t = losses.pop('x_t') | |
model_output = losses.pop('model_output') | |
diffusion_target = losses.pop('diffusion_target') | |
alpha_bar = losses.pop('alpha_bar') | |
log_loss_dict(self.diffusion, t, | |
{k: v * weights | |
for k, v in losses.items()}) | |
# if behaviour == 'sds': | |
# ! calculate sds grad, and add to the grad of | |
# if 'rec' in behaviour and self.loss_class.opt.sds_lamdba > 0: # only enable sds along with rec step | |
# w = ( | |
# 1 - alpha_bar**2 | |
# ) / self.triplane_scaling_divider * self.loss_class.opt.sds_lamdba # https://github.com/ashawkey/stable-dreamfusion/issues/106 | |
# sds_grad = denoise_loss.clone().detach( | |
# ) * w # * https://pytorch.org/docs/stable/generated/th.Tensor.detach.html. detach() returned Tensor share the same storage with previous one. add clone() here. | |
# # ae_loss = AddGradient.apply(latent[self.latent_name], sds_grad) # add sds_grad during backward | |
# def sds_hook(grad_to_add): | |
# def modify_grad(grad): | |
# return grad + grad_to_add # add the sds grad to the original grad for BP | |
# return modify_grad | |
# eps[self.latent_name].register_hook( | |
# sds_hook(sds_grad)) # merge sds grad with rec/nvs ae step | |
loss = vae_nelbo_loss + denoise_loss + vision_aided_loss # caluclate loss within AMP | |
# ! cvD loss | |
# exit AMP before backward | |
self.mp_trainer_rec.backward(loss) | |
self.mp_trainer.backward(loss) | |
# TODO, merge visualization with original AE | |
# =================================== denoised AE log part =================================== | |
if dist_util.get_rank() == 0 and self.step % 500 == 0 and behaviour != 'diff': | |
with th.no_grad(): | |
# gt_vis = th.cat([batch['img'], batch['depth']], dim=-1) | |
# st() | |
gt_depth = micro['depth'] | |
if gt_depth.ndim == 3: | |
gt_depth = gt_depth.unsqueeze(1) | |
gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - | |
gt_depth.min()) | |
# if True: | |
pred_depth = pred['image_depth'] | |
pred_depth = (pred_depth - pred_depth.min()) / ( | |
pred_depth.max() - pred_depth.min()) | |
pred_img = pred['image_raw'] | |
gt_img = micro['img'] | |
# if 'image_sr' in pred: # TODO | |
# pred_img = th.cat( | |
# [self.pool_512(pred_img), pred['image_sr']], | |
# dim=-1) | |
# gt_img = th.cat( | |
# [self.pool_512(micro['img']), micro['img_sr']], | |
# dim=-1) | |
# pred_depth = self.pool_512(pred_depth) | |
# gt_depth = self.pool_512(gt_depth) | |
gt_vis = th.cat( | |
[ | |
gt_img, micro['img'], micro['img'], | |
gt_depth.repeat_interleave(3, dim=1) | |
], | |
dim=-1)[0:1] # TODO, fail to load depth. range [0, 1] | |
noised_ae_pred = self.ddp_rec_model( | |
img=None, | |
c=micro['c'][0:1], | |
latent=x_t[0:1] * self. | |
triplane_scaling_divider, # TODO, how to define the scale automatically | |
behaviour=self.render_latent_behaviour) | |
# if denoised_out is None: | |
# if not self.denoised_ae: | |
# denoised_out = denoised_fn() | |
if self.diffusion.model_mean_type == ModelMeanType.START_X: | |
pred_xstart = model_output | |
else: # * used here | |
pred_xstart = self.diffusion._predict_xstart_from_eps( | |
x_t=x_t, t=t, eps=model_output) | |
denoised_ae_pred = self.ddp_rec_model( | |
img=None, | |
c=micro['c'][0:1], | |
latent=pred_xstart[0:1] * self. | |
triplane_scaling_divider, # TODO, how to define the scale automatically? | |
behaviour=self.render_latent_behaviour) | |
# denoised_out = denoised_ae_pred | |
# if not self.denoised_ae: | |
# denoised_ae_pred = self.ddp_rec_model( | |
# img=None, | |
# c=micro['c'][0:1], | |
# latent=denoised_out['pred_xstart'][0:1] * self. | |
# triplane_scaling_divider, # TODO, how to define the scale automatically | |
# behaviour=self.render_latent_behaviour) | |
# else: | |
# assert denoised_ae_pred is not None | |
# denoised_ae_pred['image_raw'] = denoised_ae_pred[ | |
# 'image_raw'][0:1] | |
# print(pred_img.shape) | |
# print('denoised_ae:', self.denoised_ae) | |
pred_vis = th.cat([ | |
pred_img[0:1], noised_ae_pred['image_raw'][0:1], | |
denoised_ae_pred['image_raw'][0:1], | |
pred_depth[0:1].repeat_interleave(3, dim=1) | |
], | |
dim=-1) # B, 3, H, W | |
# s | |
vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute( | |
1, 2, 0).cpu() # ! pred in range[-1, 1] | |
# vis = th.cat([ | |
# self.pool_128(micro['img']), x_t[:, :3, ...], | |
# denoised_out['pred_xstart'][:, :3, ...] | |
# ], | |
# dim=-1)[0].permute( | |
# 1, 2, 0).cpu() # ! pred in range[-1, 1] | |
# vis_grid = torchvision.utils.make_grid(vis) # HWC | |
vis = vis.numpy() * 127.5 + 127.5 | |
vis = vis.clip(0, 255).astype(np.uint8) | |
Image.fromarray(vis).save( | |
f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t[0].item()}_{behaviour}.jpg' | |
) | |
print( | |
'log denoised vis to: ', | |
f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t[0].item()}_{behaviour}.jpg' | |
) | |
th.cuda.empty_cache() | |