|
import copy |
|
import functools |
|
import json |
|
import os |
|
from pathlib import Path |
|
from pdb import set_trace as st |
|
|
|
import blobfile as bf |
|
import imageio |
|
import numpy as np |
|
import torch as th |
|
import torch.distributed as dist |
|
import torchvision |
|
from PIL import Image |
|
from torch.nn.parallel.distributed import DistributedDataParallel as DDP |
|
from torch.optim import AdamW |
|
from torch.utils.tensorboard.writer import SummaryWriter |
|
from tqdm import tqdm |
|
import matplotlib.pyplot as plt |
|
|
|
from guided_diffusion.gaussian_diffusion import _extract_into_tensor |
|
from guided_diffusion import dist_util, logger |
|
from guided_diffusion.fp16_util import MixedPrecisionTrainer |
|
from guided_diffusion.nn import update_ema |
|
from guided_diffusion.resample import LossAwareSampler, UniformSampler |
|
|
|
from guided_diffusion.train_util import (TrainLoop, calc_average_loss, |
|
find_ema_checkpoint, |
|
find_resume_checkpoint, |
|
get_blob_logdir, log_loss_dict, |
|
log_rec3d_loss_dict, |
|
parse_resume_step_from_filename) |
|
|
|
import dnnlib |
|
|
|
from nsr.camera_utils import FOV_to_intrinsics, LookAtPoseSampler |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainLoopDiffusionWithRec(TrainLoop): |
|
"""an interface with rec_model required apis |
|
""" |
|
|
|
def __init__( |
|
self, |
|
*, |
|
model, |
|
diffusion, |
|
loss_class, |
|
data, |
|
eval_data, |
|
batch_size, |
|
microbatch, |
|
lr, |
|
ema_rate, |
|
log_interval, |
|
eval_interval, |
|
save_interval, |
|
resume_checkpoint, |
|
use_fp16=False, |
|
fp16_scale_growth=0.001, |
|
weight_decay=0, |
|
lr_anneal_steps=0, |
|
iterations=10001, |
|
triplane_scaling_divider=1, |
|
use_amp=False, |
|
diffusion_input_size=224, |
|
schedule_sampler=None, |
|
model_name='ddpm', |
|
**kwargs, |
|
): |
|
super().__init__( |
|
model=model, |
|
diffusion=diffusion, |
|
data=data, |
|
batch_size=batch_size, |
|
microbatch=microbatch, |
|
lr=lr, |
|
ema_rate=ema_rate, |
|
log_interval=log_interval, |
|
save_interval=save_interval, |
|
resume_checkpoint=resume_checkpoint, |
|
use_fp16=use_fp16, |
|
fp16_scale_growth=fp16_scale_growth, |
|
schedule_sampler=schedule_sampler, |
|
weight_decay=weight_decay, |
|
lr_anneal_steps=lr_anneal_steps, |
|
use_amp=use_amp, |
|
model_name=model_name, |
|
**kwargs, |
|
) |
|
|
|
self.latent_name = 'latent_normalized' |
|
self.diffusion_input_size = diffusion_input_size |
|
self.render_latent_behaviour = 'triplane_dec' |
|
|
|
self.loss_class = loss_class |
|
|
|
self.eval_interval = eval_interval |
|
self.eval_data = eval_data |
|
self.iterations = iterations |
|
|
|
self.triplane_scaling_divider = triplane_scaling_divider |
|
|
|
if dist_util.get_rank() == 0: |
|
self.writer = SummaryWriter(log_dir=f'{logger.get_dir()}/runs') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@th.inference_mode() |
|
def render_video_given_triplane(self, |
|
planes, |
|
rec_model, |
|
name_prefix='0', |
|
save_img=False, |
|
render_reference=None, |
|
export_mesh=False, |
|
render_all=False): |
|
|
|
planes *= self.triplane_scaling_divider |
|
|
|
|
|
|
|
batch_size = planes.shape[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if planes.shape[1] == 16: |
|
ddpm_latent = { |
|
self.latent_name: planes[:, :12], |
|
'bg_plane': planes[:, 12:16], |
|
} |
|
else: |
|
ddpm_latent = { |
|
self.latent_name: planes, |
|
} |
|
|
|
ddpm_latent.update( |
|
rec_model(latent=ddpm_latent, |
|
behaviour='decode_after_vae_no_render')) |
|
|
|
if export_mesh: |
|
|
|
|
|
mesh_size = 256 |
|
|
|
|
|
mesh_thres = 10 |
|
import mcubes |
|
import trimesh |
|
dump_path = f'{logger.get_dir()}/mesh/' |
|
|
|
os.makedirs(dump_path, exist_ok=True) |
|
|
|
grid_out = rec_model( |
|
latent=ddpm_latent, |
|
grid_size=mesh_size, |
|
behaviour='triplane_decode_grid', |
|
) |
|
|
|
vtx, faces = mcubes.marching_cubes( |
|
grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), |
|
mesh_thres) |
|
vtx = vtx / (mesh_size - 1) * 2 - 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
mesh = trimesh.Trimesh( |
|
vertices=vtx, |
|
faces=faces, |
|
) |
|
|
|
mesh_dump_path = os.path.join(dump_path, f'{name_prefix}.ply') |
|
mesh.export(mesh_dump_path, 'ply') |
|
|
|
print(f"Mesh dumped to {dump_path}") |
|
del grid_out, mesh |
|
th.cuda.empty_cache() |
|
|
|
|
|
video_out = imageio.get_writer( |
|
f'{logger.get_dir()}/triplane_{name_prefix}.mp4', |
|
mode='I', |
|
fps=15, |
|
codec='libx264') |
|
|
|
if planes.shape[1] == 16: |
|
ddpm_latent = { |
|
self.latent_name: planes[:, :12], |
|
'bg_plane': planes[:, 12:16], |
|
} |
|
else: |
|
ddpm_latent = { |
|
self.latent_name: planes, |
|
} |
|
|
|
ddpm_latent.update( |
|
rec_model(latent=ddpm_latent, |
|
behaviour='decode_after_vae_no_render')) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if render_reference is None: |
|
render_reference = self.eval_data |
|
else: |
|
for key in ['ins', 'bbox', 'caption']: |
|
if key in render_reference: |
|
render_reference.pop(key) |
|
|
|
|
|
|
|
|
|
if render_all: |
|
render_reference = [{ |
|
k: v[idx:idx + 1] |
|
for k, v in render_reference.items() |
|
} for idx in range(render_reference['c'].shape[0])] |
|
else: |
|
render_reference = [{ |
|
k: v[idx:idx + 1] |
|
for k, v in render_reference.items() |
|
} for idx in range(40)] |
|
|
|
|
|
for i, batch in enumerate(tqdm(render_reference)): |
|
micro = { |
|
k: v.to(dist_util.dev()) if isinstance(v, th.Tensor) else v |
|
for k, v in batch.items() |
|
} |
|
|
|
|
|
|
|
pred = rec_model( |
|
img=None, |
|
c=micro['c'], |
|
latent=ddpm_latent, |
|
|
|
|
|
|
|
|
|
|
|
behaviour='triplane_dec') |
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_depth = pred['image_depth'] |
|
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - |
|
pred_depth.min()) |
|
|
|
|
|
pred_depth = pred_depth.cpu()[0].permute(1, 2, 0).numpy() |
|
pred_depth = (plt.cm.viridis(pred_depth[..., 0])[..., :3]) * 2 - 1 |
|
pred_depth = th.from_numpy(pred_depth).to( |
|
pred['image_raw'].device).permute(2, 0, 1).unsqueeze(0) |
|
|
|
|
|
if 'image_sr' in pred: |
|
|
|
gen_img = pred['image_sr'] |
|
|
|
if pred['image_sr'].shape[-1] == 512: |
|
|
|
pred_vis = th.cat([ |
|
micro['img_sr'], |
|
self.pool_512(pred['image_raw']), gen_img, |
|
self.pool_512(pred_depth).repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
|
|
elif pred['image_sr'].shape[-1] == 128: |
|
|
|
pred_vis = th.cat([ |
|
micro['img_sr'], |
|
self.pool_128(pred['image_raw']), pred['image_sr'], |
|
self.pool_128(pred_depth).repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
|
|
else: |
|
gen_img = pred['image_raw'] |
|
|
|
pred_vis = th.cat( |
|
[ |
|
|
|
self.pool_128(gen_img), |
|
|
|
self.pool_128(pred_depth) |
|
], |
|
dim=-1) |
|
|
|
if save_img: |
|
for batch_idx in range(gen_img.shape[0]): |
|
sampled_img = Image.fromarray( |
|
(gen_img[batch_idx].permute(1, 2, 0).cpu().numpy() * |
|
127.5 + 127.5).clip(0, 255).astype(np.uint8)) |
|
if sampled_img.size != (512, 512): |
|
sampled_img = sampled_img.resize( |
|
(128, 128), Image.HAMMING) |
|
sampled_img.save(logger.get_dir() + |
|
'/FID_Cals/{}_{}.png'.format( |
|
int(name_prefix) * batch_size + |
|
batch_idx, i)) |
|
|
|
|
|
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() |
|
vis = vis * 127.5 + 127.5 |
|
vis = vis.clip(0, 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
|
|
|
|
|
for j in range(vis.shape[0] |
|
): |
|
video_out.append_data(vis[j]) |
|
|
|
|
|
video_out.close() |
|
del video_out |
|
print('logged video to: ', |
|
f'{logger.get_dir()}/triplane_{name_prefix}.mp4') |
|
|
|
del vis, pred_vis, micro, pred, |
|
|
|
def _init_optim_groups(self, rec_model, freeze_decoder=False): |
|
"""for initializing the reconstruction model; fixing decoder part. |
|
""" |
|
kwargs = self.kwargs |
|
optim_groups = [ |
|
|
|
{ |
|
'name': 'vit_encoder', |
|
'params': rec_model.encoder.parameters(), |
|
'lr': kwargs['encoder_lr'], |
|
'weight_decay': kwargs['encoder_weight_decay'] |
|
}, |
|
] |
|
|
|
if not freeze_decoder: |
|
optim_groups += [ |
|
|
|
{ |
|
'name': 'vit_decoder', |
|
'params': rec_model.decoder.vit_decoder.parameters(), |
|
'lr': kwargs['vit_decoder_lr'], |
|
'weight_decay': kwargs['vit_decoder_wd'] |
|
}, |
|
{ |
|
'name': 'vit_decoder_pred', |
|
'params': rec_model.decoder.decoder_pred.parameters(), |
|
'lr': kwargs['vit_decoder_lr'], |
|
|
|
'weight_decay': kwargs['vit_decoder_wd'] |
|
}, |
|
|
|
|
|
{ |
|
'name': 'triplane_decoder', |
|
'params': rec_model.decoder.triplane_decoder.parameters(), |
|
'lr': kwargs['triplane_decoder_lr'], |
|
|
|
}, |
|
] |
|
|
|
if rec_model.decoder.superresolution is not None: |
|
optim_groups.append({ |
|
'name': |
|
'triplane_decoder_superresolution', |
|
'params': |
|
rec_model.decoder.superresolution.parameters(), |
|
'lr': |
|
kwargs['super_resolution_lr'], |
|
}) |
|
|
|
return optim_groups |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@th.no_grad() |
|
|
|
def eval_novelview_loop(self, rec_model): |
|
|
|
video_out = imageio.get_writer( |
|
f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}.mp4', |
|
mode='I', |
|
fps=60, |
|
codec='libx264') |
|
|
|
all_loss_dict = [] |
|
novel_view_micro = {} |
|
|
|
|
|
for i, batch in enumerate(tqdm(self.eval_data)): |
|
|
|
|
|
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} |
|
|
|
if i == 0: |
|
novel_view_micro = { |
|
k: |
|
v[0:1].to(dist_util.dev()).repeat_interleave( |
|
micro['img'].shape[0], 0) |
|
for k, v in batch.items() |
|
} |
|
else: |
|
|
|
novel_view_micro = { |
|
k: |
|
v[0:1].to(dist_util.dev()).repeat_interleave( |
|
micro['img'].shape[0], 0) |
|
for k, v in novel_view_micro.items() |
|
} |
|
|
|
pred = rec_model(img=novel_view_micro['img_to_encoder'], |
|
c=micro['c']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_, loss_dict = self.loss_class(pred, micro, test_mode=True) |
|
all_loss_dict.append(loss_dict) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_depth = pred['image_depth'] |
|
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - |
|
pred_depth.min()) |
|
if 'image_sr' in pred: |
|
|
|
if pred['image_sr'].shape[-1] == 512: |
|
|
|
pred_vis = th.cat([ |
|
micro['img_sr'], |
|
self.pool_512(pred['image_raw']), pred['image_sr'], |
|
self.pool_512(pred_depth).repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
|
|
elif pred['image_sr'].shape[-1] == 256: |
|
|
|
pred_vis = th.cat([ |
|
micro['img_sr'], |
|
self.pool_256(pred['image_raw']), pred['image_sr'], |
|
self.pool_256(pred_depth).repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
|
|
else: |
|
pred_vis = th.cat([ |
|
micro['img_sr'], |
|
self.pool_128(pred['image_raw']), |
|
self.pool_128(pred['image_sr']), |
|
self.pool_128(pred_depth).repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_vis = th.cat([ |
|
self.pool_128(micro['img']), |
|
self.pool_128(pred['image_raw']), |
|
self.pool_128(pred_depth).repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
|
|
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() |
|
vis = vis * 127.5 + 127.5 |
|
vis = vis.clip(0, 255).astype(np.uint8) |
|
|
|
for j in range(vis.shape[0]): |
|
video_out.append_data(vis[j]) |
|
|
|
video_out.close() |
|
|
|
val_scores_for_logging = calc_average_loss(all_loss_dict) |
|
with open(os.path.join(logger.get_dir(), 'scores_novelview.json'), |
|
'a') as f: |
|
json.dump({'step': self.step, **val_scores_for_logging}, f) |
|
|
|
|
|
for k, v in val_scores_for_logging.items(): |
|
self.writer.add_scalar(f'Eval/NovelView/{k}', v, |
|
self.step + self.resume_step) |
|
del video_out |
|
|
|
|
|
|
|
th.cuda.empty_cache() |
|
|
|
@th.no_grad() |
|
def eval_loop(self, rec_model): |
|
|
|
video_out = imageio.get_writer( |
|
f'{logger.get_dir()}/video_{self.step+self.resume_step}.mp4', |
|
mode='I', |
|
fps=60, |
|
codec='libx264') |
|
all_loss_dict = [] |
|
|
|
|
|
for i, batch in enumerate(tqdm(self.eval_data)): |
|
|
|
|
|
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} |
|
|
|
|
|
|
|
|
|
|
|
pred = rec_model(img=micro['img_to_encoder'], |
|
c=micro['c']) |
|
|
|
pred_depth = pred['image_depth'] |
|
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - |
|
pred_depth.min()) |
|
|
|
if 'image_sr' in pred: |
|
if pred['image_sr'].shape[-1] == 512: |
|
pred_vis = th.cat([ |
|
micro['img_sr'], |
|
self.pool_512(pred['image_raw']), pred['image_sr'], |
|
self.pool_512(pred_depth).repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
else: |
|
assert pred['image_sr'].shape[-1] == 128 |
|
pred_vis = th.cat([ |
|
micro['img_sr'], |
|
self.pool_128(pred['image_raw']), pred['image_sr'], |
|
self.pool_128(pred_depth).repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
else: |
|
pred_vis = th.cat([ |
|
self.pool_128(micro['img']), |
|
self.pool_128(pred['image_raw']), |
|
self.pool_128(pred_depth).repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
|
|
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() |
|
vis = vis * 127.5 + 127.5 |
|
vis = vis.clip(0, 255).astype(np.uint8) |
|
|
|
for j in range(vis.shape[0]): |
|
video_out.append_data(vis[j]) |
|
|
|
video_out.close() |
|
|
|
val_scores_for_logging = calc_average_loss(all_loss_dict) |
|
with open(os.path.join(logger.get_dir(), 'scores.json'), 'a') as f: |
|
json.dump({'step': self.step, **val_scores_for_logging}, f) |
|
|
|
|
|
for k, v in val_scores_for_logging.items(): |
|
self.writer.add_scalar(f'Eval/Rec/{k}', v, |
|
self.step + self.resume_step) |
|
|
|
del video_out, vis, pred_vis, pred |
|
th.cuda.empty_cache() |
|
self.eval_novelview_loop(rec_model) |
|
|
|
def save(self, mp_trainer=None, model_name='ddpm'): |
|
if mp_trainer is None: |
|
mp_trainer = self.mp_trainer |
|
|
|
def save_checkpoint(rate, params): |
|
state_dict = mp_trainer.master_params_to_state_dict(params) |
|
if dist_util.get_rank() == 0: |
|
logger.log(f"saving model {model_name} {rate}...") |
|
if not rate: |
|
filename = f"model_{model_name}{(self.step+self.resume_step):07d}.pt" |
|
else: |
|
filename = f"ema_{model_name}_{rate}_{(self.step+self.resume_step):07d}.pt" |
|
with bf.BlobFile(bf.join(get_blob_logdir(), filename), |
|
"wb") as f: |
|
th.save(state_dict, f) |
|
|
|
|
|
save_checkpoint(0, mp_trainer.master_params) |
|
if model_name == 'ddpm': |
|
for rate, params in zip(self.ema_rate, self.ema_params): |
|
save_checkpoint(rate, params) |
|
|
|
th.cuda.empty_cache() |
|
dist_util.synchronize() |
|
|
|
def _load_and_sync_parameters(self, |
|
model=None, |
|
model_name='ddpm', |
|
resume_checkpoint=None): |
|
if resume_checkpoint is None: |
|
resume_checkpoint, self.resume_step = find_resume_checkpoint( |
|
self.resume_checkpoint, model_name) or self.resume_checkpoint |
|
|
|
if model is None: |
|
model = self.model |
|
|
|
if resume_checkpoint and Path(resume_checkpoint).exists(): |
|
if dist_util.get_rank() == 0: |
|
|
|
logger.log( |
|
f"loading model from checkpoint: {resume_checkpoint}...") |
|
map_location = { |
|
'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank() |
|
} |
|
|
|
logger.log(f'mark {model_name} loading ') |
|
resume_state_dict = dist_util.load_state_dict( |
|
resume_checkpoint, map_location=map_location) |
|
logger.log(f'mark {model_name} loading finished') |
|
|
|
model_state_dict = model.state_dict() |
|
|
|
for k, v in resume_state_dict.items(): |
|
if k in model_state_dict.keys() and v.size( |
|
) == model_state_dict[k].size(): |
|
model_state_dict[k] = v |
|
|
|
else: |
|
print( |
|
'!!!! ignore key: ', |
|
k, |
|
": ", |
|
v.size(), |
|
) |
|
if k in model_state_dict: |
|
print('shape in model: ', |
|
model_state_dict[k].size()) |
|
else: |
|
print(k, ' not in model') |
|
|
|
model.load_state_dict(model_state_dict, strict=True) |
|
del model_state_dict |
|
else: |
|
logger.log(f'{resume_checkpoint} not found.') |
|
|
|
|
|
if dist_util.get_world_size() > 1: |
|
dist_util.sync_params(model.parameters()) |
|
|
|
print(f'synced {model_name} params') |
|
|
|
@th.inference_mode() |
|
def apply_model_inference(self, |
|
x_noisy, |
|
t, |
|
c=None, |
|
model_kwargs={}): |
|
|
|
pred_params = self.ddp_model(x_noisy, t, |
|
**model_kwargs) |
|
return pred_params |
|
|
|
@th.inference_mode() |
|
def eval_ddpm_sample(self, rec_model, **kwargs): |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
|
args = dnnlib.EasyDict( |
|
dict( |
|
batch_size=1, |
|
|
|
image_size=self.diffusion_input_size, |
|
|
|
|
|
denoise_in_channels=self.ddpm_model. |
|
in_channels, |
|
clip_denoised=False, |
|
class_cond=False, |
|
use_ddim=False)) |
|
|
|
model_kwargs = {} |
|
|
|
if args.class_cond: |
|
classes = th.randint(low=0, |
|
high=NUM_CLASSES, |
|
size=(args.batch_size, ), |
|
device=dist_util.dev()) |
|
model_kwargs["y"] = classes |
|
|
|
diffusion = self.diffusion |
|
sample_fn = (diffusion.p_sample_loop |
|
if not args.use_ddim else diffusion.ddim_sample_loop) |
|
|
|
|
|
for i in range(1): |
|
triplane_sample = sample_fn( |
|
|
|
self, |
|
(args.batch_size, args.denoise_in_channels, |
|
self.diffusion_input_size, self.diffusion_input_size), |
|
clip_denoised=args.clip_denoised, |
|
|
|
mixing_normal=True, |
|
device=dist_util.dev(), |
|
|
|
**model_kwargs) |
|
|
|
th.cuda.empty_cache() |
|
self.render_video_given_triplane( |
|
triplane_sample, |
|
rec_model, |
|
name_prefix=f'{self.step + self.resume_step}_{i}') |
|
th.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
self.model.train() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@th.inference_mode() |
|
def render_video_noise_schedule(self, name_prefix='0'): |
|
|
|
|
|
|
|
video_out = imageio.get_writer( |
|
f'{logger.get_dir()}/triplane_visnoise_{name_prefix}.mp4', |
|
mode='I', |
|
fps=30, |
|
codec='libx264') |
|
|
|
for i, batch in enumerate(tqdm(self.eval_data)): |
|
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} |
|
|
|
if i % 10 != 0: |
|
continue |
|
|
|
|
|
if i == 0: |
|
novel_view_micro = { |
|
k: |
|
v[0:1].to(dist_util.dev()).repeat_interleave( |
|
micro['img'].shape[0], 0) |
|
for k, v in batch.items() |
|
} |
|
else: |
|
|
|
novel_view_micro = { |
|
k: |
|
v[0:1].to(dist_util.dev()).repeat_interleave( |
|
micro['img'].shape[0], 0) |
|
for k, v in novel_view_micro.items() |
|
} |
|
|
|
latent = self.ddp_rec_model( |
|
img=novel_view_micro['img_to_encoder'], |
|
c=micro['c'])[self.latent_name] |
|
|
|
x_start = latent / self.triplane_scaling_divider |
|
|
|
|
|
all_pred_vis = [] |
|
|
|
|
|
|
|
|
|
|
|
for t in th.range(0, |
|
1001, |
|
125, |
|
dtype=th.long, |
|
device=dist_util.dev()): |
|
|
|
|
|
noise = th.randn_like(x_start) |
|
x_t = self.diffusion.q_sample( |
|
x_start, t, noise=noise |
|
) |
|
planes_x_t = (x_t * self.triplane_scaling_divider).clamp( |
|
-50, 50) |
|
|
|
|
|
|
|
|
|
|
|
pred = self.ddp_rec_model( |
|
img=None, |
|
c=micro['c'], |
|
latent=planes_x_t, |
|
behaviour=self.render_latent_behaviour |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_vis = pred['image_raw'] |
|
|
|
all_pred_vis.append(pred_vis) |
|
|
|
|
|
all_pred_vis = torchvision.utils.make_grid( |
|
th.cat(all_pred_vis, 0), |
|
nrow=len(all_pred_vis), |
|
normalize=True, |
|
value_range=(-1, 1), |
|
scale_each=True) |
|
|
|
vis = all_pred_vis.permute(1, 2, 0).cpu().numpy() |
|
|
|
vis = (vis * 255).clip(0, 255).astype(np.uint8) |
|
|
|
video_out.append_data(vis) |
|
|
|
video_out.close() |
|
print('logged video to: ', |
|
f'{logger.get_dir()}/triplane_visnoise_{name_prefix}.mp4') |
|
|
|
th.cuda.empty_cache() |
|
|
|
@th.inference_mode() |
|
def plot_noise_nsr_curve(self, name_prefix='0'): |
|
|
|
|
|
for i, batch in enumerate(tqdm(self.eval_data)): |
|
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} |
|
|
|
if i % 10 != 0: |
|
continue |
|
|
|
|
|
latent = self.ddp_rec_model( |
|
img=micro['img_to_encoder'], |
|
c=micro['c'], |
|
behaviour='enc_dec_wo_triplane') |
|
|
|
x_start = latent[ |
|
self. |
|
latent_name] / self.triplane_scaling_divider |
|
|
|
snr_list = [] |
|
snr_wo_data_list = [] |
|
xt_mean = [] |
|
xt_std = [] |
|
|
|
for t in th.range(0, |
|
1001, |
|
5, |
|
dtype=th.long, |
|
device=dist_util.dev()): |
|
|
|
|
|
noise = th.randn_like(x_start) |
|
|
|
beta_t = _extract_into_tensor( |
|
self.diffusion.sqrt_alphas_cumprod, t, x_start.shape) |
|
one_minus_beta_t = _extract_into_tensor( |
|
self.diffusion.sqrt_one_minus_alphas_cumprod, t, |
|
x_start.shape) |
|
|
|
signal_t = beta_t * x_start |
|
noise_t = one_minus_beta_t * noise |
|
|
|
x_t = signal_t + noise_t |
|
|
|
snr = signal_t / (noise_t + 1e-6) |
|
snr_wo_data = beta_t / (one_minus_beta_t + 1e-6) |
|
|
|
snr_list.append(abs(snr).mean().cpu().numpy()) |
|
snr_wo_data_list.append(abs(snr_wo_data).mean().cpu().numpy()) |
|
xt_mean.append(x_t.mean().cpu().numpy()) |
|
xt_std.append(x_t.std().cpu().numpy()) |
|
|
|
print('xt_mean', xt_mean) |
|
print('xt_std', xt_std) |
|
print('snr', snr_list) |
|
|
|
th.save( |
|
{ |
|
'xt_mean': xt_mean, |
|
'xt_std': xt_std, |
|
'snr': snr_list, |
|
'snr_wo_data': snr_wo_data_list, |
|
}, |
|
Path(logger.get_dir()) / f'snr_{i}.pt') |
|
|
|
th.cuda.empty_cache() |
|
|
|
|
|
|
|
class TrainLoop3DDiffusion(TrainLoopDiffusionWithRec): |
|
|
|
def __init__( |
|
self, |
|
*, |
|
|
|
rec_model, |
|
denoise_model, |
|
diffusion, |
|
loss_class, |
|
data, |
|
eval_data, |
|
batch_size, |
|
microbatch, |
|
lr, |
|
ema_rate, |
|
log_interval, |
|
eval_interval, |
|
save_interval, |
|
resume_checkpoint, |
|
use_fp16=False, |
|
fp16_scale_growth=0.001, |
|
schedule_sampler=None, |
|
weight_decay=0, |
|
lr_anneal_steps=0, |
|
iterations=10001, |
|
ignore_resume_opt=False, |
|
freeze_ae=False, |
|
denoised_ae=True, |
|
triplane_scaling_divider=10, |
|
use_amp=False, |
|
diffusion_input_size=224, |
|
**kwargs): |
|
|
|
super().__init__( |
|
model=denoise_model, |
|
diffusion=diffusion, |
|
loss_class=loss_class, |
|
data=data, |
|
eval_data=eval_data, |
|
batch_size=batch_size, |
|
microbatch=microbatch, |
|
lr=lr, |
|
ema_rate=ema_rate, |
|
log_interval=log_interval, |
|
eval_interval=eval_interval, |
|
save_interval=save_interval, |
|
resume_checkpoint=resume_checkpoint, |
|
use_fp16=use_fp16, |
|
fp16_scale_growth=fp16_scale_growth, |
|
weight_decay=weight_decay, |
|
lr_anneal_steps=lr_anneal_steps, |
|
iterations=iterations, |
|
triplane_scaling_divider=triplane_scaling_divider, |
|
use_amp=use_amp, |
|
diffusion_input_size=diffusion_input_size, |
|
schedule_sampler=schedule_sampler, |
|
) |
|
|
|
|
|
|
|
self._load_and_sync_parameters(model=self.rec_model, model_name='rec') |
|
|
|
|
|
self.mp_trainer_rec = MixedPrecisionTrainer( |
|
model=self.rec_model, |
|
use_fp16=self.use_fp16, |
|
use_amp=use_amp, |
|
fp16_scale_growth=fp16_scale_growth, |
|
model_name='rec', |
|
) |
|
self.denoised_ae = denoised_ae |
|
|
|
if not freeze_ae: |
|
self.opt_rec = AdamW( |
|
self._init_optim_groups(self.mp_trainer_rec.model)) |
|
else: |
|
print('!! freezing AE !!') |
|
|
|
|
|
if self.resume_step: |
|
if not ignore_resume_opt: |
|
self._load_optimizer_state() |
|
else: |
|
logger.warn("Ignoring optimizer state from checkpoint.") |
|
|
|
self.ema_params_rec = [ |
|
self._load_ema_parameters( |
|
rate, |
|
self.rec_model, |
|
self.mp_trainer_rec, |
|
model_name=self.mp_trainer_rec.model_name) |
|
for rate in self.ema_rate |
|
] |
|
else: |
|
if not freeze_ae: |
|
self.ema_params_rec = [ |
|
copy.deepcopy(self.mp_trainer_rec.master_params) |
|
for _ in range(len(self.ema_rate)) |
|
] |
|
|
|
if self.use_ddp is True: |
|
self.rec_model = th.nn.SyncBatchNorm.convert_sync_batchnorm( |
|
self.rec_model) |
|
self.ddp_rec_model = DDP( |
|
self.rec_model, |
|
device_ids=[dist_util.dev()], |
|
output_device=dist_util.dev(), |
|
broadcast_buffers=False, |
|
bucket_cap_mb=128, |
|
find_unused_parameters=False, |
|
|
|
) |
|
else: |
|
self.ddp_rec_model = self.rec_model |
|
|
|
if freeze_ae: |
|
self.ddp_rec_model.eval() |
|
self.ddp_rec_model.requires_grad_(False) |
|
self.freeze_ae = freeze_ae |
|
|
|
|
|
|
|
def _update_ema_rec(self): |
|
for rate, params in zip(self.ema_rate, self.ema_params_rec): |
|
update_ema(params, self.mp_trainer_rec.master_params, rate=rate) |
|
|
|
def run_loop(self, batch=None): |
|
th.cuda.empty_cache() |
|
while (not self.lr_anneal_steps |
|
or self.step + self.resume_step < self.lr_anneal_steps): |
|
|
|
|
|
dist_util.synchronize() |
|
|
|
|
|
if self.step % self.eval_interval == 0: |
|
if dist_util.get_rank() == 0: |
|
self.eval_ddpm_sample(self.ddp_rec_model) |
|
|
|
|
|
|
|
|
|
dist_util.synchronize() |
|
th.cuda.empty_cache() |
|
|
|
batch = next(self.data) |
|
self.run_step(batch) |
|
if self.step % self.log_interval == 0 and dist_util.get_rank( |
|
) == 0: |
|
out = logger.dumpkvs() |
|
|
|
for k, v in out.items(): |
|
self.writer.add_scalar(f'Loss/{k}', v, |
|
self.step + self.resume_step) |
|
|
|
if self.step % self.save_interval == 0 and self.step != 0: |
|
self.save() |
|
if not self.freeze_ae: |
|
self.save(self.mp_trainer_rec, 'rec') |
|
dist_util.synchronize() |
|
|
|
th.cuda.empty_cache() |
|
|
|
if os.environ.get("DIFFUSION_TRAINING_TEST", |
|
"") and self.step > 0: |
|
return |
|
|
|
self.step += 1 |
|
|
|
if self.step > self.iterations: |
|
print('reached maximum iterations, exiting') |
|
|
|
|
|
if (self.step - 1) % self.save_interval != 0: |
|
self.save() |
|
if not self.freeze_ae: |
|
self.save(self.mp_trainer_rec, 'rec') |
|
|
|
exit() |
|
|
|
|
|
if (self.step - 1) % self.save_interval != 0: |
|
self.save() |
|
if not self.freeze_ae: |
|
self.save(self.mp_trainer_rec, 'rec') |
|
|
|
def run_step(self, batch, cond=None): |
|
self.forward_backward(batch, |
|
cond) |
|
took_step_ddpm = self.mp_trainer.optimize(self.opt) |
|
if took_step_ddpm: |
|
self._update_ema() |
|
|
|
if not self.freeze_ae: |
|
took_step_rec = self.mp_trainer_rec.optimize(self.opt_rec) |
|
if took_step_rec: |
|
self._update_ema_rec() |
|
|
|
self._anneal_lr() |
|
self.log_step() |
|
|
|
def forward_backward(self, batch, *args, **kwargs): |
|
|
|
self.mp_trainer.zero_grad() |
|
|
|
batch_size = batch['img'].shape[0] |
|
|
|
for i in range(0, batch_size, self.microbatch): |
|
|
|
micro = { |
|
k: v[i:i + self.microbatch].to(dist_util.dev()) |
|
for k, v in batch.items() |
|
} |
|
|
|
last_batch = (i + self.microbatch) >= batch_size |
|
|
|
|
|
|
|
|
|
with th.cuda.amp.autocast(dtype=th.float16, |
|
enabled=self.mp_trainer_rec.use_amp |
|
and not self.freeze_ae): |
|
|
|
|
|
|
|
latent = self.ddp_rec_model( |
|
img=micro['img_to_encoder'], |
|
c=micro['c'], |
|
behaviour='enc_dec_wo_triplane') |
|
|
|
if not self.freeze_ae: |
|
target = micro |
|
pred = self.rec_model(latent=latent, |
|
c=micro['c'], |
|
behaviour='triplane_dec') |
|
|
|
if last_batch or not self.use_ddp: |
|
ae_loss, loss_dict = self.loss_class(pred, |
|
target, |
|
test_mode=False) |
|
else: |
|
with self.ddp_model.no_sync(): |
|
ae_loss, loss_dict = self.loss_class( |
|
pred, target, test_mode=False) |
|
|
|
log_rec3d_loss_dict(loss_dict) |
|
else: |
|
ae_loss = th.tensor(0.0).to(dist_util.dev()) |
|
|
|
|
|
|
|
micro_to_denoise = latent[ |
|
self. |
|
latent_name] / self.triplane_scaling_divider |
|
|
|
t, weights = self.schedule_sampler.sample( |
|
micro_to_denoise.shape[0], dist_util.dev()) |
|
|
|
model_kwargs = {} |
|
|
|
|
|
compute_losses = functools.partial( |
|
self.diffusion.training_losses, |
|
self.ddp_model, |
|
micro_to_denoise, |
|
t, |
|
model_kwargs=model_kwargs, |
|
) |
|
|
|
with th.cuda.amp.autocast(dtype=th.float16, |
|
enabled=self.mp_trainer.use_amp): |
|
|
|
if last_batch or not self.use_ddp: |
|
losses = compute_losses() |
|
|
|
else: |
|
with self.ddp_model.no_sync(): |
|
losses = compute_losses() |
|
|
|
if isinstance(self.schedule_sampler, LossAwareSampler): |
|
self.schedule_sampler.update_with_local_losses( |
|
t, losses["loss"].detach()) |
|
|
|
denoise_loss = (losses["loss"] * weights).mean() |
|
|
|
x_t = losses['x_t'] |
|
model_output = losses['model_output'] |
|
losses.pop('x_t') |
|
losses.pop('model_output') |
|
|
|
log_loss_dict(self.diffusion, t, { |
|
k: v * weights |
|
for k, v in losses.items() |
|
}) |
|
|
|
|
|
|
|
|
|
if self.denoised_ae: |
|
with th.cuda.amp.autocast( |
|
dtype=th.float16, |
|
enabled=self.mp_trainer_rec.use_amp |
|
and not self.freeze_ae): |
|
|
|
denoised_out = denoised_fn() |
|
|
|
denoised_ae_pred = self.ddp_rec_model( |
|
img=None, |
|
c=micro['c'], |
|
latent=denoised_out['pred_xstart'] * self. |
|
triplane_scaling_divider, |
|
behaviour=self.render_latent_behaviour) |
|
|
|
|
|
|
|
if last_batch or not self.use_ddp: |
|
denoised_ae_loss, loss_dict = self.loss_class( |
|
denoised_ae_pred, micro, test_mode=False) |
|
else: |
|
with self.ddp_model.no_sync(): |
|
denoised_ae_loss, loss_dict = self.loss_class( |
|
denoised_ae_pred, micro, test_mode=False) |
|
|
|
|
|
loss_dict_denoise_ae = {} |
|
for k, v in loss_dict.items(): |
|
loss_dict_denoise_ae[f'{k}_denoised'] = v.mean() |
|
log_rec3d_loss_dict(loss_dict_denoise_ae) |
|
|
|
else: |
|
denoised_ae_loss = th.tensor(0.0).to(dist_util.dev()) |
|
|
|
loss = ae_loss + denoise_loss + denoised_ae_loss |
|
|
|
|
|
|
|
|
|
self.mp_trainer.backward(loss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if dist_util.get_rank() == 1 and self.step % 500 == 0: |
|
with th.no_grad(): |
|
|
|
|
|
gt_depth = micro['depth'] |
|
if gt_depth.ndim == 3: |
|
gt_depth = gt_depth.unsqueeze(1) |
|
gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - |
|
gt_depth.min()) |
|
|
|
|
|
if self.freeze_ae: |
|
latent_micro = { |
|
k: |
|
v[0:1].to(dist_util.dev()) if v is not None else v |
|
for k, v in latent.items() |
|
} |
|
|
|
pred = self.rec_model(latent=latent_micro, |
|
c=micro['c'][0:1], |
|
behaviour='triplane_dec') |
|
else: |
|
assert pred is not None |
|
|
|
pred_depth = pred['image_depth'] |
|
pred_depth = (pred_depth - pred_depth.min()) / ( |
|
pred_depth.max() - pred_depth.min()) |
|
pred_img = pred['image_raw'] |
|
gt_img = micro['img'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gt_vis = th.cat( |
|
[ |
|
gt_img, micro['img'], micro['img'], |
|
gt_depth.repeat_interleave(3, dim=1) |
|
], |
|
dim=-1)[0:1] |
|
|
|
sr_w_code = latent_micro.get('sr_w_code', None) |
|
if sr_w_code is not None: |
|
sr_w_code = sr_w_code[0:1] |
|
|
|
noised_ae_pred = self.ddp_rec_model( |
|
img=None, |
|
c=micro['c'][0:1], |
|
latent={ |
|
'latent_normalized': |
|
x_t[0:1] * self.triplane_scaling_divider, |
|
|
|
'sr_w_code': sr_w_code |
|
}, |
|
behaviour=self.render_latent_behaviour) |
|
|
|
denoised_fn = functools.partial( |
|
self.diffusion.p_mean_variance, |
|
self.ddp_model, |
|
x_t, |
|
t, |
|
model_kwargs=model_kwargs) |
|
|
|
denoised_out = denoised_fn() |
|
|
|
denoised_ae_pred = self.ddp_rec_model( |
|
img=None, |
|
c=micro['c'][0:1], |
|
|
|
|
|
latent={ |
|
'latent_normalized': |
|
denoised_out['pred_xstart'][0:1] * self. |
|
triplane_scaling_divider, |
|
|
|
|
|
'sr_w_code': |
|
sr_w_code |
|
}, |
|
behaviour=self.render_latent_behaviour) |
|
|
|
assert denoised_ae_pred is not None |
|
|
|
|
|
|
|
|
|
pred_vis = th.cat([ |
|
pred_img[0:1], noised_ae_pred['image_raw'], |
|
denoised_ae_pred['image_raw'], |
|
pred_depth[0:1].repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
|
|
vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute( |
|
1, 2, 0).cpu() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vis = vis.numpy() * 127.5 + 127.5 |
|
vis = vis.clip(0, 255).astype(np.uint8) |
|
Image.fromarray(vis).save( |
|
f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t[0].item()}.jpg' |
|
) |
|
print( |
|
'log denoised vis to: ', |
|
f'{logger.get_dir()}/{self.step+self.resume_step}denoised_{t[0].item()}.jpg' |
|
) |
|
|
|
th.cuda.empty_cache() |
|
|
|
|
|
|
|
|