|
import functools |
|
import json |
|
import os |
|
from pathlib import Path |
|
from pdb import set_trace as st |
|
import torchvision |
|
import blobfile as bf |
|
import imageio |
|
import numpy as np |
|
import torch as th |
|
import torch.distributed as dist |
|
import torchvision |
|
from PIL import Image |
|
from torch.nn.parallel.distributed import DistributedDataParallel as DDP |
|
from tqdm import tqdm |
|
|
|
from guided_diffusion.fp16_util import MixedPrecisionTrainer |
|
from guided_diffusion import dist_util, logger |
|
from guided_diffusion.train_util import (calc_average_loss, |
|
log_rec3d_loss_dict, |
|
find_resume_checkpoint) |
|
|
|
from torch.optim import AdamW |
|
|
|
from ..train_util import TrainLoopBasic, TrainLoop3DRec |
|
import vision_aided_loss |
|
from dnnlib.util import calculate_adaptive_weight |
|
|
|
def flip_yaw(pose_matrix): |
|
flipped = pose_matrix.clone() |
|
flipped[:, 0, 1] *= -1 |
|
flipped[:, 0, 2] *= -1 |
|
flipped[:, 1, 0] *= -1 |
|
flipped[:, 2, 0] *= -1 |
|
flipped[:, 0, 3] *= -1 |
|
|
|
return flipped |
|
|
|
|
|
def get_blob_logdir(): |
|
|
|
|
|
return logger.get_dir() |
|
|
|
|
|
from ..train_util_cvD import TrainLoop3DcvD |
|
|
|
|
|
|
|
class TrainLoop3DcvD_nvsD_canoD(TrainLoop3DcvD): |
|
|
|
|
|
def __init__(self, |
|
*, |
|
rec_model, |
|
loss_class, |
|
data, |
|
eval_data, |
|
batch_size, |
|
microbatch, |
|
lr, |
|
ema_rate, |
|
log_interval, |
|
eval_interval, |
|
save_interval, |
|
resume_checkpoint, |
|
use_fp16=False, |
|
fp16_scale_growth=0.001, |
|
weight_decay=0, |
|
lr_anneal_steps=0, |
|
iterations=10001, |
|
load_submodule_name='', |
|
ignore_resume_opt=False, |
|
use_amp=False, |
|
**kwargs): |
|
super().__init__(rec_model=rec_model, |
|
loss_class=loss_class, |
|
data=data, |
|
eval_data=eval_data, |
|
batch_size=batch_size, |
|
microbatch=microbatch, |
|
lr=lr, |
|
ema_rate=ema_rate, |
|
log_interval=log_interval, |
|
eval_interval=eval_interval, |
|
save_interval=save_interval, |
|
resume_checkpoint=resume_checkpoint, |
|
use_fp16=use_fp16, |
|
fp16_scale_growth=fp16_scale_growth, |
|
weight_decay=weight_decay, |
|
lr_anneal_steps=lr_anneal_steps, |
|
iterations=iterations, |
|
load_submodule_name=load_submodule_name, |
|
ignore_resume_opt=ignore_resume_opt, |
|
use_amp=use_amp, |
|
**kwargs) |
|
|
|
device = dist_util.dev() |
|
|
|
self.cano_cvD = vision_aided_loss.Discriminator( |
|
cv_type='clip', loss_type='multilevel_sigmoid_s', |
|
device=device).to(device) |
|
self.cano_cvD.cv_ensemble.requires_grad_( |
|
False) |
|
|
|
|
|
cvD_model_params = list(self.cano_cvD.parameters()) |
|
SR_TRAINING = False |
|
if SR_TRAINING: |
|
|
|
vision_width, vision_patch_size = [ |
|
self.cano_cvD.cv_ensemble.models[0].model.conv1.weight.shape[k] |
|
for k in [0, -1] |
|
] |
|
self.cano_cvD.cv_ensemble.models[0].model.conv1 = th.nn.Conv2d( |
|
in_channels=6, |
|
out_channels=vision_width, |
|
kernel_size=vision_patch_size, |
|
stride=vision_patch_size, |
|
bias=False).to(dist_util.dev()) |
|
cvD_model_params += list( |
|
self.cano_cvD.cv_ensemble.models[0].model.conv1.parameters()) |
|
|
|
self.cano_cvD.cv_ensemble.models[ |
|
0].image_mean = self.cano_cvD.cv_ensemble.models[ |
|
0].image_mean.repeat(2) |
|
self.cano_cvD.cv_ensemble.models[ |
|
0].image_std = self.cano_cvD.cv_ensemble.models[ |
|
0].image_std.repeat(2) |
|
|
|
|
|
|
|
self._load_and_sync_parameters(model=self.cano_cvD, |
|
model_name='cano_cvD') |
|
|
|
self.mp_trainer_canonical_cvD = MixedPrecisionTrainer( |
|
model=self.cano_cvD, |
|
use_fp16=self.use_fp16, |
|
fp16_scale_growth=fp16_scale_growth, |
|
model_name='canonical_cvD', |
|
use_amp=use_amp, |
|
model_params=cvD_model_params) |
|
|
|
|
|
|
|
cano_lr = 2e-4 * ( |
|
lr / 1e-5) |
|
self.opt_cano_cvD = AdamW( |
|
self.mp_trainer_canonical_cvD.master_params, |
|
lr=cano_lr, |
|
betas=(0, 0.999), |
|
eps=1e-8) |
|
|
|
logger.log(f'cpt_cano_cvD lr: {cano_lr}') |
|
|
|
if self.use_ddp: |
|
self.ddp_cano_cvD = DDP( |
|
self.cano_cvD, |
|
device_ids=[dist_util.dev()], |
|
output_device=dist_util.dev(), |
|
broadcast_buffers=False, |
|
bucket_cap_mb=128, |
|
find_unused_parameters=False, |
|
) |
|
else: |
|
self.ddp_cano_cvD = self.cano_cvD |
|
|
|
th.cuda.empty_cache() |
|
|
|
def run_step(self, batch, step='g_step'): |
|
|
|
|
|
if step == 'g_step_rec': |
|
self.forward_G_rec(batch) |
|
took_step_g_rec = self.mp_trainer_rec.optimize(self.opt) |
|
|
|
if took_step_g_rec: |
|
self._update_ema() |
|
|
|
elif step == 'd_step_rec': |
|
self.forward_D(batch, behaviour='rec') |
|
|
|
_ = self.mp_trainer_canonical_cvD.optimize(self.opt_cano_cvD) |
|
|
|
elif step == 'g_step_nvs': |
|
self.forward_G_nvs(batch) |
|
took_step_g_nvs = self.mp_trainer_rec.optimize(self.opt) |
|
|
|
if took_step_g_nvs: |
|
self._update_ema() |
|
|
|
elif step == 'd_step_nvs': |
|
self.forward_D(batch, behaviour='nvs') |
|
_ = self.mp_trainer_cvD.optimize(self.opt_cvD) |
|
|
|
|
|
self._anneal_lr() |
|
self.log_step() |
|
|
|
def run_loop(self): |
|
while (not self.lr_anneal_steps |
|
or self.step + self.resume_step < self.lr_anneal_steps): |
|
|
|
|
|
dist_util.synchronize() |
|
|
|
|
|
|
|
batch = next(self.data) |
|
|
|
if self.novel_view_poses is None: |
|
self.novel_view_poses = th.roll(batch['c'], 1, 0).to( |
|
dist_util.dev()) |
|
|
|
self.run_step(batch, 'g_step_rec') |
|
|
|
|
|
batch = next(self.data) |
|
self.run_step(batch, 'd_step_rec') |
|
|
|
|
|
batch = next(self.data) |
|
self.run_step(batch, 'g_step_nvs') |
|
|
|
batch = next(self.data) |
|
self.run_step(batch, 'd_step_nvs') |
|
|
|
if self.step % self.log_interval == 0 and dist_util.get_rank( |
|
) == 0: |
|
out = logger.dumpkvs() |
|
|
|
for k, v in out.items(): |
|
self.writer.add_scalar(f'Loss/{k}', v, |
|
self.step + self.resume_step) |
|
|
|
|
|
if self.step % self.eval_interval == 0: |
|
if dist_util.get_rank() == 0: |
|
self.eval_loop() |
|
|
|
|
|
th.cuda.empty_cache() |
|
dist_util.synchronize() |
|
|
|
if self.step % self.save_interval == 0: |
|
self.save() |
|
self.save(self.mp_trainer_cvD, self.mp_trainer_cvD.model_name) |
|
self.save(self.mp_trainer_canonical_cvD, |
|
self.mp_trainer_canonical_cvD.model_name) |
|
|
|
dist_util.synchronize() |
|
|
|
if os.environ.get("DIFFUSION_TRAINING_TEST", |
|
"") and self.step > 0: |
|
return |
|
|
|
self.step += 1 |
|
|
|
if self.step > self.iterations: |
|
print('reached maximum iterations, exiting') |
|
|
|
|
|
if (self.step - 1) % self.save_interval != 0: |
|
|
|
self.save() |
|
self.save(self.mp_trainer_cvD, |
|
self.mp_trainer_cvD.model_name) |
|
self.save(self.mp_trainer_canonical_cvD, |
|
self.mp_trainer_canonical_cvD.model_name) |
|
|
|
exit() |
|
|
|
|
|
if (self.step - 1) % self.save_interval != 0: |
|
self.save() |
|
self.save(self.mp_trainer_canonical_cvD, 'cvD') |
|
|
|
def forward_D(self, batch, behaviour): |
|
self.mp_trainer_canonical_cvD.zero_grad() |
|
self.mp_trainer_cvD.zero_grad() |
|
|
|
self.rec_model.requires_grad_(False) |
|
|
|
|
|
|
|
if behaviour == 'nvs': |
|
self.ddp_nvs_cvD.requires_grad_(True) |
|
self.ddp_cano_cvD.requires_grad_(False) |
|
else: |
|
self.ddp_nvs_cvD.requires_grad_(False) |
|
self.ddp_cano_cvD.requires_grad_(True) |
|
|
|
batch_size = batch['img'].shape[0] |
|
|
|
|
|
for i in range(0, batch_size, self.microbatch): |
|
micro = { |
|
k: v[i:i + self.microbatch].to(dist_util.dev()).contiguous() |
|
for k, v in batch.items() |
|
} |
|
|
|
with th.autocast(device_type='cuda', |
|
dtype=th.float16, |
|
enabled=self.mp_trainer_canonical_cvD.use_amp): |
|
|
|
novel_view_c = th.cat([micro['c'][1:], micro['c'][:1]]) |
|
|
|
latent = self.rec_model(img=micro['img_to_encoder'], |
|
behaviour='enc_dec_wo_triplane') |
|
|
|
cano_pred = self.rec_model(latent=latent, |
|
c=micro['c'], |
|
behaviour='triplane_dec') |
|
|
|
|
|
|
|
if behaviour == 'rec': |
|
|
|
if 'image_sr' in cano_pred: |
|
|
|
d_loss = self.run_D_Diter( |
|
real=th.cat([ |
|
th.nn.functional.interpolate( |
|
micro['img'], |
|
size=micro['img_sr'].shape[2:], |
|
mode='bilinear', |
|
align_corners=False, |
|
antialias=True), |
|
micro['img_sr'], |
|
], |
|
dim=1), |
|
fake=th.cat([ |
|
th.nn.functional.interpolate( |
|
cano_pred['image_raw'], |
|
size=cano_pred['image_sr'].shape[2:], |
|
mode='bilinear', |
|
align_corners=False, |
|
antialias=True), |
|
cano_pred['image_sr'], |
|
], |
|
dim=1), |
|
D=self.ddp_cano_cvD) |
|
|
|
else: |
|
d_loss = self.run_D_Diter( |
|
real=micro['img'], |
|
fake=cano_pred['image_raw'], |
|
D=self.ddp_cano_cvD) |
|
|
|
log_rec3d_loss_dict( |
|
{'vision_aided_loss/D_cano': d_loss}) |
|
|
|
else: |
|
assert behaviour == 'nvs' |
|
|
|
nvs_pred = self.rec_model(latent=latent, |
|
c=novel_view_c, |
|
behaviour='triplane_dec') |
|
|
|
if 'image_sr' in nvs_pred: |
|
|
|
d_loss = self.run_D_Diter( |
|
real=th.cat([ |
|
th.nn.functional.interpolate( |
|
cano_pred['image_raw'], |
|
size=cano_pred['image_sr'].shape[2:], |
|
mode='bilinear', |
|
align_corners=False, |
|
antialias=True), |
|
cano_pred['image_sr'], |
|
], |
|
dim=1), |
|
fake=th.cat([ |
|
th.nn.functional.interpolate( |
|
nvs_pred['image_raw'], |
|
size=nvs_pred['image_sr'].shape[2:], |
|
mode='bilinear', |
|
align_corners=False, |
|
antialias=True), |
|
nvs_pred['image_sr'], |
|
], |
|
dim=1), |
|
D=self.ddp_nvs_cvD) |
|
|
|
else: |
|
d_loss = self.run_D_Diter( |
|
real=cano_pred['image_raw'], |
|
fake=nvs_pred['image_raw'], |
|
D=self.ddp_nvs_cvD) |
|
|
|
log_rec3d_loss_dict( |
|
{'vision_aided_loss/D_nvs': d_loss}) |
|
|
|
|
|
if behaviour == 'rec': |
|
self.mp_trainer_canonical_cvD.backward(d_loss) |
|
else: |
|
assert behaviour == 'nvs' |
|
self.mp_trainer_cvD.backward(d_loss) |
|
|
|
def forward_G_rec(self, batch): |
|
|
|
self.mp_trainer_rec.zero_grad() |
|
self.rec_model.requires_grad_(True) |
|
|
|
self.ddp_cano_cvD.requires_grad_(False) |
|
self.ddp_nvs_cvD.requires_grad_(False) |
|
|
|
batch_size = batch['img'].shape[0] |
|
|
|
for i in range(0, batch_size, self.microbatch): |
|
micro = { |
|
k: v[i:i + self.microbatch].to(dist_util.dev()).contiguous() |
|
for k, v in batch.items() |
|
} |
|
|
|
last_batch = (i + self.microbatch) >= batch_size |
|
|
|
with th.autocast(device_type='cuda', |
|
dtype=th.float16, |
|
enabled=self.mp_trainer_rec.use_amp): |
|
|
|
pred = self.rec_model( |
|
img=micro['img_to_encoder'], c=micro['c'] |
|
) |
|
|
|
target_for_rec = micro |
|
cano_pred = pred |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with self.rec_model.no_sync(): |
|
loss, loss_dict, fg_mask = self.loss_class(cano_pred, |
|
target_for_rec, |
|
test_mode=False, |
|
step=self.step + |
|
self.resume_step, |
|
return_fg_mask=True) |
|
|
|
|
|
|
|
if self.loss_class.opt.symmetry_loss: |
|
pose, intrinsics = micro['c'][:, :16].reshape( |
|
-1, 4, 4), micro['c'][:, 16:] |
|
flipped_pose = flip_yaw(pose) |
|
mirror_c = th.cat( |
|
[flipped_pose.reshape(-1, 16), intrinsics], -1) |
|
|
|
nvs_pred = self.rec_model(latent={ |
|
k: v |
|
for k, v in pred.items() if 'latent' in k |
|
}, |
|
c=mirror_c, |
|
behaviour='triplane_dec', |
|
return_raw_only=True) |
|
|
|
|
|
|
|
nvs_gt = { |
|
k: th.flip(target_for_rec[k], [-1]) |
|
for k in |
|
['img'] |
|
} |
|
flipped_fg_mask = th.flip(fg_mask, [-1]) |
|
if 'conf_sigma' in pred: |
|
conf_sigma = th.flip(pred['conf_sigma'], [-1]) |
|
conf_sigma = th.nn.AdaptiveAvgPool2d(fg_mask.shape[-2:])(conf_sigma) |
|
else: |
|
conf_sigma=None |
|
|
|
with self.rec_model.no_sync(): |
|
loss_symm, loss_dict_symm = self.loss_class.calc_2d_rec_loss( |
|
nvs_pred['image_raw'], |
|
nvs_gt['img'], |
|
flipped_fg_mask, |
|
|
|
test_mode=False, |
|
step=self.step + self.resume_step, |
|
conf_sigma=conf_sigma, |
|
) |
|
loss += (loss_symm * 1.0) |
|
|
|
|
|
|
|
|
|
for k, v in loss_dict_symm.items(): |
|
loss_dict[f'{k}_symm'] = v |
|
|
|
|
|
|
|
|
|
|
|
if 'image_sr' in cano_pred: |
|
|
|
vision_aided_loss = self.ddp_cano_cvD( |
|
th.cat([ |
|
th.nn.functional.interpolate( |
|
cano_pred['image_raw'], |
|
size=cano_pred['image_sr'].shape[2:], |
|
mode='bilinear', |
|
align_corners=False, |
|
antialias=True), |
|
cano_pred['image_sr'], |
|
], |
|
dim=1), |
|
for_G=True).mean() |
|
|
|
else: |
|
vision_aided_loss = self.ddp_cano_cvD( |
|
cano_pred['image_raw'], |
|
for_G=True).mean() |
|
|
|
|
|
|
|
|
|
d_weight = th.tensor(self.loss_class.opt.rec_cvD_lambda).to( |
|
dist_util.dev()) |
|
|
|
|
|
|
|
|
|
|
|
loss += vision_aided_loss * d_weight |
|
|
|
loss_dict.update({ |
|
'vision_aided_loss/G_rec': |
|
(vision_aided_loss * d_weight).detach(), |
|
'd_weight': |
|
d_weight |
|
}) |
|
|
|
log_rec3d_loss_dict(loss_dict) |
|
|
|
self.mp_trainer_rec.backward( |
|
loss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if dist_util.get_rank() == 0 and self.step % 500 == 0: |
|
with th.no_grad(): |
|
|
|
|
|
def norm_depth(pred_depth): |
|
|
|
pred_depth = (pred_depth - pred_depth.min()) / ( |
|
pred_depth.max() - pred_depth.min()) |
|
return -(pred_depth * 2 - 1) |
|
|
|
pred_img = pred['image_raw'].clip(-1,1) |
|
gt_img = micro['img'] |
|
|
|
|
|
pred_nv_img = self.rec_model( |
|
img=micro['img_to_encoder'], |
|
c=self.novel_view_poses) |
|
|
|
|
|
gt_depth = micro['depth'] |
|
if gt_depth.ndim == 3: |
|
gt_depth = gt_depth.unsqueeze(1) |
|
gt_depth = norm_depth(gt_depth) |
|
|
|
|
|
|
|
if 'image_depth' in pred: |
|
|
|
|
|
|
|
pred_depth = norm_depth(pred['image_depth']) |
|
pred_nv_depth = norm_depth( |
|
pred_nv_img['image_depth']) |
|
else: |
|
pred_depth = th.zeros_like(gt_depth) |
|
pred_nv_depth = th.zeros_like(gt_depth) |
|
|
|
if 'image_sr' in pred: |
|
if pred['image_sr'].shape[-1] == 512: |
|
pred_img = th.cat( |
|
[self.pool_512(pred_img), pred['image_sr']], |
|
dim=-1) |
|
gt_img = th.cat( |
|
[self.pool_512(micro['img']), micro['img_sr']], |
|
dim=-1) |
|
pred_depth = self.pool_512(pred_depth) |
|
gt_depth = self.pool_512(gt_depth) |
|
|
|
elif pred['image_sr'].shape[-1] == 256: |
|
pred_img = th.cat( |
|
[self.pool_256(pred_img), pred['image_sr']], |
|
dim=-1) |
|
gt_img = th.cat( |
|
[self.pool_256(micro['img']), micro['img_sr']], |
|
dim=-1) |
|
pred_depth = self.pool_256(pred_depth) |
|
gt_depth = self.pool_256(gt_depth) |
|
|
|
else: |
|
pred_img = th.cat( |
|
[self.pool_128(pred_img), pred['image_sr']], |
|
dim=-1) |
|
gt_img = th.cat( |
|
[self.pool_128(micro['img']), micro['img_sr']], |
|
dim=-1) |
|
gt_depth = self.pool_128(gt_depth) |
|
pred_depth = self.pool_128(pred_depth) |
|
|
|
if gt_img.shape[-1] == 64: |
|
gt_depth = self.pool_64(gt_depth) |
|
elif gt_img.shape[-1] == 128: |
|
gt_depth = self.pool_128(gt_depth) |
|
|
|
|
|
|
|
|
|
pred_vis = th.cat( |
|
[pred_img, |
|
pred_depth.repeat_interleave(3, dim=1)], |
|
dim=-1) |
|
|
|
pred_vis_nv = th.cat([ |
|
pred_nv_img['image_raw'].clip(-1,1), |
|
pred_nv_depth.repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
pred_vis = th.cat([pred_vis, pred_vis_nv], |
|
dim=-2) |
|
|
|
gt_vis = th.cat( |
|
[gt_img, gt_depth.repeat_interleave(3, dim=1)], |
|
dim=-1) |
|
|
|
|
|
vis = th.cat([gt_vis, pred_vis], dim=-2) |
|
|
|
|
|
vis_tensor = torchvision.utils.make_grid( |
|
vis, nrow=vis.shape[-1] // 64) |
|
torchvision.utils.save_image( |
|
vis_tensor, |
|
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', normalize=True, value_range=(-1,1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.log( |
|
'log vis to: ', |
|
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') |
|
|
|
|
|
def forward_G_nvs(self, batch): |
|
|
|
self.mp_trainer_rec.zero_grad() |
|
self.rec_model.requires_grad_(True) |
|
|
|
self.ddp_cano_cvD.requires_grad_(False) |
|
self.ddp_nvs_cvD.requires_grad_(False) |
|
|
|
batch_size = batch['img'].shape[0] |
|
|
|
for i in range(0, batch_size, self.microbatch): |
|
micro = { |
|
k: v[i:i + self.microbatch].to(dist_util.dev()).contiguous() |
|
for k, v in batch.items() |
|
} |
|
|
|
with th.autocast(device_type='cuda', |
|
dtype=th.float16, |
|
enabled=self.mp_trainer_rec.use_amp): |
|
|
|
nvs_pred = self.rec_model( |
|
img=micro['img_to_encoder'], |
|
c=th.cat([ |
|
micro['c'][1:], |
|
micro['c'][:1], |
|
])) |
|
|
|
|
|
|
|
if 'image_sr' in nvs_pred: |
|
|
|
vision_aided_loss = self.ddp_nvs_cvD( |
|
|
|
|
|
th.cat([ |
|
th.nn.functional.interpolate( |
|
nvs_pred['image_raw'], |
|
size=nvs_pred['image_sr'].shape[2:], |
|
mode='bilinear', |
|
align_corners=False, |
|
antialias=True), |
|
nvs_pred['image_sr'], |
|
], |
|
dim=1), |
|
for_G=True).mean() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
vision_aided_loss = self.ddp_nvs_cvD( |
|
nvs_pred['image_raw'], |
|
for_G=True).mean() |
|
|
|
loss = vision_aided_loss * self.loss_class.opt.nvs_cvD_lambda |
|
|
|
log_rec3d_loss_dict({ |
|
'vision_aided_loss/G_nvs': loss |
|
|
|
}) |
|
|
|
self.mp_trainer_rec.backward(loss) |
|
|
|
|
|
|
|
|
|
if dist_util.get_rank() == 0 and self.step % 500 == 1: |
|
with th.no_grad(): |
|
|
|
|
|
def norm_depth(pred_depth): |
|
|
|
pred_depth = (pred_depth - pred_depth.min()) / ( |
|
pred_depth.max() - pred_depth.min()) |
|
return -(pred_depth * 2 - 1) |
|
|
|
gt_depth = micro['depth'] |
|
if gt_depth.ndim == 3: |
|
gt_depth = gt_depth.unsqueeze(1) |
|
gt_depth = norm_depth(gt_depth) |
|
|
|
|
|
|
|
|
|
|
|
pred_depth = norm_depth(nvs_pred['image_depth']) |
|
pred_img = nvs_pred['image_raw'] |
|
gt_img = micro['img'] |
|
|
|
if 'image_sr' in nvs_pred: |
|
|
|
if nvs_pred['image_sr'].shape[-1] == 512: |
|
pred_img = th.cat([ |
|
self.pool_512(pred_img), nvs_pred['image_sr'] |
|
], |
|
dim=-1) |
|
gt_img = th.cat( |
|
[self.pool_512(micro['img']), micro['img_sr']], |
|
dim=-1) |
|
pred_depth = self.pool_512(pred_depth) |
|
gt_depth = self.pool_512(gt_depth) |
|
|
|
elif nvs_pred['image_sr'].shape[-1] == 256: |
|
pred_img = th.cat([ |
|
self.pool_256(pred_img), nvs_pred['image_sr'] |
|
], |
|
dim=-1) |
|
gt_img = th.cat( |
|
[self.pool_256(micro['img']), micro['img_sr']], |
|
dim=-1) |
|
pred_depth = self.pool_256(pred_depth) |
|
gt_depth = self.pool_256(gt_depth) |
|
|
|
else: |
|
pred_img = th.cat([ |
|
self.pool_128(pred_img), nvs_pred['image_sr'] |
|
], |
|
dim=-1) |
|
gt_img = th.cat( |
|
[self.pool_128(micro['img']), micro['img_sr']], |
|
dim=-1) |
|
gt_depth = self.pool_128(gt_depth) |
|
pred_depth = self.pool_128(pred_depth) |
|
|
|
|
|
if gt_img.shape[-1] == 64: |
|
gt_depth = self.pool_64(gt_depth) |
|
elif gt_img.shape[-1] == 128: |
|
gt_depth = self.pool_128(gt_depth) |
|
|
|
|
|
|
|
|
|
|
|
gt_vis = th.cat( |
|
[gt_img, gt_depth.repeat_interleave(3, dim=1)], |
|
dim=-1) |
|
|
|
pred_vis = th.cat( |
|
[pred_img, |
|
pred_depth.repeat_interleave(3, dim=1)], |
|
dim=-1) |
|
|
|
|
|
|
|
vis = th.cat([gt_vis, pred_vis], dim=-2) |
|
|
|
vis = torchvision.utils.make_grid( |
|
vis, |
|
normalize=True, |
|
scale_each=True, |
|
value_range=(-1, 1)).cpu().permute(1, 2, 0) |
|
vis = vis.numpy() * 255 |
|
vis = vis.clip(0, 255).astype(np.uint8) |
|
|
|
|
|
|
|
Image.fromarray(vis).save( |
|
f'{logger.get_dir()}/{self.step+self.resume_step}_nvs.jpg' |
|
) |
|
print( |
|
'log vis to: ', |
|
f'{logger.get_dir()}/{self.step+self.resume_step}_nvs.jpg' |
|
) |
|
|
|
class TrainLoop3DcvD_nvsD_canoD_eg3d(TrainLoop3DcvD_nvsD_canoD): |
|
def __init__(self, *, rec_model, loss_class, data, eval_data, batch_size, microbatch, lr, ema_rate, log_interval, eval_interval, save_interval, resume_checkpoint, use_fp16=False, fp16_scale_growth=0.001, weight_decay=0, lr_anneal_steps=0, iterations=10001, load_submodule_name='', ignore_resume_opt=False, use_amp=False, **kwargs): |
|
super().__init__(rec_model=rec_model, loss_class=loss_class, data=data, eval_data=eval_data, batch_size=batch_size, microbatch=microbatch, lr=lr, ema_rate=ema_rate, log_interval=log_interval, eval_interval=eval_interval, save_interval=save_interval, resume_checkpoint=resume_checkpoint, use_fp16=use_fp16, fp16_scale_growth=fp16_scale_growth, weight_decay=weight_decay, lr_anneal_steps=lr_anneal_steps, iterations=iterations, load_submodule_name=load_submodule_name, ignore_resume_opt=ignore_resume_opt, use_amp=use_amp, **kwargs) |
|
self.rendering_kwargs = self.rec_model.module.decoder.triplane_decoder.rendering_kwargs |
|
self._prepare_nvs_pose() |
|
|
|
@th.inference_mode() |
|
def eval_novelview_loop(self): |
|
|
|
|
|
for i, batch in enumerate(tqdm(self.eval_data)): |
|
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} |
|
|
|
video_out = imageio.get_writer( |
|
f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}_batch_{i}.mp4', |
|
mode='I', |
|
fps=60, |
|
codec='libx264') |
|
|
|
for idx, c in enumerate(self.all_nvs_params): |
|
pred = self.rec_model(img=micro['img_to_encoder'], |
|
c=c.unsqueeze(0).repeat_interleave(micro['img'].shape[0], 0)) |
|
|
|
|
|
|
|
|
|
pred_depth = pred['image_depth'] |
|
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - |
|
pred_depth.min()) |
|
if 'image_sr' in pred: |
|
|
|
if pred['image_sr'].shape[-1] == 512: |
|
|
|
pred_vis = th.cat([ |
|
micro['img_sr'], |
|
self.pool_512(pred['image_raw']), pred['image_sr'], |
|
self.pool_512(pred_depth).repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
|
|
elif pred['image_sr'].shape[-1] == 256: |
|
|
|
pred_vis = th.cat([ |
|
micro['img_sr'], |
|
self.pool_256(pred['image_raw']), pred['image_sr'], |
|
self.pool_256(pred_depth).repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
|
|
else: |
|
pred_vis = th.cat([ |
|
micro['img_sr'], |
|
self.pool_128(pred['image_raw']), |
|
self.pool_128(pred['image_sr']), |
|
self.pool_128(pred_depth).repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
|
|
else: |
|
|
|
|
|
pred_vis = th.cat([ |
|
self.pool_128(micro['img']), |
|
self.pool_128(pred['image_raw']), |
|
self.pool_128(pred_depth).repeat_interleave(3, dim=1) |
|
], |
|
dim=-1) |
|
|
|
|
|
pred_vis = pred_vis.permute(0,2,3,1).flatten(0,1) |
|
|
|
|
|
|
|
vis = pred_vis.cpu().numpy() |
|
vis = vis * 127.5 + 127.5 |
|
vis = vis.clip(0, 255).astype(np.uint8) |
|
|
|
|
|
|
|
video_out.append_data(vis) |
|
|
|
video_out.close() |
|
|
|
th.cuda.empty_cache() |
|
|
|
|
|
def _prepare_nvs_pose(self): |
|
from nsr.camera_utils import LookAtPoseSampler, FOV_to_intrinsics |
|
|
|
device = dist_util.dev() |
|
|
|
fov_deg = 18.837 |
|
intrinsics = FOV_to_intrinsics(fov_deg, device=device) |
|
|
|
all_nvs_params = [] |
|
|
|
pitch_range = 0.25 |
|
yaw_range = 0.35 |
|
num_keyframes = 10 |
|
w_frames = 1 |
|
|
|
cam_pivot = th.Tensor(self.rendering_kwargs.get('avg_camera_pivot')).to(device) |
|
cam_radius = self.rendering_kwargs.get('avg_camera_radius') |
|
|
|
for frame_idx in range(num_keyframes): |
|
|
|
cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.sin(2 * 3.14 * frame_idx / (num_keyframes * w_frames)), |
|
3.14/2 -0.05 + pitch_range * np.cos(2 * 3.14 * frame_idx / (num_keyframes * w_frames)), |
|
cam_pivot, radius=cam_radius, device=device) |
|
|
|
camera_params = th.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) |
|
|
|
all_nvs_params.append(camera_params) |
|
|
|
self.all_nvs_params = th.cat(all_nvs_params, 0) |