LN3Diff / nsr /train_util_with_eg3d.py
NIRVANALAN
release file
87c126b
import copy
import functools
import json
import os
from pathlib import Path
from pdb import set_trace as st
import blobfile as bf
import imageio
import numpy as np
import torch as th
import torch.distributed as dist
import torchvision
from PIL import Image
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from guided_diffusion import dist_util, logger
from guided_diffusion.fp16_util import MixedPrecisionTrainer
from guided_diffusion.nn import update_ema
from guided_diffusion.resample import LossAwareSampler, UniformSampler
from guided_diffusion.train_util import (calc_average_loss,
find_ema_checkpoint,
find_resume_checkpoint,
get_blob_logdir, log_rec3d_loss_dict,
parse_resume_step_from_filename)
from .train_util import TrainLoop3DRec
class TrainLoop3DRecEG3D(TrainLoop3DRec):
def __init__(self,
*,
G,
rec_model,
loss_class,
data,
eval_data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
eval_interval,
save_interval,
resume_checkpoint,
use_fp16=False,
fp16_scale_growth=0.001,
weight_decay=0,
lr_anneal_steps=0,
iterations=10001,
load_submodule_name='',
ignore_resume_opt=False,
model_name='rec',
use_amp=False,
# hybrid_training=False,
**kwargs):
super().__init__(rec_model=rec_model,
loss_class=loss_class,
data=data,
eval_data=eval_data,
batch_size=batch_size,
microbatch=microbatch,
lr=lr,
ema_rate=ema_rate,
log_interval=log_interval,
eval_interval=eval_interval,
save_interval=save_interval,
resume_checkpoint=resume_checkpoint,
use_fp16=use_fp16,
fp16_scale_growth=fp16_scale_growth,
weight_decay=weight_decay,
lr_anneal_steps=lr_anneal_steps,
iterations=iterations,
load_submodule_name=load_submodule_name,
ignore_resume_opt=ignore_resume_opt,
model_name=model_name,
use_amp=use_amp,
**kwargs)
self.G = G
# self.hybrid_training = hybrid_training
self.pool_224 = th.nn.AdaptiveAvgPool2d((224, 224))
@th.no_grad()
def run_G(
self,
z,
c,
swapping_prob,
neural_rendering_resolution,
update_emas=False,
return_raw_only=False,
):
"""add truncation psi
Args:
z (_type_): _description_
c (_type_): _description_
swapping_prob (_type_): _description_
neural_rendering_resolution (_type_): _description_
update_emas (bool, optional): _description_. Defaults to False.
Returns:
_type_: _description_
"""
c_gen_conditioning = th.zeros_like(c)
# ws = self.G.mapping(z, c_gen_conditioning, update_emas=update_emas)
ws = self.G.mapping(
z,
c_gen_conditioning,
truncation_psi=0.7,
truncation_cutoff=None,
update_emas=update_emas,
)
gen_output = self.G.synthesis(
ws, # BS * 14 * 512
c,
neural_rendering_resolution=neural_rendering_resolution,
update_emas=update_emas,
noise_mode='const',
return_raw_only=return_raw_only
# return_meta=True # return feature_volume
) # fix the SynthesisLayer modulation noise, otherviwe the same latent code may output two different ID
return gen_output, ws
def run_loop(self, batch=None):
while (not self.lr_anneal_steps
or self.step + self.resume_step < self.lr_anneal_steps):
# let all processes sync up before starting with a new epoch of training
dist_util.synchronize()
# batch, cond = next(self.data)
# if batch is None:
batch = next(self.data)
# batch = self.run_G()
self.run_step(batch)
if self.step % self.log_interval == 0 and dist_util.get_rank(
) == 0:
out = logger.dumpkvs()
# * log to tensorboard
for k, v in out.items():
self.writer.add_scalar(f'Loss/{k}', v,
self.step + self.resume_step)
if self.step % self.eval_interval == 0 and self.step != 0:
# if dist_util.get_rank() == 0:
# self.eval_loop()
# self.eval_novelview_loop()
# let all processes sync up before starting with a new epoch of training
dist_util.synchronize()
if self.step % self.save_interval == 0:
self.save()
dist_util.synchronize()
# Run for a finite amount of time in integration tests.
if os.environ.get("DIFFUSION_TRAINING_TEST",
"") and self.step > 0:
return
self.step += 1
if self.step > self.iterations:
print('reached maximum iterations, exiting')
# Save the last checkpoint if it wasn't already saved.
if (self.step - 1) % self.save_interval != 0:
self.save()
exit()
# Save the last checkpoint if it wasn't already saved.
if (self.step - 1) % self.save_interval != 0:
self.save()
def run_step(self, batch, *args):
self.forward_backward(batch)
took_step = self.mp_trainer_rec.optimize(self.opt)
if took_step:
self._update_ema()
self._anneal_lr()
self.log_step()
def forward_backward(self, batch, *args, **kwargs):
self.mp_trainer_rec.zero_grad()
batch_size = batch['c'].shape[0]
for i in range(0, batch_size, self.microbatch):
micro = {'c': batch['c'].to(dist_util.dev())}
with th.no_grad(): # * infer gt
eg3d_batch, ws = self.run_G(
z=th.randn(micro['c'].shape[0],
512).to(dist_util.dev()),
c=micro['c'].to(dist_util.dev(
)), # use real img pose here? or synthesized pose.
swapping_prob=0,
neural_rendering_resolution=128)
micro.update({
'img':
eg3d_batch['image_raw'], # gt
'img_to_encoder':
self.pool_224(eg3d_batch['image']),
'depth':
eg3d_batch['image_depth'],
'img_sr': eg3d_batch['image'],
})
last_batch = (i + self.microbatch) >= batch_size
# wrap forward within amp
with th.autocast(device_type='cuda',
dtype=th.float16,
enabled=self.mp_trainer_rec.use_amp):
pred_gen_output = self.rec_model(
img=micro['img_to_encoder'], # pool from 512
c=micro['c']) # pred: (B, 3, 64, 64)
# target = micro
target = dict(
img=eg3d_batch['image_raw'],
shape_synthesized=eg3d_batch['shape_synthesized'],
img_sr=eg3d_batch['image'],
)
pred_gen_output['shape_synthesized_query'] = {
'coarse_densities':
pred_gen_output['shape_synthesized']['coarse_densities'],
'image_depth': pred_gen_output['image_depth'],
}
eg3d_batch['shape_synthesized']['image_depth'] = eg3d_batch['image_depth']
batch_size, num_rays, _, _ = pred_gen_output[
'shape_synthesized']['coarse_densities'].shape
for coord_key in ['fine_coords']: # TODO add surface points
sigma = self.rec_model(
latent=pred_gen_output['latent_denormalized'],
coordinates=eg3d_batch['shape_synthesized'][coord_key],
directions=th.randn_like(
eg3d_batch['shape_synthesized'][coord_key]),
behaviour='triplane_renderer',
)['sigma']
rendering_kwargs = self.rec_model(
behaviour='get_rendering_kwargs')
sigma = sigma.reshape(
batch_size, num_rays,
rendering_kwargs['depth_resolution_importance'], 1)
pred_gen_output['shape_synthesized_query'][
f"{coord_key.split('_')[0]}_densities"] = sigma
# * 2D reconstruction loss
if last_batch or not self.use_ddp:
loss, loss_dict = self.loss_class(pred_gen_output,
target,
test_mode=False)
else:
with self.rec_model.no_sync(): # type: ignore
loss, loss_dict = self.loss_class(pred_gen_output,
target,
test_mode=False)
# * fully mimic 3D geometry output
loss_shape = self.calc_shape_rec_loss(
pred_gen_output['shape_synthesized_query'],
eg3d_batch['shape_synthesized'])
loss += loss_shape.mean()
# * add feature loss on feature_image
loss_feature_volume = th.nn.functional.mse_loss(
eg3d_batch['feature_volume'],
pred_gen_output['feature_volume'])
loss += loss_feature_volume * 0.1
loss_ws = th.nn.functional.mse_loss(
ws[:, -1:, :],
pred_gen_output['sr_w_code'])
loss += loss_ws * 0.1
loss_dict.update(
dict(loss_feature_volume=loss_feature_volume,
loss=loss,
loss_shape=loss_shape,
loss_ws=loss_ws))
loss_dict.update(dict(loss_feature_volume=loss_feature_volume, loss=loss, loss_shape=loss_shape))
log_rec3d_loss_dict(loss_dict)
self.mp_trainer_rec.backward(loss)
# for name, p in self.ddp_model.named_parameters():
# if p.grad is None:
# print(f"found rec unused param: {name}")
if dist_util.get_rank() == 0 and self.step % 500 == 0:
with th.no_grad():
# gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
pred_img = pred_gen_output['image_raw']
gt_img = micro['img']
if 'depth' in micro:
gt_depth = micro['depth']
if gt_depth.ndim == 3:
gt_depth = gt_depth.unsqueeze(1)
gt_depth = (gt_depth - gt_depth.min()) / (
gt_depth.max() - gt_depth.min())
pred_depth = pred_gen_output['image_depth']
pred_depth = (pred_depth - pred_depth.min()) / (
pred_depth.max() - pred_depth.min())
gt_vis = th.cat(
[gt_img,
gt_depth.repeat_interleave(3, dim=1)],
dim=-1) # TODO, fail to load depth. range [0, 1]
else:
gt_vis = th.cat(
[gt_img],
dim=-1) # TODO, fail to load depth. range [0, 1]
if 'image_sr' in pred_gen_output:
pred_img = th.cat([
self.pool_512(pred_img),
pred_gen_output['image_sr']
],
dim=-1)
pred_depth = self.pool_512(pred_depth)
gt_depth = self.pool_512(gt_depth)
gt_vis = th.cat(
[self.pool_512(micro['img']), micro['img_sr'], gt_depth.repeat_interleave(3, dim=1)],
dim=-1)
pred_vis = th.cat(
[pred_img,
pred_depth.repeat_interleave(3, dim=1)],
dim=-1) # B, 3, H, W
vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
1, 2, 0).cpu() # ! pred in range[-1, 1]
# vis_grid = torchvision.utils.make_grid(vis) # HWC
vis = vis.numpy() * 127.5 + 127.5
vis = vis.clip(0, 255).astype(np.uint8)
Image.fromarray(vis).save(
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
print(
'log vis to: ',
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
# self.writer.add_image(f'images',
# vis,
# self.step + self.resume_step,
# dataformats='HWC')
return pred_gen_output
def calc_shape_rec_loss(
self,
pred_shape: dict,
gt_shape: dict,
):
loss_shape, loss_shape_dict = self.loss_class.calc_shape_rec_loss(
pred_shape,
gt_shape,
dist_util.dev(),
)
for loss_k, loss_v in loss_shape_dict.items():
# training_stats.report('Loss/E/3D/{}'.format(loss_k), loss_v)
log_rec3d_loss_dict({'Loss/3D/{}'.format(loss_k): loss_v})
return loss_shape
# @th.inference_mode()
def eval_novelview_loop(self):
# novel view synthesis given evaluation camera trajectory
video_out = imageio.get_writer(
f'{logger.get_dir()}/video_novelview_real_{self.step+self.resume_step}.mp4',
mode='I',
fps=60,
codec='libx264')
all_loss_dict = []
novel_view_micro = {}
# for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
for i, batch in enumerate(tqdm(self.eval_data)):
# for i in range(0, 8, self.microbatch):
# c = c_list[i].to(dist_util.dev()).reshape(1, -1)
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
if i == 0:
novel_view_micro = {
k: v[0:1].to(dist_util.dev()).repeat_interleave(
micro['img'].shape[0], 0)
for k, v in batch.items()
}
else:
# if novel_view_micro['c'].shape[0] < micro['img'].shape[0]:
novel_view_micro = {
k: v[0:1].to(dist_util.dev()).repeat_interleave(
micro['img'].shape[0], 0)
for k, v in novel_view_micro.items()
}
# st()
pred = self.rec_model(img=novel_view_micro['img_to_encoder'],
c=micro['c']) # pred: (B, 3, 64, 64)
# _, loss_dict = self.loss_class(pred, micro, test_mode=True)
# all_loss_dict.append(loss_dict)
# ! move to other places, add tensorboard
pred_depth = pred['image_depth']
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
pred_depth.min())
if 'image_sr' in pred:
pred_vis = th.cat([
micro['img_sr'],
self.pool_512(pred['image_raw']), pred['image_sr'],
self.pool_512(pred_depth).repeat_interleave(3, dim=1)
],
dim=-1)
else:
pred_vis = th.cat([
self.pool_128(micro['img']), pred['image_raw'],
pred_depth.repeat_interleave(3, dim=1)
],
dim=-1) # B, 3, H, W
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
vis = vis * 127.5 + 127.5
vis = vis.clip(0, 255).astype(np.uint8)
for j in range(vis.shape[0]):
video_out.append_data(vis[j])
video_out.close()
# val_scores_for_logging = calc_average_loss(all_loss_dict)
# with open(os.path.join(logger.get_dir(), 'scores_novelview.json'),
# 'a') as f:
# json.dump({'step': self.step, **val_scores_for_logging}, f)
# * log to tensorboard
# for k, v in val_scores_for_logging.items():
# self.writer.add_scalar(f'Eval/NovelView/{k}', v,
# self.step + self.resume_step)
del video_out
# del pred_vis
# del pred
th.cuda.empty_cache()
# self.eval_novelview_loop_eg3d()
@th.inference_mode()
def eval_novelview_loop_eg3d(self):
# novel view synthesis given evaluation camera trajectory
video_out = imageio.get_writer(
f'{logger.get_dir()}/video_novelview_synthetic_{self.step+self.resume_step}.mp4',
mode='I',
fps=60,
codec='libx264')
all_loss_dict = []
novel_view_micro = {}
# for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
for i, batch in enumerate(tqdm(self.eval_data)):
# for i in range(0, 8, self.microbatch):
# c = c_list[i].to(dist_util.dev()).reshape(1, -1)
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
if i == 0:
# novel_view_micro = {
# k: v[0:1].to(dist_util.dev()).repeat_interleave(
# micro['img'].shape[0], 0)
# for k, v in batch.items()
# }
with th.no_grad(): # * infer gt
eg3d_batch, _ = self.run_G(
z=th.randn(micro['c'].shape[0],
512).to(dist_util.dev()),
c=micro['c'].to(dist_util.dev(
)), # use real img pose here? or synthesized pose.
swapping_prob=0,
neural_rendering_resolution=128)
novel_view_micro.update({
'img':
eg3d_batch['image_raw'], # gt
'img_to_encoder':
self.pool_224(eg3d_batch['image']),
'depth':
eg3d_batch['image_depth'],
})
else:
# if novel_view_micro['c'].shape[0] < micro['img'].shape[0]:
novel_view_micro = {
k: v[0:1].to(dist_util.dev()).repeat_interleave(
micro['img'].shape[0], 0)
for k, v in novel_view_micro.items()
}
# st()
pred = self.rec_model(img=novel_view_micro['img_to_encoder'],
c=micro['c']) # pred: (B, 3, 64, 64)
# _, loss_dict = self.loss_class(pred, micro, test_mode=True)
# all_loss_dict.append(loss_dict)
# ! move to other places, add tensorboard
pred_depth = pred['image_depth']
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
pred_depth.min())
if 'image_sr' in pred:
pred_vis = th.cat([
micro['img_sr'],
self.pool_512(pred['image_raw']), pred['image_sr'],
self.pool_512(pred_depth).repeat_interleave(3, dim=1)
],
dim=-1)
else:
pred_vis = th.cat([
self.pool_128(micro['img']), pred['image_raw'],
pred_depth.repeat_interleave(3, dim=1)
],
dim=-1) # B, 3, H, W
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
vis = vis * 127.5 + 127.5
vis = vis.clip(0, 255).astype(np.uint8)
for j in range(vis.shape[0]):
video_out.append_data(vis[j])
video_out.close()
# val_scores_for_logging = calc_average_loss(all_loss_dict)
# with open(os.path.join(logger.get_dir(), 'scores_novelview.json'),
# 'a') as f:
# json.dump({'step': self.step, **val_scores_for_logging}, f)
# # * log to tensorboard
# for k, v in val_scores_for_logging.items():
# self.writer.add_scalar(f'Eval/NovelView/{k}', v,
# self.step + self.resume_step)
del video_out
# del pred_vis
# del pred
th.cuda.empty_cache()