Spaces:
Running
Running
from ast import dump | |
import os | |
from turtle import left | |
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1' | |
# os.environ['TORCH_USE_CUDA_DSA'] = '1' | |
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' | |
# from inflection import camelize | |
import yaml | |
import shutil | |
import collections | |
import torch | |
import torch.utils.data | |
import torch.nn.functional as F | |
import numpy as np | |
import cv2 as cv | |
import glob | |
import datetime | |
import trimesh | |
from torch.utils.tensorboard import SummaryWriter | |
from tqdm import tqdm | |
import importlib | |
import json | |
import config | |
from network.lpips import LPIPS | |
from dataset.dataset_pose import PoseDataset | |
import utils.net_util as net_util | |
import utils.visualize_util as visualize_util | |
from utils.renderer import Renderer | |
from utils.net_util import to_cuda | |
from utils.obj_io import save_mesh_as_ply | |
from gaussians.obj_io import save_gaussians_as_ply | |
def safe_exists(path): | |
if path is None: | |
return False | |
return os.path.exists(path) | |
class AvatarTrainer: | |
def __init__(self, opt): | |
self.opt = opt | |
self.patch_size = 512 | |
self.iter_idx = 0 | |
self.iter_num = 800000 | |
self.lr_init = float(self.opt['train'].get('lr_init', 5e-4)) | |
avatar_module = self.opt['model'].get('module', 'network.avatar') | |
print('Import AvatarNet from %s' % avatar_module) | |
AvatarNet = importlib.import_module(avatar_module).AvatarNet | |
self.avatar_net = AvatarNet(self.opt['model']).to(config.device) | |
self.optm = torch.optim.Adam( | |
self.avatar_net.parameters(), lr = self.lr_init | |
) | |
self.random_bg_color = self.opt['train'].get('random_bg_color', True) | |
self.bg_color = (1., 1., 1.) | |
self.bg_color_cuda = torch.from_numpy(np.asarray(self.bg_color)).to(torch.float32).to(config.device) | |
self.loss_weight = self.opt['train']['loss_weight'] | |
self.finetune_color = self.opt['train']['finetune_color'] | |
print('# Parameter number of AvatarNet is %d' % (sum([p.numel() for p in self.avatar_net.parameters()]))) | |
def update_lr(self): | |
alpha = 0.05 | |
progress = self.iter_idx / self.iter_num | |
learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (1 - alpha) + alpha | |
lr = self.lr_init * learning_factor | |
for param_group in self.optm.param_groups: | |
param_group['lr'] = lr | |
return lr | |
def requires_net_grad(net: torch.nn.Module, flag = True): | |
for p in net.parameters(): | |
p.requires_grad = flag | |
def crop_image(self, gt_mask, patch_size, randomly, *args): | |
""" | |
:param gt_mask: (H, W) | |
:param patch_size: resize the cropped patch to the given patch_size | |
:param randomly: whether to randomly sample the patch | |
:param args: input images with shape of (C, H, W) | |
""" | |
mask_uv = torch.argwhere(gt_mask > 0.) | |
min_v, min_u = mask_uv.min(0)[0] | |
max_v, max_u = mask_uv.max(0)[0] | |
len_v = max_v - min_v | |
len_u = max_u - min_u | |
max_size = max(len_v, len_u) | |
cropped_images = [] | |
if randomly and max_size > patch_size: | |
random_v = torch.randint(0, max_size - patch_size + 1, (1,)).to(max_size) | |
random_u = torch.randint(0, max_size - patch_size + 1, (1,)).to(max_size) | |
for image in args: | |
cropped_image = self.bg_color_cuda[:, None, None] * torch.ones((3, max_size, max_size), dtype = image.dtype, device = image.device) | |
if len_v > len_u: | |
start_u = (max_size - len_u) // 2 | |
cropped_image[:, :, start_u: start_u + len_u] = image[:, min_v: max_v, min_u: max_u] | |
else: | |
start_v = (max_size - len_v) // 2 | |
cropped_image[:, start_v: start_v + len_v, :] = image[:, min_v: max_v, min_u: max_u] | |
if randomly and max_size > patch_size: | |
cropped_image = cropped_image[:, random_v: random_v + patch_size, random_u: random_u + patch_size] | |
else: | |
cropped_image = F.interpolate(cropped_image[None], size = (patch_size, patch_size), mode = 'bilinear')[0] | |
cropped_images.append(cropped_image) | |
# cv.imshow('cropped_image', cropped_image.detach().cpu().numpy().transpose(1, 2, 0)) | |
# cv.imshow('cropped_gt_image', cropped_gt_image.detach().cpu().numpy().transpose(1, 2, 0)) | |
# cv.waitKey(0) | |
if len(cropped_images) > 1: | |
return cropped_images | |
else: | |
return cropped_images[0] | |
def compute_lpips_loss(self, image, gt_image): | |
assert image.shape[1] == image.shape[2] and gt_image.shape[1] == gt_image.shape[2] | |
lpips_loss = self.lpips.forward( | |
image[None, [2, 1, 0]], | |
gt_image[None, [2, 1, 0]], | |
normalize = True | |
).mean() | |
return lpips_loss | |
def forward_one_pass_pretrain(self, items): | |
total_loss = 0 | |
batch_losses = {} | |
l1_loss = torch.nn.L1Loss() | |
items = net_util.delete_batch_idx(items) | |
pose_map = items['smpl_pos_map'][:3] | |
position_loss = l1_loss(self.avatar_net.get_positions(pose_map), self.avatar_net.cano_gaussian_model.get_xyz) | |
total_loss += position_loss | |
batch_losses.update({ | |
'position': position_loss.item() | |
}) | |
opacity, scales, rotations = self.avatar_net.get_others(pose_map) | |
opacity_loss = l1_loss(opacity, self.avatar_net.cano_gaussian_model.get_opacity) | |
total_loss += opacity_loss | |
batch_losses.update({ | |
'opacity': opacity_loss.item() | |
}) | |
scale_loss = l1_loss(scales, self.avatar_net.cano_gaussian_model.get_scaling) | |
total_loss += scale_loss | |
batch_losses.update({ | |
'scale': scale_loss.item() | |
}) | |
rotation_loss = l1_loss(rotations, self.avatar_net.cano_gaussian_model.get_rotation) | |
total_loss += rotation_loss | |
batch_losses.update({ | |
'rotation': rotation_loss.item() | |
}) | |
total_loss.backward() | |
self.optm.step() | |
self.optm.zero_grad() | |
return total_loss, batch_losses | |
def forward_one_pass(self, items): | |
# forward_start = torch.cuda.Event(enable_timing = True) | |
# forward_end = torch.cuda.Event(enable_timing = True) | |
# backward_start = torch.cuda.Event(enable_timing = True) | |
# backward_end = torch.cuda.Event(enable_timing = True) | |
# step_start = torch.cuda.Event(enable_timing = True) | |
# step_end = torch.cuda.Event(enable_timing = True) | |
if self.random_bg_color: | |
self.bg_color = np.random.rand(3) | |
self.bg_color_cuda = torch.from_numpy(np.asarray(self.bg_color)).to(torch.float32).to(config.device) | |
total_loss = 0 | |
batch_losses = {} | |
items = net_util.delete_batch_idx(items) | |
""" Optimize generator """ | |
if self.finetune_color: | |
self.requires_net_grad(self.avatar_net.color_net, True) | |
self.requires_net_grad(self.avatar_net.position_net, False) | |
self.requires_net_grad(self.avatar_net.other_net, True) | |
else: | |
self.requires_net_grad(self.avatar_net, True) | |
# forward_start.record() | |
render_output = self.avatar_net.render(items, self.bg_color) | |
image = render_output['rgb_map'].permute(2, 0, 1) | |
offset = render_output['offset'] | |
# mask image & set bg color | |
items['color_img'][~items['mask_img']] = self.bg_color_cuda | |
gt_image = items['color_img'].permute(2, 0, 1) | |
mask_img = items['mask_img'].to(torch.float32) | |
boundary_mask_img = 1. - items['boundary_mask_img'].to(torch.float32) | |
image = image * boundary_mask_img[None] + (1. - boundary_mask_img[None]) * self.bg_color_cuda[:, None, None] | |
gt_image = gt_image * boundary_mask_img[None] + (1. - boundary_mask_img[None]) * self.bg_color_cuda[:, None, None] | |
# cv.imshow('image', image.detach().permute(1, 2, 0).cpu().numpy()) | |
# cv.imshow('gt_image', gt_image.permute(1, 2, 0).cpu().numpy()) | |
# cv.waitKey(0) | |
if self.loss_weight['l1'] > 0.: | |
l1_loss = torch.abs(image - gt_image).mean() | |
total_loss += self.loss_weight['l1'] * l1_loss | |
batch_losses.update({ | |
'l1_loss': l1_loss.item() | |
}) | |
if self.loss_weight.get('mask', 0.) and 'mask_map' in render_output: | |
rendered_mask = render_output['mask_map'].squeeze(-1) * boundary_mask_img | |
gt_mask = mask_img * boundary_mask_img | |
# cv.imshow('rendered_mask', rendered_mask.detach().cpu().numpy()) | |
# cv.imshow('gt_mask', gt_mask.detach().cpu().numpy()) | |
# cv.waitKey(0) | |
mask_loss = torch.abs(rendered_mask - gt_mask).mean() | |
# mask_loss = torch.nn.BCELoss()(rendered_mask, gt_mask) | |
total_loss += self.loss_weight.get('mask', 0.) * mask_loss | |
batch_losses.update({ | |
'mask_loss': mask_loss.item() | |
}) | |
if self.loss_weight['lpips'] > 0.: | |
# crop images | |
random_patch_flag = False if self.iter_idx < 300000 else True | |
image, gt_image = self.crop_image(mask_img, self.patch_size, random_patch_flag, image, gt_image) | |
# cv.imshow('image', image.detach().permute(1, 2, 0).cpu().numpy()) | |
# cv.imshow('gt_image', gt_image.permute(1, 2, 0).cpu().numpy()) | |
# cv.waitKey(0) | |
lpips_loss = self.compute_lpips_loss(image, gt_image) | |
total_loss += self.loss_weight['lpips'] * lpips_loss | |
batch_losses.update({ | |
'lpips_loss': lpips_loss.item() | |
}) | |
# if self.loss_weight['offset'] > 0.: | |
if True: | |
offset_loss = torch.linalg.norm(offset, dim = -1).mean() | |
total_loss += self.loss_weight['offset'] * offset_loss | |
batch_losses.update({ | |
'offset_loss': offset_loss.item() | |
}) | |
# forward_end.record() | |
# backward_start.record() | |
total_loss.backward() | |
# backward_end.record() | |
# step_start.record() | |
self.optm.step() | |
self.optm.zero_grad() | |
# step_end.record() | |
# torch.cuda.synchronize() | |
# print(f'Forward costs: {forward_start.elapsed_time(forward_end) / 1000.}, ', | |
# f'Backward costs: {backward_start.elapsed_time(backward_end) / 1000.}, ', | |
# f'Step costs: {step_start.elapsed_time(step_end) / 1000.}') | |
return total_loss, batch_losses | |
def pretrain(self): | |
dataset_module = self.opt['train'].get('dataset', 'MvRgbDatasetAvatarReX') | |
MvRgbDataset = importlib.import_module('dataset.dataset_mv_rgb').__getattribute__(dataset_module) | |
self.dataset = MvRgbDataset(**self.opt['train']['data']) | |
batch_size = self.opt['train']['batch_size'] | |
num_workers = self.opt['train']['num_workers'] | |
batch_num = len(self.dataset) // batch_size | |
dataloader = torch.utils.data.DataLoader(self.dataset, | |
batch_size = batch_size, | |
shuffle = True, | |
num_workers = num_workers, | |
drop_last = True) | |
# tb writer | |
log_dir = self.opt['train']['net_ckpt_dir'] + '/' + datetime.datetime.now().strftime('pretrain_%Y_%m_%d_%H_%M_%S') | |
writer = SummaryWriter(log_dir) | |
smooth_interval = 10 | |
smooth_count = 0 | |
smooth_losses = {} | |
for epoch_idx in range(0, 9999999): | |
self.epoch_idx = epoch_idx | |
for batch_idx, items in enumerate(dataloader): | |
self.iter_idx = batch_idx + epoch_idx * batch_num | |
items = to_cuda(items) | |
# one_step_start.record() | |
total_loss, batch_losses = self.forward_one_pass_pretrain(items) | |
# one_step_end.record() | |
# torch.cuda.synchronize() | |
# print('One step costs %f secs' % (one_step_start.elapsed_time(one_step_end) / 1000.)) | |
# record batch loss | |
for key, loss in batch_losses.items(): | |
if key in smooth_losses: | |
smooth_losses[key] += loss | |
else: | |
smooth_losses[key] = loss | |
smooth_count += 1 | |
if self.iter_idx % smooth_interval == 0: | |
log_info = 'epoch %d, batch %d, iter %d, ' % (epoch_idx, batch_idx, self.iter_idx) | |
for key in smooth_losses.keys(): | |
smooth_losses[key] /= smooth_count | |
writer.add_scalar('%s/Iter' % key, smooth_losses[key], self.iter_idx) | |
log_info = log_info + ('%s: %f, ' % (key, smooth_losses[key])) | |
smooth_losses[key] = 0. | |
smooth_count = 0 | |
print(log_info) | |
with open(os.path.join(log_dir, 'loss.txt'), 'a') as fp: | |
fp.write(log_info + '\n') | |
if self.iter_idx % 200 == 0 and self.iter_idx != 0: | |
self.mini_test(pretraining = True) | |
if self.iter_idx == 5000: | |
model_folder = self.opt['train']['net_ckpt_dir'] + '/pretrained' | |
os.makedirs(model_folder, exist_ok = True) | |
self.save_ckpt(model_folder, save_optm = True) | |
self.iter_idx = 0 | |
return | |
def train(self): | |
dataset_module = self.opt['train'].get('dataset', 'MvRgbDatasetAvatarReX') | |
MvRgbDataset = importlib.import_module('dataset.dataset_mv_rgb').__getattribute__(dataset_module) | |
self.dataset = MvRgbDataset(**self.opt['train']['data']) | |
batch_size = self.opt['train']['batch_size'] | |
num_workers = self.opt['train']['num_workers'] | |
batch_num = len(self.dataset) // batch_size | |
dataloader = torch.utils.data.DataLoader(self.dataset, | |
batch_size = batch_size, | |
shuffle = True, | |
num_workers = num_workers, | |
drop_last = True) | |
if 'lpips' in self.opt['train']['loss_weight']: | |
self.lpips = LPIPS(net = 'vgg').to(config.device) | |
for p in self.lpips.parameters(): | |
p.requires_grad = False | |
if self.opt['train']['prev_ckpt'] is not None: | |
start_epoch, self.iter_idx = self.load_ckpt(self.opt['train']['prev_ckpt'], load_optm = True) | |
start_epoch += 1 | |
self.iter_idx += 1 | |
else: | |
prev_ckpt_path = self.opt['train']['net_ckpt_dir'] + '/epoch_latest' | |
if safe_exists(prev_ckpt_path): | |
start_epoch, self.iter_idx = self.load_ckpt(prev_ckpt_path, load_optm = True) | |
start_epoch += 1 | |
self.iter_idx += 1 | |
else: | |
if safe_exists(self.opt['train']['pretrained_dir']): | |
self.load_ckpt(self.opt['train']['pretrained_dir'], load_optm = False) | |
elif safe_exists(self.opt['train']['net_ckpt_dir'] + '/pretrained'): | |
self.load_ckpt(self.opt['train']['net_ckpt_dir'] + '/pretrained', load_optm = False) | |
else: | |
raise FileNotFoundError('Cannot find pretrained checkpoint!') | |
self.optm.state = collections.defaultdict(dict) | |
start_epoch = 0 | |
self.iter_idx = 0 | |
# one_step_start = torch.cuda.Event(enable_timing = True) | |
# one_step_end = torch.cuda.Event(enable_timing = True) | |
# tb writer | |
log_dir = self.opt['train']['net_ckpt_dir'] + '/' + datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') | |
writer = SummaryWriter(log_dir) | |
yaml.dump(self.opt, open(log_dir + '/config_bk.yaml', 'w'), sort_keys = False) | |
smooth_interval = 10 | |
smooth_count = 0 | |
smooth_losses = {} | |
for epoch_idx in range(start_epoch, 9999999): | |
self.epoch_idx = epoch_idx | |
for batch_idx, items in enumerate(dataloader): | |
lr = self.update_lr() | |
items = to_cuda(items) | |
# one_step_start.record() | |
total_loss, batch_losses = self.forward_one_pass(items) | |
# one_step_end.record() | |
# torch.cuda.synchronize() | |
# print('One step costs %f secs' % (one_step_start.elapsed_time(one_step_end) / 1000.)) | |
# record batch loss | |
for key, loss in batch_losses.items(): | |
if key in smooth_losses: | |
smooth_losses[key] += loss | |
else: | |
smooth_losses[key] = loss | |
smooth_count += 1 | |
if self.iter_idx % smooth_interval == 0: | |
log_info = 'epoch %d, batch %d, iter %d, lr %e, ' % (epoch_idx, batch_idx, self.iter_idx, lr) | |
for key in smooth_losses.keys(): | |
smooth_losses[key] /= smooth_count | |
writer.add_scalar('%s/Iter' % key, smooth_losses[key], self.iter_idx) | |
log_info = log_info + ('%s: %f, ' % (key, smooth_losses[key])) | |
smooth_losses[key] = 0. | |
smooth_count = 0 | |
print(log_info) | |
with open(os.path.join(log_dir, 'loss.txt'), 'a') as fp: | |
fp.write(log_info + '\n') | |
torch.cuda.empty_cache() | |
if self.iter_idx % self.opt['train']['eval_interval'] == 0 and self.iter_idx != 0: | |
if self.iter_idx % (10 * self.opt['train']['eval_interval']) == 0: | |
eval_cano_pts = True | |
else: | |
eval_cano_pts = False | |
self.mini_test(eval_cano_pts = eval_cano_pts) | |
if self.iter_idx % self.opt['train']['ckpt_interval']['batch'] == 0 and self.iter_idx != 0: | |
for folder in glob.glob(self.opt['train']['net_ckpt_dir'] + '/batch_*'): | |
shutil.rmtree(folder) | |
model_folder = self.opt['train']['net_ckpt_dir'] + '/batch_%d' % self.iter_idx | |
os.makedirs(model_folder, exist_ok = True) | |
self.save_ckpt(model_folder, save_optm = True) | |
if self.iter_idx == self.iter_num: | |
print('# Training is done.') | |
return | |
self.iter_idx += 1 | |
""" End of epoch """ | |
if epoch_idx % self.opt['train']['ckpt_interval']['epoch'] == 0 and epoch_idx != 0: | |
model_folder = self.opt['train']['net_ckpt_dir'] + '/epoch_%d' % epoch_idx | |
os.makedirs(model_folder, exist_ok = True) | |
self.save_ckpt(model_folder) | |
if batch_num > 50: | |
latest_folder = self.opt['train']['net_ckpt_dir'] + '/epoch_latest' | |
os.makedirs(latest_folder, exist_ok = True) | |
self.save_ckpt(latest_folder) | |
def mini_test(self, pretraining = False, eval_cano_pts = False): | |
self.avatar_net.eval() | |
img_factor = self.opt['train'].get('eval_img_factor', 1.0) | |
# training data | |
pose_idx, view_idx = self.opt['train'].get('eval_training_ids', (310, 19)) | |
intr = self.dataset.intr_mats[view_idx].copy() | |
intr[:2] *= img_factor | |
item = self.dataset.getitem(0, | |
pose_idx = pose_idx, | |
view_idx = view_idx, | |
training = False, | |
eval = True, | |
img_h = int(self.dataset.img_heights[view_idx] * img_factor), | |
img_w = int(self.dataset.img_widths[view_idx] * img_factor), | |
extr = self.dataset.extr_mats[view_idx], | |
intr = intr, | |
exact_hand_pose = True) | |
items = net_util.to_cuda(item, add_batch = False) | |
gs_render = self.avatar_net.render(items, self.bg_color) | |
# gs_render = self.avatar_net.render_debug(items) | |
rgb_map = gs_render['rgb_map'] | |
rgb_map.clip_(0., 1.) | |
rgb_map = (rgb_map.cpu().numpy() * 255).astype(np.uint8) | |
# cv.imshow('rgb_map', rgb_map.cpu().numpy()) | |
# cv.waitKey(0) | |
if not pretraining: | |
output_dir = self.opt['train']['net_ckpt_dir'] + '/eval/training' | |
else: | |
output_dir = self.opt['train']['net_ckpt_dir'] + '/eval_pretrain/training' | |
gt_image, _ = self.dataset.load_color_mask_images(pose_idx, view_idx) | |
if gt_image is not None: | |
gt_image = cv.resize(gt_image, (0, 0), fx = img_factor, fy = img_factor) | |
rgb_map = np.concatenate([rgb_map, gt_image], 1) | |
os.makedirs(output_dir, exist_ok = True) | |
cv.imwrite(output_dir + '/iter_%d.jpg' % self.iter_idx, rgb_map) | |
if eval_cano_pts: | |
os.makedirs(output_dir + '/cano_pts', exist_ok = True) | |
save_mesh_as_ply(output_dir + '/cano_pts/iter_%d.ply' % self.iter_idx, (self.avatar_net.init_points + gs_render['offset']).cpu().numpy()) | |
# training data | |
pose_idx, view_idx = self.opt['train'].get('eval_testing_ids', (310, 19)) | |
intr = self.dataset.intr_mats[view_idx].copy() | |
intr[:2] *= img_factor | |
item = self.dataset.getitem(0, | |
pose_idx = pose_idx, | |
view_idx = view_idx, | |
training = False, | |
eval = True, | |
img_h = int(self.dataset.img_heights[view_idx] * img_factor), | |
img_w = int(self.dataset.img_widths[view_idx] * img_factor), | |
extr = self.dataset.extr_mats[view_idx], | |
intr = intr, | |
exact_hand_pose = True) | |
items = net_util.to_cuda(item, add_batch = False) | |
gs_render = self.avatar_net.render(items, bg_color = self.bg_color) | |
# gs_render = self.avatar_net.render_debug(items) | |
rgb_map = gs_render['rgb_map'] | |
rgb_map.clip_(0., 1.) | |
rgb_map = (rgb_map.cpu().numpy() * 255).astype(np.uint8) | |
# cv.imshow('rgb_map', rgb_map.cpu().numpy()) | |
# cv.waitKey(0) | |
if not pretraining: | |
output_dir = self.opt['train']['net_ckpt_dir'] + '/eval/testing' | |
else: | |
output_dir = self.opt['train']['net_ckpt_dir'] + '/eval_pretrain/testing' | |
gt_image, _ = self.dataset.load_color_mask_images(pose_idx, view_idx) | |
if gt_image is not None: | |
gt_image = cv.resize(gt_image, (0, 0), fx = img_factor, fy = img_factor) | |
rgb_map = np.concatenate([rgb_map, gt_image], 1) | |
os.makedirs(output_dir, exist_ok = True) | |
cv.imwrite(output_dir + '/iter_%d.jpg' % self.iter_idx, rgb_map) | |
if eval_cano_pts: | |
os.makedirs(output_dir + '/cano_pts', exist_ok = True) | |
save_mesh_as_ply(output_dir + '/cano_pts/iter_%d.ply' % self.iter_idx, (self.avatar_net.init_points + gs_render['offset']).cpu().numpy()) | |
self.avatar_net.train() | |
def dump_renderer_info(self, dump_dir, extrs, intrs, img_heights, img_widths): | |
with open(os.path.join(dump_dir, 'cfg_args'), 'w') as fp: | |
outstr = "Namespace(sh_degree=%d, source_path='%s', model_path='%s', images='images', resolution=-1, " \ | |
"white_background=False, data_device='cuda', eval=False)" % ( | |
3, self.opt['train']['data']['data_dir'], dump_dir) | |
fp.write(outstr) | |
with open(os.path.join(dump_dir, 'cameras.json'), 'w') as fp: | |
cam_jsons = [] | |
for ci in range(len(extrs)): | |
extr, intr = extrs[ci], intrs[ci] | |
img_h, img_w = img_heights[ci], img_widths[ci] | |
w2c = extr | |
c2w = np.linalg.inv(w2c) | |
pos = c2w[:3, 3] | |
rot = c2w[:3, :3] | |
serializable_array_2d = [x.tolist() for x in rot] | |
camera_entry = { | |
'id': ci, | |
'img_name': '%08d' % ci, | |
'width': int(img_w), | |
'height': int(img_h), | |
'position': pos.tolist(), | |
'rotation': serializable_array_2d, | |
'fy': float(intr[1, 1]), | |
'fx': float(intr[0, 0]), | |
} | |
cam_jsons.append(camera_entry) | |
json.dump(cam_jsons, fp) | |
return | |
def test(self): | |
self.avatar_net.eval() | |
# ipdb.set_trace() | |
dataset_module = self.opt['train'].get('dataset', 'MvRgbDatasetAvatarReX') | |
MvRgbDataset = importlib.import_module('dataset.dataset_mv_rgb').__getattribute__(dataset_module) | |
training_dataset = MvRgbDataset(**self.opt['train']['data'], training = False) | |
if self.opt['test'].get('n_pca', -1) >= 1: | |
training_dataset.compute_pca(n_components = self.opt['test']['n_pca']) | |
if 'pose_data' in self.opt['test']: | |
testing_dataset = PoseDataset(**self.opt['test']['pose_data'], smpl_shape = training_dataset.smpl_data['betas'][0]) | |
dataset_name = testing_dataset.dataset_name | |
seq_name = testing_dataset.seq_name | |
else: | |
testing_dataset = MvRgbDataset(**self.opt['test']['data'], training = False) | |
dataset_name = 'training' | |
seq_name = '' | |
# print('come here') | |
self.dataset = testing_dataset | |
iter_idx = self.load_ckpt(self.opt['test']['prev_ckpt'], False)[1] | |
output_dir = self.opt['test'].get('output_dir', None) | |
if output_dir is None: | |
view_setting = config.opt['test'].get('view_setting', 'free') | |
if view_setting == 'camera': | |
view_folder = 'cam_%03d' % config.opt['test']['render_view_idx'] | |
else: | |
view_folder = view_setting + '_view' | |
exp_name = os.path.basename(os.path.dirname(self.opt['test']['prev_ckpt'])) | |
output_dir = f'./test_results/{training_dataset.subject_name}/{exp_name}/{dataset_name}_{seq_name}_{view_folder}' + '/batch_%06d' % iter_idx | |
use_pca = self.opt['test'].get('n_pca', -1) >= 1 | |
if use_pca: | |
output_dir += '/pca_%d_sigma_%.2f' % (self.opt['test'].get('n_pca', -1), float(self.opt['test'].get('sigma_pca', 1.))) | |
else: | |
output_dir += '/vanilla' | |
print('# Output dir: \033[1;31m%s\033[0m' % output_dir) | |
os.makedirs(output_dir + '/live_skeleton', exist_ok = True) | |
os.makedirs(output_dir + '/rgb_map', exist_ok = True) | |
os.makedirs(output_dir + '/rgb_map_wo_hand', exist_ok = True) | |
os.makedirs(output_dir + '/torso_map', exist_ok = True) | |
os.makedirs(output_dir + '/mask_map', exist_ok = True) | |
os.makedirs(output_dir + '/posed_gaussians', exist_ok = True) | |
os.makedirs(output_dir + '/posed_params', exist_ok = True) | |
os.makedirs(output_dir + '/full_body_mask', exist_ok = True) | |
os.makedirs(output_dir + '/hand_only_mask', exist_ok = True) | |
geo_renderer = None | |
item_0 = self.dataset.getitem(0, training = False) | |
object_center = item_0['live_bounds'].mean(0) | |
global_orient = item_0['global_orient'].cpu().numpy() if isinstance(item_0['global_orient'], torch.Tensor) else item_0['global_orient'] | |
# set x and z to 0 | |
global_orient[0] = 0 | |
global_orient[2] = 0 | |
global_orient = cv.Rodrigues(global_orient)[0] | |
# print('object_center: ', object_center.tolist()) | |
# print('global_orient: ', global_orient.tolist()) | |
# exit(1) | |
time_start = torch.cuda.Event(enable_timing = True) | |
time_start_all = torch.cuda.Event(enable_timing = True) | |
time_end = torch.cuda.Event(enable_timing = True) | |
data_num = len(self.dataset) | |
if self.opt['test'].get('fix_hand', False): | |
self.avatar_net.generate_mean_hands() | |
log_time = False | |
# extr = visualize_util.calc_free_mv(object_center, | |
# tar_pos = np.array([0, 0, 2.5]), | |
# rot_Y = 0., | |
# rot_X = 0., | |
# global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) | |
# intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
# img_scale = self.opt['test'].get('img_scale', 1.0) | |
# intr[:2] *= img_scale | |
# img_h = int(1024 * img_scale) | |
# img_w = int(1024 * img_scale) | |
# self.dump_renderer_info(output_dir, [extr], [intr], [img_h], [img_w]) | |
extr_list = [] | |
intr_list = [] | |
img_h_list = [] | |
img_w_list = [] | |
for idx in tqdm(range(data_num), desc = 'Rendering avatars...'): | |
if log_time: | |
time_start.record() | |
time_start_all.record() | |
img_scale = self.opt['test'].get('img_scale', 1.0) | |
view_setting = config.opt['test'].get('view_setting', 'free') | |
if view_setting == 'camera': | |
# training view setting | |
cam_id = config.opt['test']['render_view_idx'] | |
intr = self.dataset.intr_mats[cam_id].copy() | |
intr[:2] *= img_scale | |
extr = self.dataset.extr_mats[cam_id].copy() | |
img_h, img_w = int(self.dataset.img_heights[cam_id] * img_scale), int(self.dataset.img_widths[cam_id] * img_scale) | |
elif view_setting.startswith('free'): | |
# free view setting | |
# frame_num_per_circle = 360 | |
print(self.opt['test'].get('global_orient', False)) | |
frame_num_per_circle = 360 | |
rot_Y = (idx % frame_num_per_circle) / float(frame_num_per_circle) * 2 * np.pi | |
extr = visualize_util.calc_free_mv(object_center, | |
tar_pos = np.array([0, 0, 2.5]), | |
rot_Y = rot_Y, | |
rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) | |
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
intr[:2] *= img_scale | |
img_h = int(1024 * img_scale) | |
img_w = int(1024 * img_scale) | |
extr_list.append(extr) | |
intr_list.append(intr) | |
img_h_list.append(img_h) | |
img_w_list.append(img_w) | |
elif view_setting.startswith('degree120'): | |
print('we render 120 degree') | |
# +- 60 degree | |
frame_per_cycle = 480 | |
max_degree = 60 | |
frame_half_cycle = frame_per_cycle // 2 | |
if idx%frame_per_cycle < frame_per_cycle/2: | |
rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) | |
# rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi | |
else: | |
rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) | |
# to radian | |
rot_Y = rot_Y * np.pi / 180 | |
if rot_Y<0: | |
rot_Y = rot_Y + 2 * np.pi | |
# print('rot_Y: ', rot_Y) | |
extr = visualize_util.calc_free_mv(object_center, | |
tar_pos = np.array([0, 0, 2.5]), | |
rot_Y = rot_Y, | |
rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) | |
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
intr[:2] *= img_scale | |
img_h = int(1024 * img_scale) | |
img_w = int(1024 * img_scale) | |
extr_list.append(extr) | |
intr_list.append(intr) | |
img_h_list.append(img_h) | |
img_w_list.append(img_w) | |
elif view_setting.startswith('degree90'): | |
print('we render 90 degree') | |
# +- 60 degree | |
frame_per_cycle = 360 | |
max_degree = 45 | |
frame_half_cycle = frame_per_cycle // 2 | |
if idx%frame_per_cycle < frame_per_cycle/2: | |
rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) | |
# rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi | |
else: | |
rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) | |
# to radian | |
rot_Y = rot_Y * np.pi / 180 | |
if rot_Y<0: | |
rot_Y = rot_Y + 2 * np.pi | |
# print('rot_Y: ', rot_Y) | |
extr = visualize_util.calc_free_mv(object_center, | |
tar_pos = np.array([0, 0, 2.5]), | |
rot_Y = rot_Y, | |
rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) | |
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
intr[:2] *= img_scale | |
img_h = int(1024 * img_scale) | |
img_w = int(1024 * img_scale) | |
extr_list.append(extr) | |
intr_list.append(intr) | |
img_h_list.append(img_h) | |
img_w_list.append(img_w) | |
elif view_setting.startswith('front'): | |
# front view setting | |
extr = visualize_util.calc_free_mv(object_center, | |
tar_pos = np.array([0, 0, 2.5]), | |
rot_Y = 0., | |
rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) | |
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
intr[:2] *= img_scale | |
img_h = int(1024 * img_scale) | |
img_w = int(1024 * img_scale) | |
extr_list.append(extr) | |
intr_list.append(intr) | |
img_h_list.append(img_h) | |
img_w_list.append(img_w) | |
# print('extr: ', extr) | |
# print('intr: ', intr) | |
# print('img_h: ', img_h) | |
# print('img_w: ', img_w) | |
# exit() | |
elif view_setting.startswith('back'): | |
# back view setting | |
extr = visualize_util.calc_free_mv(object_center, | |
tar_pos = np.array([0, 0, 2.5]), | |
rot_Y = np.pi, | |
rot_X = 0.5 * np.pi / 4. if view_setting.endswith('bird') else 0., | |
global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) | |
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
intr[:2] *= img_scale | |
img_h = int(1024 * img_scale) | |
img_w = int(1024 * img_scale) | |
elif view_setting.startswith('moving'): | |
# moving camera setting | |
extr = visualize_util.calc_free_mv(object_center, | |
# tar_pos = np.array([0, 0, 3.0]), | |
# rot_Y = -0.3, | |
tar_pos = np.array([0, 0, 2.5]), | |
rot_Y = 0., | |
rot_X = 0.3 if view_setting.endswith('bird') else 0., | |
global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) | |
intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) | |
intr[:2] *= img_scale | |
img_h = int(1024 * img_scale) | |
img_w = int(1024 * img_scale) | |
elif view_setting.startswith('cano'): | |
cano_center = self.dataset.cano_bounds.mean(0) | |
extr = np.identity(4, np.float32) | |
extr[:3, 3] = -cano_center | |
rot_x = np.identity(4, np.float32) | |
rot_x[:3, :3] = cv.Rodrigues(np.array([np.pi, 0, 0], np.float32))[0] | |
extr = rot_x @ extr | |
f_len = 5000 | |
extr[2, 3] += f_len / 512 | |
intr = np.array([[f_len, 0, 512], [0, f_len, 512], [0, 0, 1]], np.float32) | |
# item = self.dataset.getitem(idx, | |
# training = False, | |
# extr = extr, | |
# intr = intr, | |
# img_w = 1024, | |
# img_h = 1024) | |
img_w, img_h = 1024, 1024 | |
# item['live_smpl_v'] = item['cano_smpl_v'] | |
# item['cano2live_jnt_mats'] = torch.eye(4, dtype = torch.float32)[None].expand(item['cano2live_jnt_mats'].shape[0], -1, -1) | |
# item['live_bounds'] = item['cano_bounds'] | |
else: | |
raise ValueError('Invalid view setting for animation!') | |
self.dump_renderer_info(output_dir, extr_list, intr_list, img_h_list, img_w_list) | |
# also save the extr and intr and img_h and img_w to json | |
camera_info = [] | |
for i in range(len(extr_list)): | |
camera = {} | |
camera['extr'] = extr_list[i].tolist() | |
camera['intr'] = intr_list[i].tolist() | |
camera['img_h'] = img_h_list[i] | |
camera['img_w'] = img_w_list[i] | |
camera_info.append(camera) | |
with open(os.path.join(output_dir, 'camera_info.json'), 'w') as fp: | |
json.dump(camera_info, fp) | |
getitem_func = self.dataset.getitem_fast if hasattr(self.dataset, 'getitem_fast') else self.dataset.getitem | |
item = getitem_func( | |
idx, | |
training = False, | |
extr = extr, | |
intr = intr, | |
img_w = img_w, | |
img_h = img_h | |
) | |
items = to_cuda(item, add_batch = False) | |
if view_setting.startswith('moving') or view_setting == 'free_moving': | |
current_center = items['live_bounds'].cpu().numpy().mean(0) | |
delta = current_center - object_center | |
object_center[0] += delta[0] | |
# object_center[1] += delta[1] | |
# object_center[2] += delta[2] | |
if log_time: | |
time_end.record() | |
torch.cuda.synchronize() | |
print('Loading data costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
time_start.record() | |
if self.opt['test'].get('render_skeleton', False): | |
from utils.visualize_skeletons import construct_skeletons | |
skel_vertices, skel_faces = construct_skeletons(item['joints'].cpu().numpy(), item['kin_parent'].cpu().numpy()) | |
skel_mesh = trimesh.Trimesh(skel_vertices, skel_faces, process = False) | |
if geo_renderer is None: | |
geo_renderer = Renderer(item['img_w'], item['img_h'], shader_name = 'phong_geometry', bg_color = (1, 1, 1)) | |
extr, intr = item['extr'], item['intr'] | |
geo_renderer.set_camera(extr, intr) | |
geo_renderer.set_model(skel_vertices[skel_faces.reshape(-1)], skel_mesh.vertex_normals.astype(np.float32)[skel_faces.reshape(-1)]) | |
skel_img = geo_renderer.render()[:, :, :3] | |
skel_img = (skel_img * 255).astype(np.uint8) | |
cv.imwrite(output_dir + '/live_skeleton/%08d.jpg' % item['data_idx'], skel_img) | |
if log_time: | |
time_end.record() | |
torch.cuda.synchronize() | |
print('Rendering skeletons costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
time_start.record() | |
if 'smpl_pos_map' not in items: | |
self.avatar_net.get_pose_map(items) | |
# pca | |
if use_pca: | |
mask = training_dataset.pos_map_mask | |
live_pos_map = items['smpl_pos_map'].permute(1, 2, 0).cpu().numpy() | |
front_live_pos_map, back_live_pos_map = np.split(live_pos_map, [3], 2) | |
pose_conds = front_live_pos_map[mask] | |
new_pose_conds = training_dataset.transform_pca(pose_conds, sigma_pca = float(self.opt['test'].get('sigma_pca', 2.))) | |
front_live_pos_map[mask] = new_pose_conds | |
live_pos_map = np.concatenate([front_live_pos_map, back_live_pos_map], 2) | |
items.update({ | |
'smpl_pos_map_pca': torch.from_numpy(live_pos_map).to(config.device).permute(2, 0, 1) | |
}) | |
if log_time: | |
time_end.record() | |
torch.cuda.synchronize() | |
print('Rendering pose conditions costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
time_start.record() | |
output = self.avatar_net.render(items, bg_color = self.bg_color, use_pca = use_pca) | |
output_wo_hand = self.avatar_net.render_wo_hand(items, bg_color = self.bg_color, use_pca = use_pca) | |
mask_output = self.avatar_net.render_mask(items, bg_color = self.bg_color, use_pca = use_pca) | |
if log_time: | |
time_end.record() | |
torch.cuda.synchronize() | |
print('Rendering avatar costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
time_start.record() | |
if 'rgb_map' in output_wo_hand: | |
rgb_map_wo_hand = output_wo_hand['rgb_map'] | |
if 'full_body_rgb_map' in mask_output: | |
os.makedirs(output_dir + '/full_body_mask', exist_ok = True) | |
full_body_mask = mask_output['full_body_rgb_map'] | |
full_body_mask.clip_(0., 1.) | |
full_body_mask = (full_body_mask * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/full_body_mask/%08d.png' % item['data_idx'], full_body_mask.cpu().numpy()) | |
if 'hand_only_rgb_map' in mask_output: | |
os.makedirs(output_dir + '/hand_only_mask', exist_ok = True) | |
hand_only_mask = mask_output['hand_only_rgb_map'] | |
hand_only_mask.clip_(0., 1.) | |
hand_only_mask = (hand_only_mask * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/hand_only_mask/%08d.png' % item['data_idx'], hand_only_mask.cpu().numpy()) | |
if 'full_body_rgb_map' in mask_output and 'hand_only_rgb_map' in mask_output: | |
# mask only covers hand | |
body_red_mask = (mask_output['full_body_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['full_body_rgb_map'].device)) | |
body_red_mask = (body_red_mask*body_red_mask).sum(dim=2) < 0.01 # need save | |
hand_red_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['hand_only_rgb_map'].device)) | |
hand_red_mask = (hand_red_mask*hand_red_mask).sum(dim=2) < 0.01 | |
if_mask_r_hand = abs(body_red_mask.sum() - hand_red_mask.sum()) / hand_red_mask.sum() > 0.95 | |
if_mask_r_hand = if_mask_r_hand.cpu().numpy() | |
body_blue_mask = (mask_output['full_body_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['full_body_rgb_map'].device)) | |
body_blue_mask = (body_blue_mask*body_blue_mask).sum(dim=2) < 0.01 # need save | |
hand_blue_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['hand_only_rgb_map'].device)) | |
hand_blue_mask = (hand_blue_mask*hand_blue_mask).sum(dim=2) < 0.01 | |
if_mask_l_hand = abs(body_blue_mask.sum() - hand_blue_mask.sum()) / hand_blue_mask.sum() > 0.95 | |
if_mask_l_hand = if_mask_l_hand.cpu().numpy() | |
# 保存左右手被遮挡部分的mask | |
red_mask = hand_red_mask ^ (hand_red_mask & body_red_mask) | |
blue_mask = hand_blue_mask ^ (hand_blue_mask & body_blue_mask) | |
all_mask = red_mask | blue_mask | |
# now save 3 mask to 3 folders | |
os.makedirs(output_dir + '/hand_mask', exist_ok = True) | |
os.makedirs(output_dir + '/r_hand_mask', exist_ok = True) | |
os.makedirs(output_dir + '/l_hand_mask', exist_ok = True) | |
os.makedirs(output_dir + '/hand_visual', exist_ok = True) | |
all_mask = (all_mask * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/hand_mask/%08d.png' % item['data_idx'], all_mask.cpu().numpy()) | |
r_hand_mask = (body_red_mask * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/r_hand_mask/%08d.png' % item['data_idx'], r_hand_mask.cpu().numpy()) | |
l_hand_mask = (body_blue_mask * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/l_hand_mask/%08d.png' % item['data_idx'], l_hand_mask.cpu().numpy()) | |
hand_visual = [if_mask_r_hand, if_mask_l_hand] | |
# save to npy | |
with open(output_dir + '/hand_visual/%08d.npy' % item['data_idx'], 'wb') as f: | |
np.save(f, hand_visual) | |
# now build sleeve_mask | |
if 'left_hand_rgb_map' in mask_output and 'right_hand_rgb_map' in mask_output: | |
os.makedirs(output_dir + '/left_sleeve_mask', exist_ok = True) | |
os.makedirs(output_dir + '/right_sleeve_mask', exist_ok = True) | |
mask = (r_hand_mask>128) | (l_hand_mask>128)| (all_mask>128) | |
mask = mask.cpu().numpy().astype(np.uint8) | |
# 定义一个结构元素,可以调整其大小以改变膨胀的程度 | |
kernel = np.ones((5, 5), np.uint8) | |
# 应用膨胀操作 | |
mask = cv.dilate(mask, kernel, iterations=3) | |
mask = torch.tensor(mask).to(config.device) | |
left_hand_mask = mask_output['left_hand_rgb_map'] | |
left_hand_mask.clip_(0., 1.) | |
# non white part is mask | |
left_hand_mask = (torch.tensor([1., 1., 1.], device = left_hand_mask.device) - left_hand_mask) | |
left_hand_mask = (left_hand_mask*left_hand_mask).sum(dim=2) > 0.01 | |
# dele two hand mask | |
left_hand_mask = left_hand_mask & ~mask | |
right_hand_mask = mask_output['right_hand_rgb_map'] | |
right_hand_mask.clip_(0., 1.) | |
right_hand_mask = (torch.tensor([1., 1., 1.], device = right_hand_mask.device) - right_hand_mask) | |
right_hand_mask = (right_hand_mask*right_hand_mask).sum(dim=2) > 0.01 | |
right_hand_mask = right_hand_mask & ~mask | |
# save | |
left_hand_mask = (left_hand_mask * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/left_sleeve_mask/%08d.png' % item['data_idx'], left_hand_mask.cpu().numpy()) | |
right_hand_mask = (right_hand_mask * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/right_sleeve_mask/%08d.png' % item['data_idx'], right_hand_mask.cpu().numpy()) | |
rgb_map = output['rgb_map'] | |
rgb_map.clip_(0., 1.) | |
rgb_map = (rgb_map * 255).to(torch.uint8).cpu().numpy() | |
cv.imwrite(output_dir + '/rgb_map/%08d.jpg' % item['data_idx'], rgb_map) | |
# 利用 r_hand_mask 和 l_hand_mask,将wo_hand图像中的mask部分覆盖rgb_map | |
if 'rgb_map' in output_wo_hand and 'full_body_rgb_map' in mask_output and 'hand_only_rgb_map' in mask_output: | |
rgb_map_wo_hand = output_wo_hand['rgb_map'] | |
rgb_map_wo_hand.clip_(0., 1.) | |
rgb_map_wo_hand = (rgb_map_wo_hand * 255).to(torch.uint8).cpu().numpy() | |
r_mask = (r_hand_mask>128).cpu().numpy() | |
l_mask = (l_hand_mask>128).cpu().numpy() | |
mask = r_mask | l_mask | |
mask = mask.astype(np.uint8) | |
# 定义一个结构元素,可以调整其大小以改变膨胀的程度 | |
kernel = np.ones((5, 5), np.uint8) | |
# 应用膨胀操作 | |
mask = cv.dilate(mask, kernel, iterations=3) | |
mask = mask.astype(np.bool_) | |
mask = np.expand_dims(mask, axis=2) | |
# print('mask shape: ', mask.shape) | |
import ipdb | |
# ipdb.set_trace() | |
mix = rgb_map_wo_hand.copy() * mask + rgb_map * ~mask | |
cv.imwrite(output_dir + '/rgb_map_wo_hand/%08d.jpg' % item['data_idx'], mix) | |
if 'torso_map' in output: | |
os.makedirs(output_dir + '/torso_map', exist_ok = True) | |
torso_map = output['torso_map'][:, :, 0] | |
torso_map.clip_(0., 1.) | |
torso_map = (torso_map * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/torso_map/%08d.png' % item['data_idx'], torso_map.cpu().numpy()) | |
if 'mask_map' in output: | |
os.makedirs(output_dir + '/mask_map', exist_ok = True) | |
mask_map = output['mask_map'][:, :, 0] | |
mask_map.clip_(0., 1.) | |
mask_map = (mask_map * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/mask_map/%08d.png' % item['data_idx'], mask_map.cpu().numpy()) | |
if self.opt['test'].get('save_tex_map', False): | |
os.makedirs(output_dir + '/cano_tex_map', exist_ok = True) | |
cano_tex_map = output['cano_tex_map'] | |
cano_tex_map.clip_(0., 1.) | |
cano_tex_map = (cano_tex_map * 255).to(torch.uint8) | |
cv.imwrite(output_dir + '/cano_tex_map/%08d.jpg' % item['data_idx'], cano_tex_map.cpu().numpy()) | |
if self.opt['test'].get('save_ply', False): | |
if item['data_idx'] == 0: | |
save_gaussians_as_ply(output_dir + '/posed_gaussians/%08d.ply' % item['data_idx'], output['posed_gaussians']) | |
for k in output['posed_gaussians'].keys(): | |
if isinstance(output['posed_gaussians'][k], torch.Tensor): | |
output['posed_gaussians'][k] = output['posed_gaussians'][k].detach().cpu().numpy() | |
np.savez(output_dir + '/posed_gaussians/%08d.npz' % item['data_idx'], **output['posed_gaussians']) | |
np.savez(output_dir + ('/posed_params/%08d.npz' % item['data_idx']), | |
betas=training_dataset.smpl_data['betas'].reshape([-1]).detach().cpu().numpy(), | |
global_orient=item['global_orient'].reshape([-1]).detach().cpu().numpy(), | |
transl=item['transl'].reshape([-1]).detach().cpu().numpy(), | |
body_pose=item['body_pose'].reshape([-1]).detach().cpu().numpy()) | |
if log_time: | |
time_end.record() | |
torch.cuda.synchronize() | |
print('Saving images costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) | |
print('Animating one frame costs %.4f secs' % (time_start_all.elapsed_time(time_end) / 1000.)) | |
torch.cuda.empty_cache() | |
def save_ckpt(self, path, save_optm = True): | |
os.makedirs(path, exist_ok = True) | |
net_dict = { | |
'epoch_idx': self.epoch_idx, | |
'iter_idx': self.iter_idx, | |
'avatar_net': self.avatar_net.state_dict(), | |
} | |
print('Saving networks to ', path + '/net.pt') | |
torch.save(net_dict, path + '/net.pt') | |
if save_optm: | |
optm_dict = { | |
'avatar_net': self.optm.state_dict(), | |
} | |
print('Saving optimizers to ', path + '/optm.pt') | |
torch.save(optm_dict, path + '/optm.pt') | |
def load_ckpt(self, path, load_optm = True): | |
print('Loading networks from ', path + '/net.pt') | |
net_dict = torch.load(path + '/net.pt') | |
if 'avatar_net' in net_dict: | |
self.avatar_net.load_state_dict(net_dict['avatar_net']) | |
else: | |
print('[WARNING] Cannot find "avatar_net" from the network checkpoint!') | |
epoch_idx = net_dict['epoch_idx'] | |
iter_idx = net_dict['iter_idx'] | |
if load_optm and os.path.exists(path + '/optm.pt'): | |
print('Loading optimizers from ', path + '/optm.pt') | |
optm_dict = torch.load(path + '/optm.pt') | |
if 'avatar_net' in optm_dict: | |
self.optm.load_state_dict(optm_dict['avatar_net']) | |
else: | |
print('[WARNING] Cannot find "avatar_net" from the optimizer checkpoint!') | |
return epoch_idx, iter_idx | |
if __name__ == '__main__': | |
torch.manual_seed(31359) | |
np.random.seed(31359) | |
# torch.autograd.set_detect_anomaly(True) | |
from argparse import ArgumentParser | |
arg_parser = ArgumentParser() | |
arg_parser.add_argument('-c', '--config_path', type = str, help = 'Configuration file path.') | |
arg_parser.add_argument('-m', '--mode', type = str, help = 'Running mode.', default = 'train') | |
args = arg_parser.parse_args() | |
config.load_global_opt(args.config_path) | |
if args.mode is not None: | |
config.opt['mode'] = args.mode | |
trainer = AvatarTrainer(config.opt) | |
if config.opt['mode'] == 'train': | |
if not safe_exists(config.opt['train']['net_ckpt_dir'] + '/pretrained') \ | |
and not safe_exists(config.opt['train']['pretrained_dir'])\ | |
and not safe_exists(config.opt['train']['prev_ckpt']): | |
trainer.pretrain() | |
trainer.train() | |
elif config.opt['mode'] == 'test': | |
trainer.test() | |
else: | |
raise NotImplementedError('Invalid running mode!') | |