Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
# import imageio.v3 | |
import functools | |
import json | |
import os | |
from pathlib import Path | |
from pdb import set_trace as st | |
from einops import rearrange | |
import webdataset as wds | |
import traceback | |
import blobfile as bf | |
import imageio | |
import numpy as np | |
# from sympy import O | |
import torch as th | |
import torch.distributed as dist | |
import torchvision | |
from PIL import Image | |
from torch.nn.parallel.distributed import DistributedDataParallel as DDP | |
from torch.optim import AdamW | |
from torch.utils.tensorboard import SummaryWriter | |
from tqdm import tqdm | |
from guided_diffusion import dist_util, logger | |
from guided_diffusion.fp16_util import MixedPrecisionTrainer | |
from guided_diffusion.nn import update_ema | |
from guided_diffusion.resample import LossAwareSampler, UniformSampler | |
from guided_diffusion.train_util import (calc_average_loss, | |
find_ema_checkpoint, | |
find_resume_checkpoint, | |
get_blob_logdir, log_rec3d_loss_dict, | |
parse_resume_step_from_filename) | |
from .camera_utils import LookAtPoseSampler, FOV_to_intrinsics | |
from .train_util import TrainLoop3DRec | |
class TrainLoop3DRecNV(TrainLoop3DRec): | |
# supervise the training of novel view | |
def __init__(self, | |
*, | |
rec_model, | |
loss_class, | |
data, | |
eval_data, | |
batch_size, | |
microbatch, | |
lr, | |
ema_rate, | |
log_interval, | |
eval_interval, | |
save_interval, | |
resume_checkpoint, | |
use_fp16=False, | |
fp16_scale_growth=0.001, | |
weight_decay=0, | |
lr_anneal_steps=0, | |
iterations=10001, | |
load_submodule_name='', | |
ignore_resume_opt=False, | |
model_name='rec', | |
use_amp=False, | |
**kwargs): | |
super().__init__(rec_model=rec_model, | |
loss_class=loss_class, | |
data=data, | |
eval_data=eval_data, | |
batch_size=batch_size, | |
microbatch=microbatch, | |
lr=lr, | |
ema_rate=ema_rate, | |
log_interval=log_interval, | |
eval_interval=eval_interval, | |
save_interval=save_interval, | |
resume_checkpoint=resume_checkpoint, | |
use_fp16=use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
weight_decay=weight_decay, | |
lr_anneal_steps=lr_anneal_steps, | |
iterations=iterations, | |
load_submodule_name=load_submodule_name, | |
ignore_resume_opt=ignore_resume_opt, | |
model_name=model_name, | |
use_amp=use_amp, | |
**kwargs) | |
self.rec_cano = True | |
def forward_backward(self, batch, *args, **kwargs): | |
# return super().forward_backward(batch, *args, **kwargs) | |
self.mp_trainer_rec.zero_grad() | |
batch_size = batch['img_to_encoder'].shape[0] | |
for i in range(0, batch_size, self.microbatch): | |
# st() | |
micro = { | |
k: v[i:i + self.microbatch].to(dist_util.dev()) | |
for k, v in batch.items() | |
} | |
# ! concat novel-view? next version. also add self reconstruction, patch-based loss in the next version. verify novel-view prediction first. | |
# wrap forward within amp | |
with th.autocast(device_type='cuda', | |
dtype=th.float16, | |
enabled=self.mp_trainer_rec.use_amp): | |
target_nvs = {} | |
target_cano = {} | |
latent = self.rec_model(img=micro['img_to_encoder'], | |
behaviour='enc_dec_wo_triplane') | |
pred = self.rec_model( | |
latent=latent, | |
c=micro['nv_c'], # predict novel view here | |
behaviour='triplane_dec') | |
for k, v in micro.items(): | |
if k[:2] == 'nv': | |
orig_key = k.replace('nv_', '') | |
target_nvs[orig_key] = v | |
target_cano[orig_key] = micro[orig_key] | |
with self.rec_model.no_sync(): # type: ignore | |
loss, loss_dict, fg_mask = self.loss_class( | |
pred, | |
target_nvs, | |
step=self.step + self.resume_step, | |
test_mode=False, | |
return_fg_mask=True, | |
conf_sigma_l1=None, | |
conf_sigma_percl=None) | |
log_rec3d_loss_dict(loss_dict) | |
if self.rec_cano: | |
pred_cano = self.rec_model(latent=latent, | |
c=micro['c'], | |
behaviour='triplane_dec') | |
with self.rec_model.no_sync(): # type: ignore | |
fg_mask = target_cano['depth_mask'].unsqueeze( | |
1).repeat_interleave(3, 1).float() | |
loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss( | |
pred_cano['image_raw'], | |
target_cano['img'], | |
fg_mask, | |
step=self.step + self.resume_step, | |
test_mode=False, | |
) | |
loss = loss + loss_cano | |
# remove redundant log | |
log_rec3d_loss_dict({ | |
f'cano_{k}': v | |
for k, v in loss_cano_dict.items() | |
# if "loss" in k | |
}) | |
self.mp_trainer_rec.backward(loss) | |
if dist_util.get_rank() == 0 and self.step % 500 == 0: | |
if self.rec_cano: | |
self.log_img(micro, pred, pred_cano) | |
else: | |
self.log_img(micro, pred, None) | |
def log_img(self, micro, pred, pred_cano): | |
# gt_vis = th.cat([batch['img'], batch['depth']], dim=-1) | |
def norm_depth(pred_depth): # to [-1,1] | |
# pred_depth = pred['image_depth'] | |
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
pred_depth.min()) | |
return -(pred_depth * 2 - 1) | |
pred_img = pred['image_raw'] | |
gt_img = micro['img'] | |
# infer novel view also | |
# if self.loss_class.opt.symmetry_loss: | |
# pred_nv_img = nvs_pred | |
# else: | |
# ! replace with novel view prediction | |
# ! log another novel-view prediction | |
# pred_nv_img = self.rec_model( | |
# img=micro['img_to_encoder'], | |
# c=self.novel_view_poses) # pred: (B, 3, 64, 64) | |
# if 'depth' in micro: | |
gt_depth = micro['depth'] | |
if gt_depth.ndim == 3: | |
gt_depth = gt_depth.unsqueeze(1) | |
gt_depth = norm_depth(gt_depth) | |
# gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - | |
# gt_depth.min()) | |
# if True: | |
fg_mask = pred['image_mask'] * 2 - 1 # 0-1 | |
input_fg_mask = pred_cano['image_mask'] * 2 - 1 # 0-1 | |
if 'image_depth' in pred: | |
pred_depth = norm_depth(pred['image_depth']) | |
pred_nv_depth = norm_depth(pred_cano['image_depth']) | |
else: | |
pred_depth = th.zeros_like(gt_depth) | |
pred_nv_depth = th.zeros_like(gt_depth) | |
if 'image_sr' in pred: | |
if pred['image_sr'].shape[-1] == 512: | |
pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']], | |
dim=-1) | |
gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']], | |
dim=-1) | |
pred_depth = self.pool_512(pred_depth) | |
gt_depth = self.pool_512(gt_depth) | |
elif pred['image_sr'].shape[-1] == 256: | |
pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']], | |
dim=-1) | |
gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']], | |
dim=-1) | |
pred_depth = self.pool_256(pred_depth) | |
gt_depth = self.pool_256(gt_depth) | |
else: | |
pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']], | |
dim=-1) | |
gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']], | |
dim=-1) | |
gt_depth = self.pool_128(gt_depth) | |
pred_depth = self.pool_128(pred_depth) | |
else: | |
gt_img = self.pool_64(gt_img) | |
gt_depth = self.pool_64(gt_depth) | |
pred_vis = th.cat([ | |
pred_img, | |
pred_depth.repeat_interleave(3, dim=1), | |
fg_mask.repeat_interleave(3, dim=1), | |
], | |
dim=-1) # B, 3, H, W | |
pred_vis_nv = th.cat([ | |
pred_cano['image_raw'], | |
pred_nv_depth.repeat_interleave(3, dim=1), | |
input_fg_mask.repeat_interleave(3, dim=1), | |
], | |
dim=-1) # B, 3, H, W | |
pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) # cat in H dim | |
gt_vis = th.cat([ | |
gt_img, | |
gt_depth.repeat_interleave(3, dim=1), | |
th.zeros_like(gt_img) | |
], | |
dim=-1) # TODO, fail to load depth. range [0, 1] | |
if 'conf_sigma' in pred: | |
gt_vis = th.cat([gt_vis, fg_mask], dim=-1) # placeholder | |
# vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute( | |
vis = th.cat([gt_vis, pred_vis], dim=-2) | |
# .permute( | |
# 0, 2, 3, 1).cpu() | |
vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] // | |
64) # HWC | |
torchvision.utils.save_image( | |
vis_tensor, | |
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', | |
value_range=(-1, 1), | |
normalize=True) | |
# vis = vis.numpy() * 127.5 + 127.5 | |
# vis = vis.clip(0, 255).astype(np.uint8) | |
# Image.fromarray(vis).save( | |
# f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
logger.log('log vis to: ', | |
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
# self.writer.add_image(f'images', | |
# vis, | |
# self.step + self.resume_step, | |
# dataformats='HWC') | |
# return pred | |
class TrainLoop3DRecNVPatch(TrainLoop3DRecNV): | |
# add patch rendering | |
def __init__(self, | |
*, | |
rec_model, | |
loss_class, | |
data, | |
eval_data, | |
batch_size, | |
microbatch, | |
lr, | |
ema_rate, | |
log_interval, | |
eval_interval, | |
save_interval, | |
resume_checkpoint, | |
use_fp16=False, | |
fp16_scale_growth=0.001, | |
weight_decay=0, | |
lr_anneal_steps=0, | |
iterations=10001, | |
load_submodule_name='', | |
ignore_resume_opt=False, | |
model_name='rec', | |
use_amp=False, | |
**kwargs): | |
super().__init__(rec_model=rec_model, | |
loss_class=loss_class, | |
data=data, | |
eval_data=eval_data, | |
batch_size=batch_size, | |
microbatch=microbatch, | |
lr=lr, | |
ema_rate=ema_rate, | |
log_interval=log_interval, | |
eval_interval=eval_interval, | |
save_interval=save_interval, | |
resume_checkpoint=resume_checkpoint, | |
use_fp16=use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
weight_decay=weight_decay, | |
lr_anneal_steps=lr_anneal_steps, | |
iterations=iterations, | |
load_submodule_name=load_submodule_name, | |
ignore_resume_opt=ignore_resume_opt, | |
model_name=model_name, | |
use_amp=use_amp, | |
**kwargs) | |
# the rendrer | |
self.eg3d_model = self.rec_model.module.decoder.triplane_decoder # type: ignore | |
# self.rec_cano = False | |
self.rec_cano = True | |
def forward_backward(self, batch, *args, **kwargs): | |
# add patch sampling | |
self.mp_trainer_rec.zero_grad() | |
batch_size = batch['img_to_encoder'].shape[0] | |
for i in range(0, batch_size, self.microbatch): | |
micro = { | |
k: v[i:i + self.microbatch].to(dist_util.dev()) | |
for k, v in batch.items() | |
} | |
# ! sample rendering patch | |
target = { | |
**self.eg3d_model( | |
c=micro['nv_c'], # type: ignore | |
ws=None, | |
planes=None, | |
sample_ray_only=True, | |
fg_bbox=micro['nv_bbox']), # rays o / dir | |
} | |
patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ | |
'patch_rendering_resolution'] # type: ignore | |
cropped_target = { | |
k: | |
th.empty_like(v) | |
[..., :patch_rendering_resolution, :patch_rendering_resolution] | |
if k not in [ | |
'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', | |
'nv_img_sr', 'c' | |
] else v | |
for k, v in micro.items() | |
} | |
# crop according to uv sampling | |
for j in range(micro['img'].shape[0]): | |
top, left, height, width = target['ray_bboxes'][ | |
j] # list of tuple | |
# for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
for key in ('img', 'depth_mask', 'depth'): # type: ignore | |
# target[key][i:i+1] = torchvision.transforms.functional.crop( | |
# cropped_target[key][ | |
# j:j + 1] = torchvision.transforms.functional.crop( | |
# micro[key][j:j + 1], top, left, height, width) | |
cropped_target[f'{key}'][ # ! no nv_ here | |
j:j + 1] = torchvision.transforms.functional.crop( | |
micro[f'nv_{key}'][j:j + 1], top, left, height, | |
width) | |
# target.update(cropped_target) | |
# wrap forward within amp | |
with th.autocast(device_type='cuda', | |
dtype=th.float16, | |
enabled=self.mp_trainer_rec.use_amp): | |
# target_nvs = {} | |
# target_cano = {} | |
latent = self.rec_model(img=micro['img_to_encoder'], | |
behaviour='enc_dec_wo_triplane') | |
pred_nv = self.rec_model( | |
latent=latent, | |
c=micro['nv_c'], # predict novel view here | |
behaviour='triplane_dec', | |
ray_origins=target['ray_origins'], | |
ray_directions=target['ray_directions'], | |
) | |
# ! directly retrieve from target | |
# for k, v in target.items(): | |
# if k[:2] == 'nv': | |
# orig_key = k.replace('nv_', '') | |
# target_nvs[orig_key] = v | |
# target_cano[orig_key] = target[orig_key] | |
with self.rec_model.no_sync(): # type: ignore | |
loss, loss_dict, _ = self.loss_class(pred_nv, | |
cropped_target, | |
step=self.step + | |
self.resume_step, | |
test_mode=False, | |
return_fg_mask=True, | |
conf_sigma_l1=None, | |
conf_sigma_percl=None) | |
log_rec3d_loss_dict(loss_dict) | |
if self.rec_cano: | |
cano_target = { | |
**self.eg3d_model( | |
c=micro['c'], # type: ignore | |
ws=None, | |
planes=None, | |
sample_ray_only=True, | |
fg_bbox=micro['bbox']), # rays o / dir | |
} | |
cano_cropped_target = { | |
k: th.empty_like(v) | |
for k, v in cropped_target.items() | |
} | |
for j in range(micro['img'].shape[0]): | |
top, left, height, width = cano_target['ray_bboxes'][ | |
j] # list of tuple | |
# for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
for key in ('img', 'depth_mask', | |
'depth'): # type: ignore | |
# target[key][i:i+1] = torchvision.transforms.functional.crop( | |
cano_cropped_target[key][ | |
j:j + | |
1] = torchvision.transforms.functional.crop( | |
micro[key][j:j + 1], top, left, height, | |
width) | |
# cano_target.update(cano_cropped_target) | |
pred_cano = self.rec_model( | |
latent=latent, | |
c=micro['c'], | |
behaviour='triplane_dec', | |
ray_origins=cano_target['ray_origins'], | |
ray_directions=cano_target['ray_directions'], | |
) | |
with self.rec_model.no_sync(): # type: ignore | |
fg_mask = cano_cropped_target['depth_mask'].unsqueeze( | |
1).repeat_interleave(3, 1).float() | |
loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss( | |
pred_cano['image_raw'], | |
cano_cropped_target['img'], | |
fg_mask, | |
step=self.step + self.resume_step, | |
test_mode=False, | |
) | |
loss = loss + loss_cano | |
# remove redundant log | |
log_rec3d_loss_dict({ | |
f'cano_{k}': v | |
for k, v in loss_cano_dict.items() | |
# if "loss" in k | |
}) | |
self.mp_trainer_rec.backward(loss) | |
if dist_util.get_rank() == 0 and self.step % 500 == 0: | |
self.log_patch_img(cropped_target, pred_nv, pred_cano) | |
def log_patch_img(self, micro, pred, pred_cano): | |
# gt_vis = th.cat([batch['img'], batch['depth']], dim=-1) | |
def norm_depth(pred_depth): # to [-1,1] | |
# pred_depth = pred['image_depth'] | |
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
pred_depth.min()) | |
return -(pred_depth * 2 - 1) | |
pred_img = pred['image_raw'] | |
gt_img = micro['img'] | |
# infer novel view also | |
# if self.loss_class.opt.symmetry_loss: | |
# pred_nv_img = nvs_pred | |
# else: | |
# ! replace with novel view prediction | |
# ! log another novel-view prediction | |
# pred_nv_img = self.rec_model( | |
# img=micro['img_to_encoder'], | |
# c=self.novel_view_poses) # pred: (B, 3, 64, 64) | |
# if 'depth' in micro: | |
gt_depth = micro['depth'] | |
if gt_depth.ndim == 3: | |
gt_depth = gt_depth.unsqueeze(1) | |
gt_depth = norm_depth(gt_depth) | |
# gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - | |
# gt_depth.min()) | |
# if True: | |
fg_mask = pred['image_mask'] * 2 - 1 # 0-1 | |
input_fg_mask = pred_cano['image_mask'] * 2 - 1 # 0-1 | |
if 'image_depth' in pred: | |
pred_depth = norm_depth(pred['image_depth']) | |
pred_cano_depth = norm_depth(pred_cano['image_depth']) | |
else: | |
pred_depth = th.zeros_like(gt_depth) | |
pred_cano_depth = th.zeros_like(gt_depth) | |
# if 'image_sr' in pred: | |
# if pred['image_sr'].shape[-1] == 512: | |
# pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']], | |
# dim=-1) | |
# gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']], | |
# dim=-1) | |
# pred_depth = self.pool_512(pred_depth) | |
# gt_depth = self.pool_512(gt_depth) | |
# elif pred['image_sr'].shape[-1] == 256: | |
# pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']], | |
# dim=-1) | |
# gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']], | |
# dim=-1) | |
# pred_depth = self.pool_256(pred_depth) | |
# gt_depth = self.pool_256(gt_depth) | |
# else: | |
# pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']], | |
# dim=-1) | |
# gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']], | |
# dim=-1) | |
# gt_depth = self.pool_128(gt_depth) | |
# pred_depth = self.pool_128(pred_depth) | |
# else: | |
# gt_img = self.pool_64(gt_img) | |
# gt_depth = self.pool_64(gt_depth) | |
pred_vis = th.cat([ | |
pred_img, | |
pred_depth.repeat_interleave(3, dim=1), | |
fg_mask.repeat_interleave(3, dim=1), | |
], | |
dim=-1) # B, 3, H, W | |
pred_vis_nv = th.cat([ | |
pred_cano['image_raw'], | |
pred_cano_depth.repeat_interleave(3, dim=1), | |
input_fg_mask.repeat_interleave(3, dim=1), | |
], | |
dim=-1) # B, 3, H, W | |
pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) # cat in H dim | |
gt_vis = th.cat([ | |
gt_img, | |
gt_depth.repeat_interleave(3, dim=1), | |
th.zeros_like(gt_img) | |
], | |
dim=-1) # TODO, fail to load depth. range [0, 1] | |
# if 'conf_sigma' in pred: | |
# gt_vis = th.cat([gt_vis, fg_mask], dim=-1) # placeholder | |
# vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute( | |
# st() | |
vis = th.cat([gt_vis, pred_vis], dim=-2) | |
# .permute( | |
# 0, 2, 3, 1).cpu() | |
vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] // | |
64) # HWC | |
torchvision.utils.save_image( | |
vis_tensor, | |
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', | |
value_range=(-1, 1), | |
normalize=True) | |
logger.log('log vis to: ', | |
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
# self.writer.add_image(f'images', | |
# vis, | |
# self.step + self.resume_step, | |
# dataformats='HWC') | |
class TrainLoop3DRecNVPatchSingleForward(TrainLoop3DRecNVPatch): | |
def __init__(self, | |
*, | |
rec_model, | |
loss_class, | |
data, | |
eval_data, | |
batch_size, | |
microbatch, | |
lr, | |
ema_rate, | |
log_interval, | |
eval_interval, | |
save_interval, | |
resume_checkpoint, | |
use_fp16=False, | |
fp16_scale_growth=0.001, | |
weight_decay=0, | |
lr_anneal_steps=0, | |
iterations=10001, | |
load_submodule_name='', | |
ignore_resume_opt=False, | |
model_name='rec', | |
use_amp=False, | |
**kwargs): | |
super().__init__(rec_model=rec_model, | |
loss_class=loss_class, | |
data=data, | |
eval_data=eval_data, | |
batch_size=batch_size, | |
microbatch=microbatch, | |
lr=lr, | |
ema_rate=ema_rate, | |
log_interval=log_interval, | |
eval_interval=eval_interval, | |
save_interval=save_interval, | |
resume_checkpoint=resume_checkpoint, | |
use_fp16=use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
weight_decay=weight_decay, | |
lr_anneal_steps=lr_anneal_steps, | |
iterations=iterations, | |
load_submodule_name=load_submodule_name, | |
ignore_resume_opt=ignore_resume_opt, | |
model_name=model_name, | |
use_amp=use_amp, | |
**kwargs) | |
def forward_backward(self, batch, *args, **kwargs): | |
# add patch sampling | |
self.mp_trainer_rec.zero_grad() | |
batch_size = batch['img_to_encoder'].shape[0] | |
batch.pop('caption') # not required | |
batch.pop('ins') # not required | |
# batch.pop('nv_caption') # not required | |
for i in range(0, batch_size, self.microbatch): | |
micro = { | |
k: | |
v[i:i + self.microbatch].to(dist_util.dev()) if isinstance( | |
v, th.Tensor) else v[i:i + self.microbatch] | |
for k, v in batch.items() | |
} | |
# ! sample rendering patch | |
target = { | |
**self.eg3d_model( | |
c=micro['nv_c'], # type: ignore | |
ws=None, | |
planes=None, | |
sample_ray_only=True, | |
fg_bbox=micro['nv_bbox']), # rays o / dir | |
} | |
patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ | |
'patch_rendering_resolution'] # type: ignore | |
cropped_target = { | |
k: | |
th.empty_like(v) | |
[..., :patch_rendering_resolution, :patch_rendering_resolution] | |
if k not in [ | |
'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', | |
'nv_img_sr', 'c', 'caption', 'nv_caption' | |
] else v | |
for k, v in micro.items() | |
} | |
# crop according to uv sampling | |
for j in range(micro['img'].shape[0]): | |
top, left, height, width = target['ray_bboxes'][ | |
j] # list of tuple | |
# for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
for key in ('img', 'depth_mask', 'depth'): # type: ignore | |
# target[key][i:i+1] = torchvision.transforms.functional.crop( | |
# cropped_target[key][ | |
# j:j + 1] = torchvision.transforms.functional.crop( | |
# micro[key][j:j + 1], top, left, height, width) | |
cropped_target[f'{key}'][ # ! no nv_ here | |
j:j + 1] = torchvision.transforms.functional.crop( | |
micro[f'nv_{key}'][j:j + 1], top, left, height, | |
width) | |
# ! cano view loss | |
cano_target = { | |
**self.eg3d_model( | |
c=micro['c'], # type: ignore | |
ws=None, | |
planes=None, | |
sample_ray_only=True, | |
fg_bbox=micro['bbox']), # rays o / dir | |
} | |
# cano_cropped_target = { | |
# k: th.empty_like(v) | |
# for k, v in cropped_target.items() | |
# } | |
# for j in range(micro['img'].shape[0]): | |
# top, left, height, width = cano_target['ray_bboxes'][ | |
# j] # list of tuple | |
# # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
# for key in ('img', 'depth_mask', 'depth'): # type: ignore | |
# # target[key][i:i+1] = torchvision.transforms.functional.crop( | |
# cano_cropped_target[key][ | |
# j:j + 1] = torchvision.transforms.functional.crop( | |
# micro[key][j:j + 1], top, left, height, width) | |
# ! vit no amp | |
latent = self.rec_model(img=micro['img_to_encoder'], | |
behaviour='enc_dec_wo_triplane') | |
# wrap forward within amp | |
with th.autocast(device_type='cuda', | |
dtype=th.float16, | |
enabled=self.mp_trainer_rec.use_amp): | |
# c = th.cat([micro['nv_c'], micro['c']]), # predict novel view here | |
# c = th.cat([micro['nv_c'].repeat(3, 1), micro['c']]), # predict novel view here | |
instance_mv_num = batch_size // 4 # 4 pairs by default | |
# instance_mv_num = 4 | |
# ! roll views for multi-view supervision | |
c = th.cat([ | |
micro['nv_c'].roll(instance_mv_num * i, dims=0) | |
for i in range(1, 4) | |
] | |
# + [micro['c']] | |
) # predict novel view here | |
ray_origins = th.cat( | |
[ | |
target['ray_origins'].roll(instance_mv_num * i, dims=0) | |
for i in range(1, 4) | |
] | |
# + [cano_target['ray_origins'] ] | |
, | |
0) | |
ray_directions = th.cat([ | |
target['ray_directions'].roll(instance_mv_num * i, dims=0) | |
for i in range(1, 4) | |
] | |
# + [cano_target['ray_directions'] ] | |
) | |
pred_nv_cano = self.rec_model( | |
# latent=latent.expand(2,), | |
latent={ | |
'latent_after_vit': # ! triplane for rendering | |
# latent['latent_after_vit'].repeat(2, 1, 1, 1) | |
latent['latent_after_vit'].repeat(3, 1, 1, 1) | |
}, | |
c=c, | |
behaviour='triplane_dec', | |
# ray_origins=target['ray_origins'], | |
# ray_directions=target['ray_directions'], | |
ray_origins=ray_origins, | |
ray_directions=ray_directions, | |
) | |
pred_nv_cano.update( | |
latent | |
) # torchvision.utils.save_image(pred_nv_cano['image_raw'], 'pred.png', normalize=True) | |
# gt = { | |
# k: th.cat([v, cano_cropped_target[k]], 0) | |
# for k, v in cropped_target.items() | |
# } | |
gt = { | |
k: | |
th.cat( | |
[ | |
v.roll(instance_mv_num * i, dims=0) | |
for i in range(1, 4) | |
] | |
# + [cano_cropped_target[k] ] | |
, | |
0) | |
for k, v in cropped_target.items() | |
} # torchvision.utils.save_image(gt['img'], 'gt.png', normalize=True) | |
with self.rec_model.no_sync(): # type: ignore | |
loss, loss_dict, _ = self.loss_class( | |
pred_nv_cano, | |
gt, # prepare merged data | |
step=self.step + self.resume_step, | |
test_mode=False, | |
return_fg_mask=True, | |
conf_sigma_l1=None, | |
conf_sigma_percl=None) | |
log_rec3d_loss_dict(loss_dict) | |
self.mp_trainer_rec.backward(loss) | |
# for name, p in self.rec_model.named_parameters(): | |
# if p.grad is None: | |
# logger.log(f"found rec unused param: {name}") | |
if dist_util.get_rank() == 0 and self.step % 500 == 0: | |
micro_bs = micro['img_to_encoder'].shape[0] | |
self.log_patch_img( # record one cano view and one novel view | |
cropped_target, | |
{ | |
k: pred_nv_cano[k][-micro_bs:] | |
for k in ['image_raw', 'image_depth', 'image_mask'] | |
}, | |
{ | |
k: pred_nv_cano[k][:micro_bs] | |
for k in ['image_raw', 'image_depth', 'image_mask'] | |
}, | |
) | |
def eval_loop(self): | |
return super().eval_loop() | |
# def eval_loop(self, c_list:list): | |
def eval_novelview_loop_old(self, camera=None): | |
# novel view synthesis given evaluation camera trajectory | |
all_loss_dict = [] | |
novel_view_micro = {} | |
# ! randomly inference an instance | |
export_mesh = True | |
if export_mesh: | |
Path(f'{logger.get_dir()}/FID_Cals/').mkdir(parents=True, | |
exist_ok=True) | |
# for i in range(0, len(c_list), 1): # TODO, larger batch size for eval | |
batch = {} | |
# if camera is not None: | |
# # batch['c'] = camera.to(batch['c'].device()) | |
# batch['c'] = camera.clone() | |
# else: | |
# batch = | |
for eval_idx, render_reference in enumerate(tqdm(self.eval_data)): | |
if eval_idx > 500: | |
break | |
video_out = imageio.get_writer( | |
f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}_{eval_idx}.mp4', | |
mode='I', | |
fps=25, | |
codec='libx264') | |
with open( | |
f'{logger.get_dir()}/triplane_{self.step+self.resume_step}_{eval_idx}_caption.txt', | |
'w') as f: | |
f.write(render_reference['caption']) | |
for key in ['ins', 'bbox', 'caption']: | |
if key in render_reference: | |
render_reference.pop(key) | |
real_flag = False | |
mv_flag = False # TODO, use full-instance for evaluation? Calculate the metrics. | |
if render_reference['c'].shape[:2] == (1, 40): | |
real_flag = True | |
# real img monocular reconstruction | |
# compat lst for enumerate | |
render_reference = [{ | |
k: v[0][idx:idx + 1] | |
for k, v in render_reference.items() | |
} for idx in range(40)] | |
elif render_reference['c'].shape[0] == 8: | |
mv_flag = True | |
render_reference = { | |
k: v[:4] | |
for k, v in render_reference.items() | |
} | |
# save gt | |
torchvision.utils.save_image( | |
render_reference[0:4]['img'], | |
logger.get_dir() + '/FID_Cals/{}_inp.png'.format(eval_idx), | |
padding=0, | |
normalize=True, | |
value_range=(-1, 1), | |
) | |
# torchvision.utils.save_image(render_reference[4:8]['img'], | |
# logger.get_dir() + '/FID_Cals/{}_inp2.png'.format(eval_idx), | |
# padding=0, | |
# normalize=True, | |
# value_range=(-1,1), | |
# ) | |
else: | |
# compat lst for enumerate | |
st() | |
render_reference = [{ | |
k: v[idx:idx + 1] | |
for k, v in render_reference.items() | |
} for idx in range(40)] | |
# ! single-view version | |
render_reference[0]['img_to_encoder'] = render_reference[14][ | |
'img_to_encoder'] # encode side view | |
render_reference[0]['img'] = render_reference[14][ | |
'img'] # encode side view | |
# save gt | |
torchvision.utils.save_image( | |
render_reference[0]['img'], | |
logger.get_dir() + '/FID_Cals/{}_gt.png'.format(eval_idx), | |
padding=0, | |
normalize=True, | |
value_range=(-1, 1)) | |
# ! TODO, merge with render_video_given_triplane later | |
for i, batch in enumerate(render_reference): | |
# for i in range(0, 8, self.microbatch): | |
# c = c_list[i].to(dist_util.dev()).reshape(1, -1) | |
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} | |
st() | |
if i == 0: | |
if mv_flag: | |
novel_view_micro = None | |
else: | |
novel_view_micro = { | |
k: | |
v[0:1].to(dist_util.dev()).repeat_interleave( | |
# v[14:15].to(dist_util.dev()).repeat_interleave( | |
micro['img'].shape[0], | |
0) if isinstance(v, th.Tensor) else v[0:1] | |
for k, v in batch.items() | |
} | |
else: | |
if i == 1: | |
# ! output mesh | |
if export_mesh: | |
# ! get planes first | |
# self.latent_name = 'latent_normalized' # normalized triplane latent | |
# ddpm_latent = { | |
# self.latent_name: planes, | |
# } | |
# ddpm_latent.update(self.rec_model(latent=ddpm_latent, behaviour='decode_after_vae_no_render')) | |
# mesh_size = 512 | |
# mesh_size = 256 | |
mesh_size = 384 | |
# mesh_size = 320 | |
# mesh_thres = 3 # TODO, requires tuning | |
# mesh_thres = 5 # TODO, requires tuning | |
mesh_thres = 10 # TODO, requires tuning | |
import mcubes | |
import trimesh | |
dump_path = f'{logger.get_dir()}/mesh/' | |
os.makedirs(dump_path, exist_ok=True) | |
grid_out = self.rec_model( | |
latent=pred, | |
grid_size=mesh_size, | |
behaviour='triplane_decode_grid', | |
) | |
vtx, faces = mcubes.marching_cubes( | |
grid_out['sigma'].squeeze(0).squeeze( | |
-1).cpu().numpy(), mesh_thres) | |
vtx = vtx / (mesh_size - 1) * 2 - 1 | |
# vtx_tensor = th.tensor(vtx, dtype=th.float32, device=dist_util.dev()).unsqueeze(0) | |
# vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].squeeze(0).cpu().numpy() # (0, 1) | |
# vtx_colors = (vtx_colors * 255).astype(np.uint8) | |
# mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors) | |
mesh = trimesh.Trimesh( | |
vertices=vtx, | |
faces=faces, | |
) | |
mesh_dump_path = os.path.join( | |
dump_path, f'{eval_idx}.ply') | |
mesh.export(mesh_dump_path, 'ply') | |
print(f"Mesh dumped to {dump_path}") | |
del grid_out, mesh | |
th.cuda.empty_cache() | |
# return | |
# st() | |
# if novel_view_micro['c'].shape[0] < micro['img'].shape[0]: | |
novel_view_micro = { | |
k: | |
v[0:1].to(dist_util.dev()).repeat_interleave( | |
micro['img'].shape[0], 0) | |
for k, v in novel_view_micro.items() | |
} | |
pred = self.rec_model(img=novel_view_micro['img_to_encoder'], | |
c=micro['c']) # pred: (B, 3, 64, 64) | |
# target = { | |
# 'img': micro['img'], | |
# 'depth': micro['depth'], | |
# 'depth_mask': micro['depth_mask'] | |
# } | |
# targe | |
# if not export_mesh: | |
if not real_flag: | |
_, loss_dict = self.loss_class(pred, micro, test_mode=True) | |
all_loss_dict.append(loss_dict) | |
# ! move to other places, add tensorboard | |
# pred_vis = th.cat([ | |
# pred['image_raw'], | |
# -pred['image_depth'].repeat_interleave(3, dim=1) | |
# ], | |
# dim=-1) | |
# normalize depth | |
# if True: | |
pred_depth = pred['image_depth'] | |
pred_depth = (pred_depth - pred_depth.min()) / ( | |
pred_depth.max() - pred_depth.min()) | |
if 'image_sr' in pred: | |
if pred['image_sr'].shape[-1] == 512: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_512(pred['image_raw']), pred['image_sr'], | |
self.pool_512(pred_depth).repeat_interleave(3, | |
dim=1) | |
], | |
dim=-1) | |
elif pred['image_sr'].shape[-1] == 256: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_256(pred['image_raw']), pred['image_sr'], | |
self.pool_256(pred_depth).repeat_interleave(3, | |
dim=1) | |
], | |
dim=-1) | |
else: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_128(pred['image_raw']), | |
self.pool_128(pred['image_sr']), | |
self.pool_128(pred_depth).repeat_interleave(3, | |
dim=1) | |
], | |
dim=-1) | |
else: | |
# pred_vis = th.cat([ | |
# self.pool_64(micro['img']), pred['image_raw'], | |
# pred_depth.repeat_interleave(3, dim=1) | |
# ], | |
# dim=-1) # B, 3, H, W | |
pooled_depth = self.pool_128(pred_depth).repeat_interleave( | |
3, dim=1) | |
pred_vis = th.cat( | |
[ | |
# self.pool_128(micro['img']), | |
self.pool_128(novel_view_micro['img'] | |
), # use the input here | |
self.pool_128(pred['image_raw']), | |
pooled_depth, | |
], | |
dim=-1) # B, 3, H, W | |
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() | |
vis = vis * 127.5 + 127.5 | |
vis = vis.clip(0, 255).astype(np.uint8) | |
if export_mesh: | |
# save image | |
torchvision.utils.save_image( | |
pred['image_raw'], | |
logger.get_dir() + | |
'/FID_Cals/{}_{}.png'.format(eval_idx, i), | |
padding=0, | |
normalize=True, | |
value_range=(-1, 1)) | |
torchvision.utils.save_image( | |
pooled_depth, | |
logger.get_dir() + | |
'/FID_Cals/{}_{}_dpeth.png'.format(eval_idx, i), | |
padding=0, | |
normalize=True, | |
value_range=(0, 1)) | |
# st() | |
for j in range(vis.shape[0]): | |
video_out.append_data(vis[j]) | |
video_out.close() | |
# if not export_mesh: | |
if not real_flag or mv_flag: | |
val_scores_for_logging = calc_average_loss(all_loss_dict) | |
with open(os.path.join(logger.get_dir(), 'scores_novelview.json'), | |
'a') as f: | |
json.dump({'step': self.step, **val_scores_for_logging}, f) | |
# * log to tensorboard | |
for k, v in val_scores_for_logging.items(): | |
self.writer.add_scalar(f'Eval/NovelView/{k}', v, | |
self.step + self.resume_step) | |
del video_out | |
# del pred_vis | |
# del pred | |
th.cuda.empty_cache() | |
# def eval_loop(self, c_list:list): | |
def eval_novelview_loop(self, camera=None, save_latent=False): | |
# novel view synthesis given evaluation camera trajectory | |
if save_latent: # for diffusion learning | |
latent_dir = Path(f'{logger.get_dir()}/latent_dir') | |
latent_dir.mkdir(exist_ok=True, parents=True) | |
# wds_path = os.path.join(logger.get_dir(), 'latent_dir', | |
# f'wds-%06d.tar') | |
# sink = wds.ShardWriter(wds_path, start_shard=0) | |
# eval_batch_size = 20 | |
# eval_batch_size = 1 | |
eval_batch_size = 40 # ! for i23d | |
for eval_idx, micro in enumerate(tqdm(self.eval_data)): | |
latent = self.rec_model( | |
img=micro['img_to_encoder'], | |
behaviour='encoder_vae') # pred: (B, 3, 64, 64) | |
# torchvision.utils.save_image(micro['img'], 'inp.jpg') | |
if save_latent: | |
# np.save(f'{logger.get_dir()}/latent_dir/{eval_idx}.npy', latent[self.latent_name].cpu().numpy()) | |
latent_save_dir = f'{logger.get_dir()}/latent_dir/{micro["ins"][0]}' | |
Path(latent_save_dir).mkdir(parents=True, exist_ok=True) | |
np.save(f'{latent_save_dir}/latent.npy', | |
latent[self.latent_name][0].cpu().numpy()) | |
assert all([ | |
micro['ins'][0] == micro['ins'][i] | |
for i in range(micro['c'].shape[0]) | |
]) # ! assert same instance | |
if eval_idx < 50: | |
# if False: | |
self.render_video_given_triplane( | |
latent[self.latent_name], # B 12 32 32 | |
self.rec_model, # compatible with join_model | |
name_prefix=f'{self.step + self.resume_step}_{eval_idx}', | |
save_img=False, | |
render_reference={'c': camera}, | |
save_mesh=True) | |
class TrainLoop3DRecNVPatchSingleForwardMV(TrainLoop3DRecNVPatchSingleForward): | |
def __init__(self, | |
*, | |
rec_model, | |
loss_class, | |
data, | |
eval_data, | |
batch_size, | |
microbatch, | |
lr, | |
ema_rate, | |
log_interval, | |
eval_interval, | |
save_interval, | |
resume_checkpoint, | |
use_fp16=False, | |
fp16_scale_growth=0.001, | |
weight_decay=0, | |
lr_anneal_steps=0, | |
iterations=10001, | |
load_submodule_name='', | |
ignore_resume_opt=False, | |
model_name='rec', | |
use_amp=False, | |
**kwargs): | |
super().__init__(rec_model=rec_model, | |
loss_class=loss_class, | |
data=data, | |
eval_data=eval_data, | |
batch_size=batch_size, | |
microbatch=microbatch, | |
lr=lr, | |
ema_rate=ema_rate, | |
log_interval=log_interval, | |
eval_interval=eval_interval, | |
save_interval=save_interval, | |
resume_checkpoint=resume_checkpoint, | |
use_fp16=use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
weight_decay=weight_decay, | |
lr_anneal_steps=lr_anneal_steps, | |
iterations=iterations, | |
load_submodule_name=load_submodule_name, | |
ignore_resume_opt=ignore_resume_opt, | |
model_name=model_name, | |
use_amp=use_amp, | |
**kwargs) | |
def forward_backward(self, batch, behaviour='g_step', *args, **kwargs): | |
# add patch sampling | |
self.mp_trainer_rec.zero_grad() | |
batch_size = batch['img_to_encoder'].shape[0] | |
batch.pop('caption') # not required | |
batch.pop('nv_caption') # not required | |
batch.pop('ins') # not required | |
batch.pop('nv_ins') # not required | |
if '__key__' in batch.keys(): | |
batch.pop('__key__') | |
for i in range(0, batch_size, self.microbatch): | |
micro = { | |
k: | |
v[i:i + self.microbatch].to(dist_util.dev()) if isinstance( | |
v, th.Tensor) else v[i:i + self.microbatch] | |
for k, v in batch.items() | |
} | |
# ! sample rendering patch | |
# nv_c = th.cat([micro['nv_c'], micro['c']]) | |
nv_c = th.cat([micro['nv_c'], micro['c']]) | |
# nv_c = micro['nv_c'] | |
target = { | |
**self.eg3d_model( | |
c=nv_c, # type: ignore | |
ws=None, | |
planes=None, | |
sample_ray_only=True, | |
fg_bbox=th.cat([micro['nv_bbox'], micro['bbox']])), # rays o / dir | |
} | |
patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ | |
'patch_rendering_resolution'] # type: ignore | |
cropped_target = { | |
k: | |
th.empty_like(v).repeat_interleave(2, 0) | |
[..., :patch_rendering_resolution, :patch_rendering_resolution] | |
if k not in [ | |
'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', | |
'nv_img_sr', 'c', 'caption', 'nv_caption' | |
] else v | |
for k, v in micro.items() | |
} | |
# crop according to uv sampling | |
for j in range(2 * self.microbatch): | |
top, left, height, width = target['ray_bboxes'][ | |
j] # list of tuple | |
# for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore | |
for key in ('img', 'depth_mask', 'depth'): # type: ignore | |
if j < self.microbatch: | |
cropped_target[f'{key}'][ # ! no nv_ here | |
j:j + 1] = torchvision.transforms.functional.crop( | |
micro[f'nv_{key}'][j:j + 1], top, left, height, | |
width) | |
else: | |
cropped_target[f'{key}'][ # ! no nv_ here | |
j:j + 1] = torchvision.transforms.functional.crop( | |
micro[f'{key}'][j - self.microbatch:j - | |
self.microbatch + 1], top, | |
left, height, width) | |
# ! vit no amp | |
latent = self.rec_model(img=micro['img_to_encoder'], | |
behaviour='enc_dec_wo_triplane') | |
# wrap forward within amp | |
with th.autocast(device_type='cuda', | |
dtype=th.float16, | |
enabled=self.mp_trainer_rec.use_amp): | |
# c = th.cat([micro['nv_c'], micro['c']]), # predict novel view here | |
# c = th.cat([micro['nv_c'].repeat(3, 1), micro['c']]), # predict novel view here | |
# instance_mv_num = batch_size // 4 # 4 pairs by default | |
# instance_mv_num = 4 | |
# ! roll views for multi-view supervision | |
# c = micro['nv_c'] | |
ray_origins = target['ray_origins'] | |
ray_directions = target['ray_directions'] | |
pred_nv_cano = self.rec_model( | |
# latent=latent.expand(2,), | |
latent={ | |
'latent_after_vit': # ! triplane for rendering | |
latent['latent_after_vit'].repeat_interleave(4, dim=0).repeat(2,1,1,1) # NV=4 | |
# latent['latent_after_vit'].repeat_interleave(8, dim=0) # NV=4 | |
}, | |
c=nv_c, | |
behaviour='triplane_dec', | |
ray_origins=ray_origins, | |
ray_directions=ray_directions, | |
) | |
pred_nv_cano.update( | |
latent | |
) # torchvision.utils.save_image(pred_nv_cano['image_raw'], 'pred.png', normalize=True) | |
gt = cropped_target | |
with self.rec_model.no_sync(): # type: ignore | |
loss, loss_dict, _ = self.loss_class( | |
pred_nv_cano, | |
gt, # prepare merged data | |
step=self.step + self.resume_step, | |
test_mode=False, | |
return_fg_mask=True, | |
behaviour=behaviour, | |
conf_sigma_l1=None, | |
conf_sigma_percl=None) | |
log_rec3d_loss_dict(loss_dict) | |
self.mp_trainer_rec.backward(loss) | |
# for name, p in self.rec_model.named_parameters(): | |
# if p.grad is None: | |
# logger.log(f"found rec unused param: {name}") | |
# torchvision.utils.save_image(cropped_target['img'], 'gt.png', normalize=True) | |
# torchvision.utils.save_image( pred_nv_cano['image_raw'], 'pred.png', normalize=True) | |
if dist_util.get_rank() == 0 and self.step % 500 == 0 and i == 0: | |
try: | |
torchvision.utils.save_image( | |
th.cat( | |
[cropped_target['img'], pred_nv_cano['image_raw'] | |
], ), | |
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', | |
normalize=True) | |
logger.log( | |
'log vis to: ', | |
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
except Exception as e: | |
logger.log(e) | |
# micro_bs = micro['img_to_encoder'].shape[0] | |
# self.log_patch_img( # record one cano view and one novel view | |
# cropped_target, | |
# { | |
# k: pred_nv_cano[k][0:1] | |
# for k in ['image_raw', 'image_depth', 'image_mask'] | |
# }, | |
# { | |
# k: pred_nv_cano[k][1:2] | |
# for k in ['image_raw', 'image_depth', 'image_mask'] | |
# }, | |
# ) | |
# def save(self): | |
# return super().save() | |
class TrainLoop3DRecNVPatchSingleForwardMVAdvLoss( | |
TrainLoop3DRecNVPatchSingleForwardMV): | |
def __init__(self, | |
*, | |
rec_model, | |
loss_class, | |
data, | |
eval_data, | |
batch_size, | |
microbatch, | |
lr, | |
ema_rate, | |
log_interval, | |
eval_interval, | |
save_interval, | |
resume_checkpoint, | |
use_fp16=False, | |
fp16_scale_growth=0.001, | |
weight_decay=0, | |
lr_anneal_steps=0, | |
iterations=10001, | |
load_submodule_name='', | |
ignore_resume_opt=False, | |
model_name='rec', | |
use_amp=False, | |
**kwargs): | |
super().__init__(rec_model=rec_model, | |
loss_class=loss_class, | |
data=data, | |
eval_data=eval_data, | |
batch_size=batch_size, | |
microbatch=microbatch, | |
lr=lr, | |
ema_rate=ema_rate, | |
log_interval=log_interval, | |
eval_interval=eval_interval, | |
save_interval=save_interval, | |
resume_checkpoint=resume_checkpoint, | |
use_fp16=use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
weight_decay=weight_decay, | |
lr_anneal_steps=lr_anneal_steps, | |
iterations=iterations, | |
load_submodule_name=load_submodule_name, | |
ignore_resume_opt=ignore_resume_opt, | |
model_name=model_name, | |
use_amp=use_amp, | |
**kwargs) | |
# create discriminator | |
disc_params = self.loss_class.get_trainable_parameters() | |
self.mp_trainer_disc = MixedPrecisionTrainer( | |
model=self.loss_class.discriminator, | |
use_fp16=self.use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
model_name='disc', | |
use_amp=use_amp, | |
model_params=disc_params) | |
# st() # check self.lr | |
self.opt_disc = AdamW( | |
self.mp_trainer_disc.master_params, | |
lr=self.lr, # follow sd code base | |
betas=(0, 0.999), | |
eps=1e-8) | |
# TODO, is loss cls already in the DDP? | |
if self.use_ddp: | |
self.ddp_disc = DDP( | |
self.loss_class.discriminator, | |
device_ids=[dist_util.dev()], | |
output_device=dist_util.dev(), | |
broadcast_buffers=False, | |
bucket_cap_mb=128, | |
find_unused_parameters=False, | |
) | |
else: | |
self.ddp_disc = self.loss_class.discriminator | |
# def run_st | |
# def run_step(self, batch, *args): | |
# self.forward_backward(batch) | |
# took_step = self.mp_trainer_rec.optimize(self.opt) | |
# if took_step: | |
# self._update_ema() | |
# self._anneal_lr() | |
# self.log_step() | |
def save(self, mp_trainer=None, model_name='rec'): | |
if mp_trainer is None: | |
mp_trainer = self.mp_trainer_rec | |
def save_checkpoint(rate, params): | |
state_dict = mp_trainer.master_params_to_state_dict(params) | |
if dist_util.get_rank() == 0: | |
logger.log(f"saving model {model_name} {rate}...") | |
if not rate: | |
filename = f"model_{model_name}{(self.step+self.resume_step):07d}.pt" | |
else: | |
filename = f"ema_{model_name}_{rate}_{(self.step+self.resume_step):07d}.pt" | |
with bf.BlobFile(bf.join(get_blob_logdir(), filename), | |
"wb") as f: | |
th.save(state_dict, f) | |
save_checkpoint(0, mp_trainer.master_params) | |
dist.barrier() | |
def run_step(self, batch, step='g_step'): | |
# self.forward_backward(batch) | |
if step == 'g_step': | |
self.forward_backward(batch, behaviour='g_step') | |
took_step_g_rec = self.mp_trainer_rec.optimize(self.opt) | |
if took_step_g_rec: | |
self._update_ema() # g_ema | |
elif step == 'd_step': | |
self.forward_backward(batch, behaviour='d_step') | |
_ = self.mp_trainer_disc.optimize(self.opt_disc) | |
self._anneal_lr() | |
self.log_step() | |
def run_loop(self, batch=None): | |
while (not self.lr_anneal_steps | |
or self.step + self.resume_step < self.lr_anneal_steps): | |
batch = next(self.data) | |
self.run_step(batch, 'g_step') | |
batch = next(self.data) | |
self.run_step(batch, 'd_step') | |
if self.step % 1000 == 0: | |
dist_util.synchronize() | |
if self.step % 10000 == 0: | |
th.cuda.empty_cache() # avoid memory leak | |
if self.step % self.log_interval == 0 and dist_util.get_rank( | |
) == 0: | |
out = logger.dumpkvs() | |
# * log to tensorboard | |
for k, v in out.items(): | |
self.writer.add_scalar(f'Loss/{k}', v, | |
self.step + self.resume_step) | |
if self.step % self.eval_interval == 0 and self.step != 0: | |
if dist_util.get_rank() == 0: | |
try: | |
self.eval_loop() | |
except Exception as e: | |
logger.log(e) | |
dist_util.synchronize() | |
# if self.step % self.save_interval == 0 and self.step != 0: | |
if self.step % self.save_interval == 0: | |
self.save() | |
self.save(self.mp_trainer_disc, | |
self.mp_trainer_disc.model_name) | |
dist_util.synchronize() | |
# Run for a finite amount of time in integration tests. | |
if os.environ.get("DIFFUSION_TRAINING_TEST", | |
"") and self.step > 0: | |
return | |
self.step += 1 | |
if self.step > self.iterations: | |
logger.log('reached maximum iterations, exiting') | |
# Save the last checkpoint if it wasn't already saved. | |
if (self.step - | |
1) % self.save_interval != 0 and self.step != 1: | |
self.save() | |
exit() | |
# Save the last checkpoint if it wasn't already saved. | |
# if (self.step - 1) % self.save_interval != 0 and self.step != 1: | |
if (self.step - 1) % self.save_interval != 0: | |
self.save() # save rec | |
self.save(self.mp_trainer_disc, self.mp_trainer_disc.model_name) | |