from ast import dump import os from turtle import left # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # os.environ['TORCH_USE_CUDA_DSA'] = '1' os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' # from inflection import camelize import yaml import shutil import collections import torch import torch.utils.data import torch.nn.functional as F import numpy as np import cv2 as cv import glob import datetime import trimesh from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import importlib import json import config from network.lpips import LPIPS from dataset.dataset_pose import PoseDataset import utils.net_util as net_util import utils.visualize_util as visualize_util from utils.renderer import Renderer from utils.net_util import to_cuda from utils.obj_io import save_mesh_as_ply from gaussians.obj_io import save_gaussians_as_ply def safe_exists(path): if path is None: return False return os.path.exists(path) class AvatarTrainer: def __init__(self, opt): self.opt = opt self.patch_size = 512 self.iter_idx = 0 self.iter_num = 800000 self.lr_init = float(self.opt['train'].get('lr_init', 5e-4)) avatar_module = self.opt['model'].get('module', 'network.avatar') print('Import AvatarNet from %s' % avatar_module) AvatarNet = importlib.import_module(avatar_module).AvatarNet self.avatar_net = AvatarNet(self.opt['model']).to(config.device) self.optm = torch.optim.Adam( self.avatar_net.parameters(), lr = self.lr_init ) self.random_bg_color = self.opt['train'].get('random_bg_color', True) self.bg_color = (1., 1., 1.) self.bg_color_cuda = torch.from_numpy(np.asarray(self.bg_color)).to(torch.float32).to(config.device) self.loss_weight = self.opt['train']['loss_weight'] self.finetune_color = self.opt['train']['finetune_color'] print('# Parameter number of AvatarNet is %d' % (sum([p.numel() for p in self.avatar_net.parameters()]))) def update_lr(self): alpha = 0.05 progress = self.iter_idx / self.iter_num learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (1 - alpha) + alpha lr = self.lr_init * learning_factor for param_group in self.optm.param_groups: param_group['lr'] = lr return lr @staticmethod def requires_net_grad(net: torch.nn.Module, flag = True): for p in net.parameters(): p.requires_grad = flag def crop_image(self, gt_mask, patch_size, randomly, *args): """ :param gt_mask: (H, W) :param patch_size: resize the cropped patch to the given patch_size :param randomly: whether to randomly sample the patch :param args: input images with shape of (C, H, W) """ mask_uv = torch.argwhere(gt_mask > 0.) min_v, min_u = mask_uv.min(0)[0] max_v, max_u = mask_uv.max(0)[0] len_v = max_v - min_v len_u = max_u - min_u max_size = max(len_v, len_u) cropped_images = [] if randomly and max_size > patch_size: random_v = torch.randint(0, max_size - patch_size + 1, (1,)).to(max_size) random_u = torch.randint(0, max_size - patch_size + 1, (1,)).to(max_size) for image in args: cropped_image = self.bg_color_cuda[:, None, None] * torch.ones((3, max_size, max_size), dtype = image.dtype, device = image.device) if len_v > len_u: start_u = (max_size - len_u) // 2 cropped_image[:, :, start_u: start_u + len_u] = image[:, min_v: max_v, min_u: max_u] else: start_v = (max_size - len_v) // 2 cropped_image[:, start_v: start_v + len_v, :] = image[:, min_v: max_v, min_u: max_u] if randomly and max_size > patch_size: cropped_image = cropped_image[:, random_v: random_v + patch_size, random_u: random_u + patch_size] else: cropped_image = F.interpolate(cropped_image[None], size = (patch_size, patch_size), mode = 'bilinear')[0] cropped_images.append(cropped_image) # cv.imshow('cropped_image', cropped_image.detach().cpu().numpy().transpose(1, 2, 0)) # cv.imshow('cropped_gt_image', cropped_gt_image.detach().cpu().numpy().transpose(1, 2, 0)) # cv.waitKey(0) if len(cropped_images) > 1: return cropped_images else: return cropped_images[0] def compute_lpips_loss(self, image, gt_image): assert image.shape[1] == image.shape[2] and gt_image.shape[1] == gt_image.shape[2] lpips_loss = self.lpips.forward( image[None, [2, 1, 0]], gt_image[None, [2, 1, 0]], normalize = True ).mean() return lpips_loss def forward_one_pass_pretrain(self, items): total_loss = 0 batch_losses = {} l1_loss = torch.nn.L1Loss() items = net_util.delete_batch_idx(items) pose_map = items['smpl_pos_map'][:3] position_loss = l1_loss(self.avatar_net.get_positions(pose_map), self.avatar_net.cano_gaussian_model.get_xyz) total_loss += position_loss batch_losses.update({ 'position': position_loss.item() }) opacity, scales, rotations = self.avatar_net.get_others(pose_map) opacity_loss = l1_loss(opacity, self.avatar_net.cano_gaussian_model.get_opacity) total_loss += opacity_loss batch_losses.update({ 'opacity': opacity_loss.item() }) scale_loss = l1_loss(scales, self.avatar_net.cano_gaussian_model.get_scaling) total_loss += scale_loss batch_losses.update({ 'scale': scale_loss.item() }) rotation_loss = l1_loss(rotations, self.avatar_net.cano_gaussian_model.get_rotation) total_loss += rotation_loss batch_losses.update({ 'rotation': rotation_loss.item() }) total_loss.backward() self.optm.step() self.optm.zero_grad() return total_loss, batch_losses def forward_one_pass(self, items): # forward_start = torch.cuda.Event(enable_timing = True) # forward_end = torch.cuda.Event(enable_timing = True) # backward_start = torch.cuda.Event(enable_timing = True) # backward_end = torch.cuda.Event(enable_timing = True) # step_start = torch.cuda.Event(enable_timing = True) # step_end = torch.cuda.Event(enable_timing = True) if self.random_bg_color: self.bg_color = np.random.rand(3) self.bg_color_cuda = torch.from_numpy(np.asarray(self.bg_color)).to(torch.float32).to(config.device) total_loss = 0 batch_losses = {} items = net_util.delete_batch_idx(items) """ Optimize generator """ if self.finetune_color: self.requires_net_grad(self.avatar_net.color_net, True) self.requires_net_grad(self.avatar_net.position_net, False) self.requires_net_grad(self.avatar_net.other_net, True) else: self.requires_net_grad(self.avatar_net, True) # forward_start.record() render_output = self.avatar_net.render(items, self.bg_color) image = render_output['rgb_map'].permute(2, 0, 1) offset = render_output['offset'] # mask image & set bg color items['color_img'][~items['mask_img']] = self.bg_color_cuda gt_image = items['color_img'].permute(2, 0, 1) mask_img = items['mask_img'].to(torch.float32) boundary_mask_img = 1. - items['boundary_mask_img'].to(torch.float32) image = image * boundary_mask_img[None] + (1. - boundary_mask_img[None]) * self.bg_color_cuda[:, None, None] gt_image = gt_image * boundary_mask_img[None] + (1. - boundary_mask_img[None]) * self.bg_color_cuda[:, None, None] # cv.imshow('image', image.detach().permute(1, 2, 0).cpu().numpy()) # cv.imshow('gt_image', gt_image.permute(1, 2, 0).cpu().numpy()) # cv.waitKey(0) if self.loss_weight['l1'] > 0.: l1_loss = torch.abs(image - gt_image).mean() total_loss += self.loss_weight['l1'] * l1_loss batch_losses.update({ 'l1_loss': l1_loss.item() }) if self.loss_weight.get('mask', 0.) and 'mask_map' in render_output: rendered_mask = render_output['mask_map'].squeeze(-1) * boundary_mask_img gt_mask = mask_img * boundary_mask_img # cv.imshow('rendered_mask', rendered_mask.detach().cpu().numpy()) # cv.imshow('gt_mask', gt_mask.detach().cpu().numpy()) # cv.waitKey(0) mask_loss = torch.abs(rendered_mask - gt_mask).mean() # mask_loss = torch.nn.BCELoss()(rendered_mask, gt_mask) total_loss += self.loss_weight.get('mask', 0.) * mask_loss batch_losses.update({ 'mask_loss': mask_loss.item() }) if self.loss_weight['lpips'] > 0.: # crop images random_patch_flag = False if self.iter_idx < 300000 else True image, gt_image = self.crop_image(mask_img, self.patch_size, random_patch_flag, image, gt_image) # cv.imshow('image', image.detach().permute(1, 2, 0).cpu().numpy()) # cv.imshow('gt_image', gt_image.permute(1, 2, 0).cpu().numpy()) # cv.waitKey(0) lpips_loss = self.compute_lpips_loss(image, gt_image) total_loss += self.loss_weight['lpips'] * lpips_loss batch_losses.update({ 'lpips_loss': lpips_loss.item() }) # if self.loss_weight['offset'] > 0.: if True: offset_loss = torch.linalg.norm(offset, dim = -1).mean() total_loss += self.loss_weight['offset'] * offset_loss batch_losses.update({ 'offset_loss': offset_loss.item() }) # forward_end.record() # backward_start.record() total_loss.backward() # backward_end.record() # step_start.record() self.optm.step() self.optm.zero_grad() # step_end.record() # torch.cuda.synchronize() # print(f'Forward costs: {forward_start.elapsed_time(forward_end) / 1000.}, ', # f'Backward costs: {backward_start.elapsed_time(backward_end) / 1000.}, ', # f'Step costs: {step_start.elapsed_time(step_end) / 1000.}') return total_loss, batch_losses def pretrain(self): dataset_module = self.opt['train'].get('dataset', 'MvRgbDatasetAvatarReX') MvRgbDataset = importlib.import_module('dataset.dataset_mv_rgb').__getattribute__(dataset_module) self.dataset = MvRgbDataset(**self.opt['train']['data']) batch_size = self.opt['train']['batch_size'] num_workers = self.opt['train']['num_workers'] batch_num = len(self.dataset) // batch_size dataloader = torch.utils.data.DataLoader(self.dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True) # tb writer log_dir = self.opt['train']['net_ckpt_dir'] + '/' + datetime.datetime.now().strftime('pretrain_%Y_%m_%d_%H_%M_%S') writer = SummaryWriter(log_dir) smooth_interval = 10 smooth_count = 0 smooth_losses = {} for epoch_idx in range(0, 9999999): self.epoch_idx = epoch_idx for batch_idx, items in enumerate(dataloader): self.iter_idx = batch_idx + epoch_idx * batch_num items = to_cuda(items) # one_step_start.record() total_loss, batch_losses = self.forward_one_pass_pretrain(items) # one_step_end.record() # torch.cuda.synchronize() # print('One step costs %f secs' % (one_step_start.elapsed_time(one_step_end) / 1000.)) # record batch loss for key, loss in batch_losses.items(): if key in smooth_losses: smooth_losses[key] += loss else: smooth_losses[key] = loss smooth_count += 1 if self.iter_idx % smooth_interval == 0: log_info = 'epoch %d, batch %d, iter %d, ' % (epoch_idx, batch_idx, self.iter_idx) for key in smooth_losses.keys(): smooth_losses[key] /= smooth_count writer.add_scalar('%s/Iter' % key, smooth_losses[key], self.iter_idx) log_info = log_info + ('%s: %f, ' % (key, smooth_losses[key])) smooth_losses[key] = 0. smooth_count = 0 print(log_info) with open(os.path.join(log_dir, 'loss.txt'), 'a') as fp: fp.write(log_info + '\n') if self.iter_idx % 200 == 0 and self.iter_idx != 0: self.mini_test(pretraining = True) if self.iter_idx == 5000: model_folder = self.opt['train']['net_ckpt_dir'] + '/pretrained' os.makedirs(model_folder, exist_ok = True) self.save_ckpt(model_folder, save_optm = True) self.iter_idx = 0 return def train(self): dataset_module = self.opt['train'].get('dataset', 'MvRgbDatasetAvatarReX') MvRgbDataset = importlib.import_module('dataset.dataset_mv_rgb').__getattribute__(dataset_module) self.dataset = MvRgbDataset(**self.opt['train']['data']) batch_size = self.opt['train']['batch_size'] num_workers = self.opt['train']['num_workers'] batch_num = len(self.dataset) // batch_size dataloader = torch.utils.data.DataLoader(self.dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True) if 'lpips' in self.opt['train']['loss_weight']: self.lpips = LPIPS(net = 'vgg').to(config.device) for p in self.lpips.parameters(): p.requires_grad = False if self.opt['train']['prev_ckpt'] is not None: start_epoch, self.iter_idx = self.load_ckpt(self.opt['train']['prev_ckpt'], load_optm = True) start_epoch += 1 self.iter_idx += 1 else: prev_ckpt_path = self.opt['train']['net_ckpt_dir'] + '/epoch_latest' if safe_exists(prev_ckpt_path): start_epoch, self.iter_idx = self.load_ckpt(prev_ckpt_path, load_optm = True) start_epoch += 1 self.iter_idx += 1 else: if safe_exists(self.opt['train']['pretrained_dir']): self.load_ckpt(self.opt['train']['pretrained_dir'], load_optm = False) elif safe_exists(self.opt['train']['net_ckpt_dir'] + '/pretrained'): self.load_ckpt(self.opt['train']['net_ckpt_dir'] + '/pretrained', load_optm = False) else: raise FileNotFoundError('Cannot find pretrained checkpoint!') self.optm.state = collections.defaultdict(dict) start_epoch = 0 self.iter_idx = 0 # one_step_start = torch.cuda.Event(enable_timing = True) # one_step_end = torch.cuda.Event(enable_timing = True) # tb writer log_dir = self.opt['train']['net_ckpt_dir'] + '/' + datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S') writer = SummaryWriter(log_dir) yaml.dump(self.opt, open(log_dir + '/config_bk.yaml', 'w'), sort_keys = False) smooth_interval = 10 smooth_count = 0 smooth_losses = {} for epoch_idx in range(start_epoch, 9999999): self.epoch_idx = epoch_idx for batch_idx, items in enumerate(dataloader): lr = self.update_lr() items = to_cuda(items) # one_step_start.record() total_loss, batch_losses = self.forward_one_pass(items) # one_step_end.record() # torch.cuda.synchronize() # print('One step costs %f secs' % (one_step_start.elapsed_time(one_step_end) / 1000.)) # record batch loss for key, loss in batch_losses.items(): if key in smooth_losses: smooth_losses[key] += loss else: smooth_losses[key] = loss smooth_count += 1 if self.iter_idx % smooth_interval == 0: log_info = 'epoch %d, batch %d, iter %d, lr %e, ' % (epoch_idx, batch_idx, self.iter_idx, lr) for key in smooth_losses.keys(): smooth_losses[key] /= smooth_count writer.add_scalar('%s/Iter' % key, smooth_losses[key], self.iter_idx) log_info = log_info + ('%s: %f, ' % (key, smooth_losses[key])) smooth_losses[key] = 0. smooth_count = 0 print(log_info) with open(os.path.join(log_dir, 'loss.txt'), 'a') as fp: fp.write(log_info + '\n') torch.cuda.empty_cache() if self.iter_idx % self.opt['train']['eval_interval'] == 0 and self.iter_idx != 0: if self.iter_idx % (10 * self.opt['train']['eval_interval']) == 0: eval_cano_pts = True else: eval_cano_pts = False self.mini_test(eval_cano_pts = eval_cano_pts) if self.iter_idx % self.opt['train']['ckpt_interval']['batch'] == 0 and self.iter_idx != 0: for folder in glob.glob(self.opt['train']['net_ckpt_dir'] + '/batch_*'): shutil.rmtree(folder) model_folder = self.opt['train']['net_ckpt_dir'] + '/batch_%d' % self.iter_idx os.makedirs(model_folder, exist_ok = True) self.save_ckpt(model_folder, save_optm = True) if self.iter_idx == self.iter_num: print('# Training is done.') return self.iter_idx += 1 """ End of epoch """ if epoch_idx % self.opt['train']['ckpt_interval']['epoch'] == 0 and epoch_idx != 0: model_folder = self.opt['train']['net_ckpt_dir'] + '/epoch_%d' % epoch_idx os.makedirs(model_folder, exist_ok = True) self.save_ckpt(model_folder) if batch_num > 50: latest_folder = self.opt['train']['net_ckpt_dir'] + '/epoch_latest' os.makedirs(latest_folder, exist_ok = True) self.save_ckpt(latest_folder) @torch.no_grad() def mini_test(self, pretraining = False, eval_cano_pts = False): self.avatar_net.eval() img_factor = self.opt['train'].get('eval_img_factor', 1.0) # training data pose_idx, view_idx = self.opt['train'].get('eval_training_ids', (310, 19)) intr = self.dataset.intr_mats[view_idx].copy() intr[:2] *= img_factor item = self.dataset.getitem(0, pose_idx = pose_idx, view_idx = view_idx, training = False, eval = True, img_h = int(self.dataset.img_heights[view_idx] * img_factor), img_w = int(self.dataset.img_widths[view_idx] * img_factor), extr = self.dataset.extr_mats[view_idx], intr = intr, exact_hand_pose = True) items = net_util.to_cuda(item, add_batch = False) gs_render = self.avatar_net.render(items, self.bg_color) # gs_render = self.avatar_net.render_debug(items) rgb_map = gs_render['rgb_map'] rgb_map.clip_(0., 1.) rgb_map = (rgb_map.cpu().numpy() * 255).astype(np.uint8) # cv.imshow('rgb_map', rgb_map.cpu().numpy()) # cv.waitKey(0) if not pretraining: output_dir = self.opt['train']['net_ckpt_dir'] + '/eval/training' else: output_dir = self.opt['train']['net_ckpt_dir'] + '/eval_pretrain/training' gt_image, _ = self.dataset.load_color_mask_images(pose_idx, view_idx) if gt_image is not None: gt_image = cv.resize(gt_image, (0, 0), fx = img_factor, fy = img_factor) rgb_map = np.concatenate([rgb_map, gt_image], 1) os.makedirs(output_dir, exist_ok = True) cv.imwrite(output_dir + '/iter_%d.jpg' % self.iter_idx, rgb_map) if eval_cano_pts: os.makedirs(output_dir + '/cano_pts', exist_ok = True) save_mesh_as_ply(output_dir + '/cano_pts/iter_%d.ply' % self.iter_idx, (self.avatar_net.init_points + gs_render['offset']).cpu().numpy()) # training data pose_idx, view_idx = self.opt['train'].get('eval_testing_ids', (310, 19)) intr = self.dataset.intr_mats[view_idx].copy() intr[:2] *= img_factor item = self.dataset.getitem(0, pose_idx = pose_idx, view_idx = view_idx, training = False, eval = True, img_h = int(self.dataset.img_heights[view_idx] * img_factor), img_w = int(self.dataset.img_widths[view_idx] * img_factor), extr = self.dataset.extr_mats[view_idx], intr = intr, exact_hand_pose = True) items = net_util.to_cuda(item, add_batch = False) gs_render = self.avatar_net.render(items, bg_color = self.bg_color) # gs_render = self.avatar_net.render_debug(items) rgb_map = gs_render['rgb_map'] rgb_map.clip_(0., 1.) rgb_map = (rgb_map.cpu().numpy() * 255).astype(np.uint8) # cv.imshow('rgb_map', rgb_map.cpu().numpy()) # cv.waitKey(0) if not pretraining: output_dir = self.opt['train']['net_ckpt_dir'] + '/eval/testing' else: output_dir = self.opt['train']['net_ckpt_dir'] + '/eval_pretrain/testing' gt_image, _ = self.dataset.load_color_mask_images(pose_idx, view_idx) if gt_image is not None: gt_image = cv.resize(gt_image, (0, 0), fx = img_factor, fy = img_factor) rgb_map = np.concatenate([rgb_map, gt_image], 1) os.makedirs(output_dir, exist_ok = True) cv.imwrite(output_dir + '/iter_%d.jpg' % self.iter_idx, rgb_map) if eval_cano_pts: os.makedirs(output_dir + '/cano_pts', exist_ok = True) save_mesh_as_ply(output_dir + '/cano_pts/iter_%d.ply' % self.iter_idx, (self.avatar_net.init_points + gs_render['offset']).cpu().numpy()) self.avatar_net.train() def dump_renderer_info(self, dump_dir, extrs, intrs, img_heights, img_widths): with open(os.path.join(dump_dir, 'cfg_args'), 'w') as fp: outstr = "Namespace(sh_degree=%d, source_path='%s', model_path='%s', images='images', resolution=-1, " \ "white_background=False, data_device='cuda', eval=False)" % ( 3, self.opt['train']['data']['data_dir'], dump_dir) fp.write(outstr) with open(os.path.join(dump_dir, 'cameras.json'), 'w') as fp: cam_jsons = [] for ci in range(len(extrs)): extr, intr = extrs[ci], intrs[ci] img_h, img_w = img_heights[ci], img_widths[ci] w2c = extr c2w = np.linalg.inv(w2c) pos = c2w[:3, 3] rot = c2w[:3, :3] serializable_array_2d = [x.tolist() for x in rot] camera_entry = { 'id': ci, 'img_name': '%08d' % ci, 'width': int(img_w), 'height': int(img_h), 'position': pos.tolist(), 'rotation': serializable_array_2d, 'fy': float(intr[1, 1]), 'fx': float(intr[0, 0]), } cam_jsons.append(camera_entry) json.dump(cam_jsons, fp) return @torch.no_grad() def test(self): self.avatar_net.eval() # ipdb.set_trace() dataset_module = self.opt['train'].get('dataset', 'MvRgbDatasetAvatarReX') MvRgbDataset = importlib.import_module('dataset.dataset_mv_rgb').__getattribute__(dataset_module) training_dataset = MvRgbDataset(**self.opt['train']['data'], training = False) if self.opt['test'].get('n_pca', -1) >= 1: training_dataset.compute_pca(n_components = self.opt['test']['n_pca']) if 'pose_data' in self.opt['test']: testing_dataset = PoseDataset(**self.opt['test']['pose_data'], smpl_shape = training_dataset.smpl_data['betas'][0]) dataset_name = testing_dataset.dataset_name seq_name = testing_dataset.seq_name else: testing_dataset = MvRgbDataset(**self.opt['test']['data'], training = False) dataset_name = 'training' seq_name = '' # print('come here') self.dataset = testing_dataset iter_idx = self.load_ckpt(self.opt['test']['prev_ckpt'], False)[1] output_dir = self.opt['test'].get('output_dir', None) if output_dir is None: view_setting = config.opt['test'].get('view_setting', 'free') if view_setting == 'camera': view_folder = 'cam_%03d' % config.opt['test']['render_view_idx'] else: view_folder = view_setting + '_view' exp_name = os.path.basename(os.path.dirname(self.opt['test']['prev_ckpt'])) output_dir = f'./test_results/{training_dataset.subject_name}/{exp_name}/{dataset_name}_{seq_name}_{view_folder}' + '/batch_%06d' % iter_idx use_pca = self.opt['test'].get('n_pca', -1) >= 1 if use_pca: output_dir += '/pca_%d_sigma_%.2f' % (self.opt['test'].get('n_pca', -1), float(self.opt['test'].get('sigma_pca', 1.))) else: output_dir += '/vanilla' print('# Output dir: \033[1;31m%s\033[0m' % output_dir) os.makedirs(output_dir + '/live_skeleton', exist_ok = True) os.makedirs(output_dir + '/rgb_map', exist_ok = True) os.makedirs(output_dir + '/rgb_map_wo_hand', exist_ok = True) os.makedirs(output_dir + '/torso_map', exist_ok = True) os.makedirs(output_dir + '/mask_map', exist_ok = True) os.makedirs(output_dir + '/posed_gaussians', exist_ok = True) os.makedirs(output_dir + '/posed_params', exist_ok = True) os.makedirs(output_dir + '/full_body_mask', exist_ok = True) os.makedirs(output_dir + '/hand_only_mask', exist_ok = True) geo_renderer = None item_0 = self.dataset.getitem(0, training = False) object_center = item_0['live_bounds'].mean(0) global_orient = item_0['global_orient'].cpu().numpy() if isinstance(item_0['global_orient'], torch.Tensor) else item_0['global_orient'] # set x and z to 0 global_orient[0] = 0 global_orient[2] = 0 global_orient = cv.Rodrigues(global_orient)[0] # print('object_center: ', object_center.tolist()) # print('global_orient: ', global_orient.tolist()) # exit(1) time_start = torch.cuda.Event(enable_timing = True) time_start_all = torch.cuda.Event(enable_timing = True) time_end = torch.cuda.Event(enable_timing = True) data_num = len(self.dataset) if self.opt['test'].get('fix_hand', False): self.avatar_net.generate_mean_hands() log_time = False # extr = visualize_util.calc_free_mv(object_center, # tar_pos = np.array([0, 0, 2.5]), # rot_Y = 0., # rot_X = 0., # global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) # intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) # img_scale = self.opt['test'].get('img_scale', 1.0) # intr[:2] *= img_scale # img_h = int(1024 * img_scale) # img_w = int(1024 * img_scale) # self.dump_renderer_info(output_dir, [extr], [intr], [img_h], [img_w]) extr_list = [] intr_list = [] img_h_list = [] img_w_list = [] for idx in tqdm(range(data_num), desc = 'Rendering avatars...'): if log_time: time_start.record() time_start_all.record() img_scale = self.opt['test'].get('img_scale', 1.0) view_setting = config.opt['test'].get('view_setting', 'free') if view_setting == 'camera': # training view setting cam_id = config.opt['test']['render_view_idx'] intr = self.dataset.intr_mats[cam_id].copy() intr[:2] *= img_scale extr = self.dataset.extr_mats[cam_id].copy() img_h, img_w = int(self.dataset.img_heights[cam_id] * img_scale), int(self.dataset.img_widths[cam_id] * img_scale) elif view_setting.startswith('free'): # free view setting # frame_num_per_circle = 360 print(self.opt['test'].get('global_orient', False)) frame_num_per_circle = 360 rot_Y = (idx % frame_num_per_circle) / float(frame_num_per_circle) * 2 * np.pi extr = visualize_util.calc_free_mv(object_center, tar_pos = np.array([0, 0, 2.5]), rot_Y = rot_Y, rot_X = 0.3 if view_setting.endswith('bird') else 0., global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) intr[:2] *= img_scale img_h = int(1024 * img_scale) img_w = int(1024 * img_scale) extr_list.append(extr) intr_list.append(intr) img_h_list.append(img_h) img_w_list.append(img_w) elif view_setting.startswith('degree120'): print('we render 120 degree') # +- 60 degree frame_per_cycle = 480 max_degree = 60 frame_half_cycle = frame_per_cycle // 2 if idx%frame_per_cycle < frame_per_cycle/2: rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) # rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi else: rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) # to radian rot_Y = rot_Y * np.pi / 180 if rot_Y<0: rot_Y = rot_Y + 2 * np.pi # print('rot_Y: ', rot_Y) extr = visualize_util.calc_free_mv(object_center, tar_pos = np.array([0, 0, 2.5]), rot_Y = rot_Y, rot_X = 0.3 if view_setting.endswith('bird') else 0., global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) intr[:2] *= img_scale img_h = int(1024 * img_scale) img_w = int(1024 * img_scale) extr_list.append(extr) intr_list.append(intr) img_h_list.append(img_h) img_w_list.append(img_w) elif view_setting.startswith('degree90'): print('we render 90 degree') # +- 60 degree frame_per_cycle = 360 max_degree = 45 frame_half_cycle = frame_per_cycle // 2 if idx%frame_per_cycle < frame_per_cycle/2: rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) # rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi else: rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) # to radian rot_Y = rot_Y * np.pi / 180 if rot_Y<0: rot_Y = rot_Y + 2 * np.pi # print('rot_Y: ', rot_Y) extr = visualize_util.calc_free_mv(object_center, tar_pos = np.array([0, 0, 2.5]), rot_Y = rot_Y, rot_X = 0.3 if view_setting.endswith('bird') else 0., global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) intr[:2] *= img_scale img_h = int(1024 * img_scale) img_w = int(1024 * img_scale) extr_list.append(extr) intr_list.append(intr) img_h_list.append(img_h) img_w_list.append(img_w) elif view_setting.startswith('front'): # front view setting extr = visualize_util.calc_free_mv(object_center, tar_pos = np.array([0, 0, 2.5]), rot_Y = 0., rot_X = 0.3 if view_setting.endswith('bird') else 0., global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) intr[:2] *= img_scale img_h = int(1024 * img_scale) img_w = int(1024 * img_scale) extr_list.append(extr) intr_list.append(intr) img_h_list.append(img_h) img_w_list.append(img_w) # print('extr: ', extr) # print('intr: ', intr) # print('img_h: ', img_h) # print('img_w: ', img_w) # exit() elif view_setting.startswith('back'): # back view setting extr = visualize_util.calc_free_mv(object_center, tar_pos = np.array([0, 0, 2.5]), rot_Y = np.pi, rot_X = 0.5 * np.pi / 4. if view_setting.endswith('bird') else 0., global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) intr[:2] *= img_scale img_h = int(1024 * img_scale) img_w = int(1024 * img_scale) elif view_setting.startswith('moving'): # moving camera setting extr = visualize_util.calc_free_mv(object_center, # tar_pos = np.array([0, 0, 3.0]), # rot_Y = -0.3, tar_pos = np.array([0, 0, 2.5]), rot_Y = 0., rot_X = 0.3 if view_setting.endswith('bird') else 0., global_orient = global_orient if self.opt['test'].get('global_orient', False) else None) intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) intr[:2] *= img_scale img_h = int(1024 * img_scale) img_w = int(1024 * img_scale) elif view_setting.startswith('cano'): cano_center = self.dataset.cano_bounds.mean(0) extr = np.identity(4, np.float32) extr[:3, 3] = -cano_center rot_x = np.identity(4, np.float32) rot_x[:3, :3] = cv.Rodrigues(np.array([np.pi, 0, 0], np.float32))[0] extr = rot_x @ extr f_len = 5000 extr[2, 3] += f_len / 512 intr = np.array([[f_len, 0, 512], [0, f_len, 512], [0, 0, 1]], np.float32) # item = self.dataset.getitem(idx, # training = False, # extr = extr, # intr = intr, # img_w = 1024, # img_h = 1024) img_w, img_h = 1024, 1024 # item['live_smpl_v'] = item['cano_smpl_v'] # item['cano2live_jnt_mats'] = torch.eye(4, dtype = torch.float32)[None].expand(item['cano2live_jnt_mats'].shape[0], -1, -1) # item['live_bounds'] = item['cano_bounds'] else: raise ValueError('Invalid view setting for animation!') self.dump_renderer_info(output_dir, extr_list, intr_list, img_h_list, img_w_list) # also save the extr and intr and img_h and img_w to json camera_info = [] for i in range(len(extr_list)): camera = {} camera['extr'] = extr_list[i].tolist() camera['intr'] = intr_list[i].tolist() camera['img_h'] = img_h_list[i] camera['img_w'] = img_w_list[i] camera_info.append(camera) with open(os.path.join(output_dir, 'camera_info.json'), 'w') as fp: json.dump(camera_info, fp) getitem_func = self.dataset.getitem_fast if hasattr(self.dataset, 'getitem_fast') else self.dataset.getitem item = getitem_func( idx, training = False, extr = extr, intr = intr, img_w = img_w, img_h = img_h ) items = to_cuda(item, add_batch = False) if view_setting.startswith('moving') or view_setting == 'free_moving': current_center = items['live_bounds'].cpu().numpy().mean(0) delta = current_center - object_center object_center[0] += delta[0] # object_center[1] += delta[1] # object_center[2] += delta[2] if log_time: time_end.record() torch.cuda.synchronize() print('Loading data costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) time_start.record() if self.opt['test'].get('render_skeleton', False): from utils.visualize_skeletons import construct_skeletons skel_vertices, skel_faces = construct_skeletons(item['joints'].cpu().numpy(), item['kin_parent'].cpu().numpy()) skel_mesh = trimesh.Trimesh(skel_vertices, skel_faces, process = False) if geo_renderer is None: geo_renderer = Renderer(item['img_w'], item['img_h'], shader_name = 'phong_geometry', bg_color = (1, 1, 1)) extr, intr = item['extr'], item['intr'] geo_renderer.set_camera(extr, intr) geo_renderer.set_model(skel_vertices[skel_faces.reshape(-1)], skel_mesh.vertex_normals.astype(np.float32)[skel_faces.reshape(-1)]) skel_img = geo_renderer.render()[:, :, :3] skel_img = (skel_img * 255).astype(np.uint8) cv.imwrite(output_dir + '/live_skeleton/%08d.jpg' % item['data_idx'], skel_img) if log_time: time_end.record() torch.cuda.synchronize() print('Rendering skeletons costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) time_start.record() if 'smpl_pos_map' not in items: self.avatar_net.get_pose_map(items) # pca if use_pca: mask = training_dataset.pos_map_mask live_pos_map = items['smpl_pos_map'].permute(1, 2, 0).cpu().numpy() front_live_pos_map, back_live_pos_map = np.split(live_pos_map, [3], 2) pose_conds = front_live_pos_map[mask] new_pose_conds = training_dataset.transform_pca(pose_conds, sigma_pca = float(self.opt['test'].get('sigma_pca', 2.))) front_live_pos_map[mask] = new_pose_conds live_pos_map = np.concatenate([front_live_pos_map, back_live_pos_map], 2) items.update({ 'smpl_pos_map_pca': torch.from_numpy(live_pos_map).to(config.device).permute(2, 0, 1) }) if log_time: time_end.record() torch.cuda.synchronize() print('Rendering pose conditions costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) time_start.record() output = self.avatar_net.render(items, bg_color = self.bg_color, use_pca = use_pca) output_wo_hand = self.avatar_net.render_wo_hand(items, bg_color = self.bg_color, use_pca = use_pca) mask_output = self.avatar_net.render_mask(items, bg_color = self.bg_color, use_pca = use_pca) if log_time: time_end.record() torch.cuda.synchronize() print('Rendering avatar costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) time_start.record() if 'rgb_map' in output_wo_hand: rgb_map_wo_hand = output_wo_hand['rgb_map'] if 'full_body_rgb_map' in mask_output: os.makedirs(output_dir + '/full_body_mask', exist_ok = True) full_body_mask = mask_output['full_body_rgb_map'] full_body_mask.clip_(0., 1.) full_body_mask = (full_body_mask * 255).to(torch.uint8) cv.imwrite(output_dir + '/full_body_mask/%08d.png' % item['data_idx'], full_body_mask.cpu().numpy()) if 'hand_only_rgb_map' in mask_output: os.makedirs(output_dir + '/hand_only_mask', exist_ok = True) hand_only_mask = mask_output['hand_only_rgb_map'] hand_only_mask.clip_(0., 1.) hand_only_mask = (hand_only_mask * 255).to(torch.uint8) cv.imwrite(output_dir + '/hand_only_mask/%08d.png' % item['data_idx'], hand_only_mask.cpu().numpy()) if 'full_body_rgb_map' in mask_output and 'hand_only_rgb_map' in mask_output: # mask only covers hand body_red_mask = (mask_output['full_body_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['full_body_rgb_map'].device)) body_red_mask = (body_red_mask*body_red_mask).sum(dim=2) < 0.01 # need save hand_red_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['hand_only_rgb_map'].device)) hand_red_mask = (hand_red_mask*hand_red_mask).sum(dim=2) < 0.01 if_mask_r_hand = abs(body_red_mask.sum() - hand_red_mask.sum()) / hand_red_mask.sum() > 0.95 if_mask_r_hand = if_mask_r_hand.cpu().numpy() body_blue_mask = (mask_output['full_body_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['full_body_rgb_map'].device)) body_blue_mask = (body_blue_mask*body_blue_mask).sum(dim=2) < 0.01 # need save hand_blue_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['hand_only_rgb_map'].device)) hand_blue_mask = (hand_blue_mask*hand_blue_mask).sum(dim=2) < 0.01 if_mask_l_hand = abs(body_blue_mask.sum() - hand_blue_mask.sum()) / hand_blue_mask.sum() > 0.95 if_mask_l_hand = if_mask_l_hand.cpu().numpy() # 保存左右手被遮挡部分的mask red_mask = hand_red_mask ^ (hand_red_mask & body_red_mask) blue_mask = hand_blue_mask ^ (hand_blue_mask & body_blue_mask) all_mask = red_mask | blue_mask # now save 3 mask to 3 folders os.makedirs(output_dir + '/hand_mask', exist_ok = True) os.makedirs(output_dir + '/r_hand_mask', exist_ok = True) os.makedirs(output_dir + '/l_hand_mask', exist_ok = True) os.makedirs(output_dir + '/hand_visual', exist_ok = True) all_mask = (all_mask * 255).to(torch.uint8) cv.imwrite(output_dir + '/hand_mask/%08d.png' % item['data_idx'], all_mask.cpu().numpy()) r_hand_mask = (body_red_mask * 255).to(torch.uint8) cv.imwrite(output_dir + '/r_hand_mask/%08d.png' % item['data_idx'], r_hand_mask.cpu().numpy()) l_hand_mask = (body_blue_mask * 255).to(torch.uint8) cv.imwrite(output_dir + '/l_hand_mask/%08d.png' % item['data_idx'], l_hand_mask.cpu().numpy()) hand_visual = [if_mask_r_hand, if_mask_l_hand] # save to npy with open(output_dir + '/hand_visual/%08d.npy' % item['data_idx'], 'wb') as f: np.save(f, hand_visual) # now build sleeve_mask if 'left_hand_rgb_map' in mask_output and 'right_hand_rgb_map' in mask_output: os.makedirs(output_dir + '/left_sleeve_mask', exist_ok = True) os.makedirs(output_dir + '/right_sleeve_mask', exist_ok = True) mask = (r_hand_mask>128) | (l_hand_mask>128)| (all_mask>128) mask = mask.cpu().numpy().astype(np.uint8) # 定义一个结构元素,可以调整其大小以改变膨胀的程度 kernel = np.ones((5, 5), np.uint8) # 应用膨胀操作 mask = cv.dilate(mask, kernel, iterations=3) mask = torch.tensor(mask).to(config.device) left_hand_mask = mask_output['left_hand_rgb_map'] left_hand_mask.clip_(0., 1.) # non white part is mask left_hand_mask = (torch.tensor([1., 1., 1.], device = left_hand_mask.device) - left_hand_mask) left_hand_mask = (left_hand_mask*left_hand_mask).sum(dim=2) > 0.01 # dele two hand mask left_hand_mask = left_hand_mask & ~mask right_hand_mask = mask_output['right_hand_rgb_map'] right_hand_mask.clip_(0., 1.) right_hand_mask = (torch.tensor([1., 1., 1.], device = right_hand_mask.device) - right_hand_mask) right_hand_mask = (right_hand_mask*right_hand_mask).sum(dim=2) > 0.01 right_hand_mask = right_hand_mask & ~mask # save left_hand_mask = (left_hand_mask * 255).to(torch.uint8) cv.imwrite(output_dir + '/left_sleeve_mask/%08d.png' % item['data_idx'], left_hand_mask.cpu().numpy()) right_hand_mask = (right_hand_mask * 255).to(torch.uint8) cv.imwrite(output_dir + '/right_sleeve_mask/%08d.png' % item['data_idx'], right_hand_mask.cpu().numpy()) rgb_map = output['rgb_map'] rgb_map.clip_(0., 1.) rgb_map = (rgb_map * 255).to(torch.uint8).cpu().numpy() cv.imwrite(output_dir + '/rgb_map/%08d.jpg' % item['data_idx'], rgb_map) # 利用 r_hand_mask 和 l_hand_mask,将wo_hand图像中的mask部分覆盖rgb_map if 'rgb_map' in output_wo_hand and 'full_body_rgb_map' in mask_output and 'hand_only_rgb_map' in mask_output: rgb_map_wo_hand = output_wo_hand['rgb_map'] rgb_map_wo_hand.clip_(0., 1.) rgb_map_wo_hand = (rgb_map_wo_hand * 255).to(torch.uint8).cpu().numpy() r_mask = (r_hand_mask>128).cpu().numpy() l_mask = (l_hand_mask>128).cpu().numpy() mask = r_mask | l_mask mask = mask.astype(np.uint8) # 定义一个结构元素,可以调整其大小以改变膨胀的程度 kernel = np.ones((5, 5), np.uint8) # 应用膨胀操作 mask = cv.dilate(mask, kernel, iterations=3) mask = mask.astype(np.bool_) mask = np.expand_dims(mask, axis=2) # print('mask shape: ', mask.shape) import ipdb # ipdb.set_trace() mix = rgb_map_wo_hand.copy() * mask + rgb_map * ~mask cv.imwrite(output_dir + '/rgb_map_wo_hand/%08d.jpg' % item['data_idx'], mix) if 'torso_map' in output: os.makedirs(output_dir + '/torso_map', exist_ok = True) torso_map = output['torso_map'][:, :, 0] torso_map.clip_(0., 1.) torso_map = (torso_map * 255).to(torch.uint8) cv.imwrite(output_dir + '/torso_map/%08d.png' % item['data_idx'], torso_map.cpu().numpy()) if 'mask_map' in output: os.makedirs(output_dir + '/mask_map', exist_ok = True) mask_map = output['mask_map'][:, :, 0] mask_map.clip_(0., 1.) mask_map = (mask_map * 255).to(torch.uint8) cv.imwrite(output_dir + '/mask_map/%08d.png' % item['data_idx'], mask_map.cpu().numpy()) if self.opt['test'].get('save_tex_map', False): os.makedirs(output_dir + '/cano_tex_map', exist_ok = True) cano_tex_map = output['cano_tex_map'] cano_tex_map.clip_(0., 1.) cano_tex_map = (cano_tex_map * 255).to(torch.uint8) cv.imwrite(output_dir + '/cano_tex_map/%08d.jpg' % item['data_idx'], cano_tex_map.cpu().numpy()) if self.opt['test'].get('save_ply', False): if item['data_idx'] == 0: save_gaussians_as_ply(output_dir + '/posed_gaussians/%08d.ply' % item['data_idx'], output['posed_gaussians']) for k in output['posed_gaussians'].keys(): if isinstance(output['posed_gaussians'][k], torch.Tensor): output['posed_gaussians'][k] = output['posed_gaussians'][k].detach().cpu().numpy() np.savez(output_dir + '/posed_gaussians/%08d.npz' % item['data_idx'], **output['posed_gaussians']) np.savez(output_dir + ('/posed_params/%08d.npz' % item['data_idx']), betas=training_dataset.smpl_data['betas'].reshape([-1]).detach().cpu().numpy(), global_orient=item['global_orient'].reshape([-1]).detach().cpu().numpy(), transl=item['transl'].reshape([-1]).detach().cpu().numpy(), body_pose=item['body_pose'].reshape([-1]).detach().cpu().numpy()) if log_time: time_end.record() torch.cuda.synchronize() print('Saving images costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) print('Animating one frame costs %.4f secs' % (time_start_all.elapsed_time(time_end) / 1000.)) torch.cuda.empty_cache() def save_ckpt(self, path, save_optm = True): os.makedirs(path, exist_ok = True) net_dict = { 'epoch_idx': self.epoch_idx, 'iter_idx': self.iter_idx, 'avatar_net': self.avatar_net.state_dict(), } print('Saving networks to ', path + '/net.pt') torch.save(net_dict, path + '/net.pt') if save_optm: optm_dict = { 'avatar_net': self.optm.state_dict(), } print('Saving optimizers to ', path + '/optm.pt') torch.save(optm_dict, path + '/optm.pt') def load_ckpt(self, path, load_optm = True): print('Loading networks from ', path + '/net.pt') net_dict = torch.load(path + '/net.pt') if 'avatar_net' in net_dict: self.avatar_net.load_state_dict(net_dict['avatar_net']) else: print('[WARNING] Cannot find "avatar_net" from the network checkpoint!') epoch_idx = net_dict['epoch_idx'] iter_idx = net_dict['iter_idx'] if load_optm and os.path.exists(path + '/optm.pt'): print('Loading optimizers from ', path + '/optm.pt') optm_dict = torch.load(path + '/optm.pt') if 'avatar_net' in optm_dict: self.optm.load_state_dict(optm_dict['avatar_net']) else: print('[WARNING] Cannot find "avatar_net" from the optimizer checkpoint!') return epoch_idx, iter_idx if __name__ == '__main__': torch.manual_seed(31359) np.random.seed(31359) # torch.autograd.set_detect_anomaly(True) from argparse import ArgumentParser arg_parser = ArgumentParser() arg_parser.add_argument('-c', '--config_path', type = str, help = 'Configuration file path.') arg_parser.add_argument('-m', '--mode', type = str, help = 'Running mode.', default = 'train') args = arg_parser.parse_args() config.load_global_opt(args.config_path) if args.mode is not None: config.opt['mode'] = args.mode trainer = AvatarTrainer(config.opt) if config.opt['mode'] == 'train': if not safe_exists(config.opt['train']['net_ckpt_dir'] + '/pretrained') \ and not safe_exists(config.opt['train']['pretrained_dir'])\ and not safe_exists(config.opt['train']['prev_ckpt']): trainer.pretrain() trainer.train() elif config.opt['mode'] == 'test': trainer.test() else: raise NotImplementedError('Invalid running mode!')