pengc02's picture
all
ec9a6bc
raw
history blame
6.9 kB
import os
import torch
import numpy as np
import pytorch3d.ops
import importlib
from base_trainer import BaseTrainer
import config
from network.template import TemplateNet
from network.lpips import LPIPS
import utils.lr_schedule as lr_schedule
import utils.net_util as net_util
import utils.recon_util as recon_util
from utils.net_util import to_cuda
from utils.obj_io import save_mesh_as_ply
class TemplateTrainer(BaseTrainer):
def __init__(self, opt):
super(TemplateTrainer, self).__init__(opt)
self.iter_num = 15_0000
def update_config_before_epoch(self, epoch_idx):
self.iter_idx = epoch_idx * self.batch_num
print('# Optimizable variable number in network: %d' % sum(p.numel() for p in self.network.parameters() if p.requires_grad))
def forward_one_pass(self, items):
total_loss = 0
batch_losses = {}
""" random sampling """
if 'nerf_random' in items:
items.update(items['nerf_random'])
render_output = self.network.render(items, depth_guided_sampling = self.opt['train']['depth_guided_sampling'])
# color loss
if 'rgb_map' in render_output:
color_loss = torch.nn.L1Loss()(render_output['rgb_map'], items['color_gt'])
total_loss += self.loss_weight['color'] * color_loss
batch_losses.update({
'color_loss_random': color_loss.item()
})
# mask loss
if 'acc_map' in render_output:
mask_loss = torch.nn.L1Loss()(render_output['acc_map'], items['mask_gt'])
total_loss += self.loss_weight['mask'] * mask_loss
batch_losses.update({
'mask_loss_random': mask_loss.item()
})
# eikonal loss
if 'normal' in render_output:
eikonal_loss = ((torch.linalg.norm(render_output['normal'], dim = -1) - 1.) ** 2).mean()
total_loss += self.loss_weight['eikonal'] * eikonal_loss
batch_losses.update({
'eikonal_loss': eikonal_loss.item()
})
self.zero_grad()
total_loss.backward()
self.step()
return total_loss, batch_losses
def run(self):
dataset_module = self.opt['train'].get('dataset', 'MvRgbDatasetAvatarReX')
MvRgbDataset = importlib.import_module('dataset.dataset_mv_rgb').__getattribute__(dataset_module)
self.set_dataset(MvRgbDataset(**self.opt['train']['data']))
self.set_network(TemplateNet(self.opt['model']).to(config.device))
self.set_net_dict({
'network': self.network
})
self.set_optm_dict({
'network': torch.optim.Adam(self.network.parameters(), lr = 1e-3)
})
self.set_lr_schedule_dict({
'network': lr_schedule.get_learning_rate_schedules(**self.opt['train']['lr']['network'])
})
self.set_update_keys(['network'])
if self.opt['train'].get('finetune_hand', False):
print('# Finetune hand')
for n, p in self.network.named_parameters():
if not (n.startswith('left_hand') or n.startswith('right_hand')):
p.requires_grad_(False)
if 'lpips' in self.opt['train']['loss_weight']:
self.lpips = LPIPS(net = 'vgg').to(config.device)
for p in self.lpips.parameters():
p.requires_grad = False
self.train()
# output final cano geometry
items = to_cuda(self.dataset.getitem(0, training = False), add_batch = True)
with torch.no_grad():
self.network.eval()
vertices, faces, normals = self.test_geometry(items, space = 'cano', testing_res = (256, 256, 128))
save_mesh_as_ply(self.opt['train']['data']['data_dir'] + '/template.ply',
vertices, faces, normals)
def test_geometry(self, items, space = 'live', testing_res = (128, 128, 128)):
if space == 'live':
bounds = items['live_bounds'][0]
else:
bounds = items['cano_bounds'][0]
vol_pts = net_util.generate_volume_points(bounds, testing_res)
chunk_size = 256 * 256 * 4
# chunk_size = 256 * 32
sdf_list = []
for i in range(0, vol_pts.shape[0], chunk_size):
vol_pts_chunk = vol_pts[i: i + chunk_size][None]
sdf_chunk = torch.zeros(vol_pts_chunk.shape[1]).to(vol_pts_chunk)
if space == 'live':
cano_pts_chunk, near_flag = self.network.transform_live2cano(vol_pts_chunk, items, near_thres = 0.1)
else:
cano_pts_chunk = vol_pts_chunk
dists, _, _ = pytorch3d.ops.knn_points(cano_pts_chunk, items['cano_smpl_v'], K = 1)
near_flag = dists[:, :, 0] < (0.1**2) # (1, N)
near_flag.fill_(True)
if (~near_flag).sum() > 0:
sdf_chunk[~near_flag[0]] = self.network.cano_weight_volume.forward_sdf(cano_pts_chunk[~near_flag][None])[0, :, 0]
if near_flag.sum() > 0:
ret = self.network.forward_cano_radiance_field(cano_pts_chunk[near_flag][None], None, items)
if self.network.with_hand:
self.network.fuse_hands(ret, vol_pts_chunk[near_flag][None], None, items, space)
sdf_chunk[near_flag[0]] = ret['sdf'][0, :, 0]
# sdf_chunk = self.network.forward_cano_radiance_field(cano_pts_chunk, None, items['pose'])['sdf']
sdf_list.append(sdf_chunk)
sdf_list = torch.cat(sdf_list, 0)
vertices, faces, normals = recon_util.recon_mesh(sdf_list, testing_res, bounds, iso_value = 0.)
return vertices, faces, normals
@torch.no_grad()
def mini_test(self):
self.network.eval()
item = self.dataset.getitem(0, training = False)
items = to_cuda(item, add_batch = True)
vertices, faces, normals = self.test_geometry(items, space = 'cano', testing_res = (256, 256, 128))
output_dir = self.opt['train']['net_ckpt_dir'] + '/eval'
os.makedirs(output_dir, exist_ok = True)
save_mesh_as_ply(output_dir + '/batch_%d.ply' % self.iter_idx, vertices, faces, normals)
self.network.train()
if __name__ == '__main__':
torch.manual_seed(31359)
np.random.seed(31359)
from argparse import ArgumentParser
arg_parser = ArgumentParser()
arg_parser.add_argument('-c', '--config_path', type = str, help = 'Configuration file path.')
args = arg_parser.parse_args()
config.load_global_opt(args.config_path)
trainer = TemplateTrainer(config.opt)
trainer.run()