Spaces:
Running
Running
File size: 6,895 Bytes
ec9a6bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import os
import torch
import numpy as np
import pytorch3d.ops
import importlib
from base_trainer import BaseTrainer
import config
from network.template import TemplateNet
from network.lpips import LPIPS
import utils.lr_schedule as lr_schedule
import utils.net_util as net_util
import utils.recon_util as recon_util
from utils.net_util import to_cuda
from utils.obj_io import save_mesh_as_ply
class TemplateTrainer(BaseTrainer):
def __init__(self, opt):
super(TemplateTrainer, self).__init__(opt)
self.iter_num = 15_0000
def update_config_before_epoch(self, epoch_idx):
self.iter_idx = epoch_idx * self.batch_num
print('# Optimizable variable number in network: %d' % sum(p.numel() for p in self.network.parameters() if p.requires_grad))
def forward_one_pass(self, items):
total_loss = 0
batch_losses = {}
""" random sampling """
if 'nerf_random' in items:
items.update(items['nerf_random'])
render_output = self.network.render(items, depth_guided_sampling = self.opt['train']['depth_guided_sampling'])
# color loss
if 'rgb_map' in render_output:
color_loss = torch.nn.L1Loss()(render_output['rgb_map'], items['color_gt'])
total_loss += self.loss_weight['color'] * color_loss
batch_losses.update({
'color_loss_random': color_loss.item()
})
# mask loss
if 'acc_map' in render_output:
mask_loss = torch.nn.L1Loss()(render_output['acc_map'], items['mask_gt'])
total_loss += self.loss_weight['mask'] * mask_loss
batch_losses.update({
'mask_loss_random': mask_loss.item()
})
# eikonal loss
if 'normal' in render_output:
eikonal_loss = ((torch.linalg.norm(render_output['normal'], dim = -1) - 1.) ** 2).mean()
total_loss += self.loss_weight['eikonal'] * eikonal_loss
batch_losses.update({
'eikonal_loss': eikonal_loss.item()
})
self.zero_grad()
total_loss.backward()
self.step()
return total_loss, batch_losses
def run(self):
dataset_module = self.opt['train'].get('dataset', 'MvRgbDatasetAvatarReX')
MvRgbDataset = importlib.import_module('dataset.dataset_mv_rgb').__getattribute__(dataset_module)
self.set_dataset(MvRgbDataset(**self.opt['train']['data']))
self.set_network(TemplateNet(self.opt['model']).to(config.device))
self.set_net_dict({
'network': self.network
})
self.set_optm_dict({
'network': torch.optim.Adam(self.network.parameters(), lr = 1e-3)
})
self.set_lr_schedule_dict({
'network': lr_schedule.get_learning_rate_schedules(**self.opt['train']['lr']['network'])
})
self.set_update_keys(['network'])
if self.opt['train'].get('finetune_hand', False):
print('# Finetune hand')
for n, p in self.network.named_parameters():
if not (n.startswith('left_hand') or n.startswith('right_hand')):
p.requires_grad_(False)
if 'lpips' in self.opt['train']['loss_weight']:
self.lpips = LPIPS(net = 'vgg').to(config.device)
for p in self.lpips.parameters():
p.requires_grad = False
self.train()
# output final cano geometry
items = to_cuda(self.dataset.getitem(0, training = False), add_batch = True)
with torch.no_grad():
self.network.eval()
vertices, faces, normals = self.test_geometry(items, space = 'cano', testing_res = (256, 256, 128))
save_mesh_as_ply(self.opt['train']['data']['data_dir'] + '/template.ply',
vertices, faces, normals)
def test_geometry(self, items, space = 'live', testing_res = (128, 128, 128)):
if space == 'live':
bounds = items['live_bounds'][0]
else:
bounds = items['cano_bounds'][0]
vol_pts = net_util.generate_volume_points(bounds, testing_res)
chunk_size = 256 * 256 * 4
# chunk_size = 256 * 32
sdf_list = []
for i in range(0, vol_pts.shape[0], chunk_size):
vol_pts_chunk = vol_pts[i: i + chunk_size][None]
sdf_chunk = torch.zeros(vol_pts_chunk.shape[1]).to(vol_pts_chunk)
if space == 'live':
cano_pts_chunk, near_flag = self.network.transform_live2cano(vol_pts_chunk, items, near_thres = 0.1)
else:
cano_pts_chunk = vol_pts_chunk
dists, _, _ = pytorch3d.ops.knn_points(cano_pts_chunk, items['cano_smpl_v'], K = 1)
near_flag = dists[:, :, 0] < (0.1**2) # (1, N)
near_flag.fill_(True)
if (~near_flag).sum() > 0:
sdf_chunk[~near_flag[0]] = self.network.cano_weight_volume.forward_sdf(cano_pts_chunk[~near_flag][None])[0, :, 0]
if near_flag.sum() > 0:
ret = self.network.forward_cano_radiance_field(cano_pts_chunk[near_flag][None], None, items)
if self.network.with_hand:
self.network.fuse_hands(ret, vol_pts_chunk[near_flag][None], None, items, space)
sdf_chunk[near_flag[0]] = ret['sdf'][0, :, 0]
# sdf_chunk = self.network.forward_cano_radiance_field(cano_pts_chunk, None, items['pose'])['sdf']
sdf_list.append(sdf_chunk)
sdf_list = torch.cat(sdf_list, 0)
vertices, faces, normals = recon_util.recon_mesh(sdf_list, testing_res, bounds, iso_value = 0.)
return vertices, faces, normals
@torch.no_grad()
def mini_test(self):
self.network.eval()
item = self.dataset.getitem(0, training = False)
items = to_cuda(item, add_batch = True)
vertices, faces, normals = self.test_geometry(items, space = 'cano', testing_res = (256, 256, 128))
output_dir = self.opt['train']['net_ckpt_dir'] + '/eval'
os.makedirs(output_dir, exist_ok = True)
save_mesh_as_ply(output_dir + '/batch_%d.ply' % self.iter_idx, vertices, faces, normals)
self.network.train()
if __name__ == '__main__':
torch.manual_seed(31359)
np.random.seed(31359)
from argparse import ArgumentParser
arg_parser = ArgumentParser()
arg_parser.add_argument('-c', '--config_path', type = str, help = 'Configuration file path.')
args = arg_parser.parse_args()
config.load_global_opt(args.config_path)
trainer = TemplateTrainer(config.opt)
trainer.run()
|