pengc02 commited on
Commit
42fd375
1 Parent(s): 4c1e242
Files changed (1) hide show
  1. avatar_generator.py +597 -0
avatar_generator.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from calendar import c
2
+ import os
3
+ # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
4
+ # os.environ['TORCH_USE_CUDA_DSA'] = '1'
5
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
6
+ import yaml
7
+ import shutil
8
+ import collections
9
+ import torch
10
+ import torch.utils.data
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ import cv2 as cv
14
+ import glob
15
+ import datetime
16
+ import trimesh
17
+ from torch.utils.tensorboard import SummaryWriter
18
+ from tqdm import tqdm
19
+ import importlib
20
+ # import config
21
+ from omegaconf import OmegaConf
22
+ import json
23
+ import math
24
+ import cv2
25
+
26
+ # AnimatableGaussians part
27
+ from AnimatableGaussians.network.lpips import LPIPS
28
+ from AnimatableGaussians.dataset.dataset_pose import PoseDataset
29
+ import AnimatableGaussians.utils.net_util as net_util
30
+ # import AnimatableGaussians.utils.visualize_util as visualize_util
31
+ from AnimatableGaussians.utils.camera_dir import get_camera_dir
32
+ from AnimatableGaussians.utils.renderer import Renderer
33
+ from AnimatableGaussians.utils.net_util import to_cuda
34
+ from AnimatableGaussians.utils.obj_io import save_mesh_as_ply
35
+ from AnimatableGaussians.gaussians.obj_io import save_gaussians_as_ply
36
+ import AnimatableGaussians.config as ag_config
37
+
38
+ # Gaussian-Head-Avatar part
39
+ from GHA.config.config import config_reenactment
40
+ from GHA.lib.dataset.Dataset import ReenactmentDataset
41
+ from GHA.lib.dataset.DataLoaderX import DataLoaderX
42
+ from GHA.lib.module.GaussianHeadModule import GaussianHeadModule
43
+ from GHA.lib.module.SuperResolutionModule import SuperResolutionModule
44
+ from GHA.lib.module.CameraModule import CameraModule
45
+ from GHA.lib.recorder.Recorder import ReenactmentRecorder
46
+ from GHA.lib.apps.Reenactment import Reenactment
47
+ from GHA.lib.utils.graphics_utils import getWorld2View2, getProjectionMatrix
48
+
49
+ # cat utils
50
+ from calc_offline_rendering_param import calc_offline_rendering_param
51
+ from calc_offline_rendering_param import load_camera_data
52
+ from render_utils.lib.networks.smpl_torch import SmplTorch
53
+ from render_utils.lib.utils.gaussian_np_utils import load_gaussians_from_ply
54
+ from render_utils.stitch_body_and_head import load_body_params, load_face_params, get_smpl_verts_and_head_transformation, calc_livehead2livebody
55
+ from render_utils.stitch_funcs import soften_blending_mask,paste_back_with_linear_interp
56
+
57
+
58
+ import ipdb
59
+
60
+ class Avatar:
61
+ def __init__(self, config):
62
+ self.config = config
63
+ self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
64
+
65
+ # animateble gaussians part init
66
+ self.body = config.animatablegaussians
67
+ self.body.mode = 'test'
68
+ ag_config.set_opt(self.body)
69
+ avatar_module = self.body['model'].get('module', 'AnimatableGaussians.network.avatar')
70
+ print('Import AvatarNet from %s' % avatar_module)
71
+ AvatarNet = importlib.import_module(avatar_module).AvatarNet
72
+ self.avatar_net = AvatarNet(self.body.model).to(self.device)
73
+ self.random_bg_color = self.body['train'].get('random_bg_color', True)
74
+ self.bg_color = (1., 1., 1.)
75
+ self.bg_color_cuda = torch.from_numpy(np.asarray(self.bg_color)).to(torch.float32).to(self.device)
76
+ self.loss_weight = self.body['train']['loss_weight']
77
+ self.finetune_color = self.body['train']['finetune_color']
78
+ print('# Parameter number of AvatarNet is %d' % (sum([p.numel() for p in self.avatar_net.parameters()])))
79
+
80
+ # gaussian head avatar part init
81
+ self.head = config.gha
82
+ # cat utils part init
83
+ self.cat = config.cat
84
+
85
+ def build_dataset(self, body_pose_path=None, face_exp_path=None):
86
+ # build body_dataset
87
+
88
+ if body_pose_path is not None:
89
+ self.body['test']['pose_data']['data_path'] = body_pose_path
90
+ body_pose = np.load(body_pose_path, allow_pickle = True)
91
+ # print('body_pose keys:', body_pose.keys())
92
+ # print('body_pose shape:', body_pose['poses'].shape)
93
+ self.body['test']['pose_data']['frame_range'] = [0,body_pose['poses'].shape[0]]
94
+
95
+ dataset_module = self.body.get('dataset', 'MvRgbDatasetAvatarReX')
96
+ MvRgbDataset = importlib.import_module('AnimatableGaussians.dataset.dataset_mv_rgb').__getattribute__(dataset_module)
97
+ self.body_training_dataset = MvRgbDataset(**self.body['train']['data'], training = False)
98
+ if self.body['test'].get('n_pca', -1) >= 1:
99
+ self.body_training_dataset.compute_pca(n_components = self.body['test']['n_pca'])
100
+ if 'pose_data' in self.body.test:
101
+ testing_dataset = PoseDataset(**self.body['test']['pose_data'], smpl_shape = self.body_training_dataset.smpl_data['betas'][0])
102
+ dataset_name = testing_dataset.dataset_name
103
+ seq_name = testing_dataset.seq_name
104
+ else:
105
+ # throw an error
106
+ raise ValueError('No pose data in test config')
107
+ self.body_dataset = testing_dataset
108
+ iter_idx = self.load_ckpt(self.body['test']['prev_ckpt'], False)[1]
109
+
110
+
111
+ self.head_config = config_reenactment()
112
+ self.head_config.load(self.head.config_path)
113
+ if face_exp_path is not None:
114
+ self.head_config.cfg.dataset.exp_path = face_exp_path
115
+ self.head_config.freeze()
116
+ self.head_config = self.head_config.get_cfg()
117
+ # build face dataset
118
+ self.head_dataset = ReenactmentDataset(self.head_config.dataset)
119
+ self.head_dataloader = DataLoaderX(self.head_dataset, batch_size=1, shuffle=False, pin_memory=True)
120
+
121
+ # device = torch.device('cuda:%d' % cfg.gpu_id)
122
+
123
+ gaussianhead_state_dict = torch.load(self.head_config.load_gaussianhead_checkpoint, map_location=lambda storage, loc: storage)
124
+ self.gaussianhead = GaussianHeadModule(self.head_config.gaussianheadmodule,
125
+ xyz=gaussianhead_state_dict['xyz'],
126
+ feature=gaussianhead_state_dict['feature'],
127
+ landmarks_3d_neutral=gaussianhead_state_dict['landmarks_3d_neutral']).to(self.device)
128
+ self.gaussianhead.load_state_dict(gaussianhead_state_dict)
129
+
130
+ self.supres = SuperResolutionModule(self.head_config.supresmodule).to(self.device)
131
+ self.supres.load_state_dict(torch.load(self.head_config.load_supres_checkpoint, map_location=lambda storage, loc: storage))
132
+
133
+ self.head_camera = CameraModule()
134
+ self.head_recorder = ReenactmentRecorder(self.head_config.recorder)
135
+
136
+ def render_all(self):
137
+ # len = short one
138
+ lenth = min(len(self.body_dataset), len(self.head_dataloader))
139
+ # build a tqdm bar
140
+ for idx in tqdm(range(lenth)):
141
+ self.reder_frame(idx)
142
+
143
+ # for idx in range(lenth):
144
+ # self.reder_frame(idx)
145
+
146
+ def reder_frame(self, idx):
147
+ # 渲染身体和各种mask
148
+ body_output = self.build_body(idx)
149
+ # 计算头的渲染参数
150
+ head_param = self.build_param(idx,body_output)
151
+ # 渲染头
152
+ head_output = self.build_head(idx, head_param)
153
+ # 把头和身体拼接起来
154
+ body_rendering= body_output['rgb_map_wo_hand'].astype(np.float32) / 255.0
155
+ # save body_rendering
156
+ # cv.imwrite('./output' + '/body_rgb_%08d.jpg' % idx, (body_output['rgb_map']).astype(np.uint8))
157
+ # cv.imwrite('./output' + '/body_rgb_wo_hand%08d.jpg' % idx, (body_output['rgb_map_wo_hand']).astype(np.uint8))
158
+ body_mask = body_output['mask_map'].astype(np.float32) / 255.0
159
+ body_torso_mask = body_output['torso_map'].astype(np.float32) / 255.0
160
+ head_rendering = head_output['render_images'].astype(np.float32) / 255.0
161
+ head_blending_mask = head_output['render_bw'].astype(np.float32) / 255.0
162
+ body_head_blending_params = np.load(self.cat.body_head_blending_param_path)
163
+ head_offline_rendering_param = head_param
164
+ stitch_output = self.stich_head_body(body_rendering, body_mask, body_torso_mask, head_rendering, head_blending_mask, body_head_blending_params, head_offline_rendering_param)
165
+ cv.imwrite('./output' + '/%08d.jpg' % idx, stitch_output)
166
+
167
+ # 渲染手和手的mask
168
+
169
+ # 把手拼上去
170
+
171
+ return stitch_output
172
+
173
+ pass
174
+
175
+ def load_ckpt(self, path, load_optm = True):
176
+ print('Loading networks from ', path + '/net.pt')
177
+ net_dict = torch.load(path + '/net.pt')
178
+ if 'avatar_net' in net_dict:
179
+ self.avatar_net.load_state_dict(net_dict['avatar_net'])
180
+ else:
181
+ print('[WARNING] Cannot find "avatar_net" from the network checkpoint!')
182
+ epoch_idx = net_dict['epoch_idx']
183
+ iter_idx = net_dict['iter_idx']
184
+
185
+ # if load_optm and os.path.exists(path + '/optm.pt'):
186
+ # print('Loading optimizers from ', path + '/optm.pt')
187
+ # optm_dict = torch.load(path + '/optm.pt')
188
+ # if 'avatar_net' in optm_dict:
189
+ # self.optm.load_state_dict(optm_dict['avatar_net'])
190
+ # else:
191
+ # print('[WARNING] Cannot find "avatar_net" from the optimizer checkpoint!')
192
+
193
+ return epoch_idx, iter_idx
194
+
195
+ @torch.no_grad()
196
+ def build_body(self,idx):
197
+ self.avatar_net.eval()
198
+ geo_renderer = None
199
+ item_0 = self.body_dataset.getitem(0, training = False)
200
+ object_center = item_0['live_bounds'].mean(0)
201
+ global_orient = item_0['global_orient'].cpu().numpy() if isinstance(item_0['global_orient'], torch.Tensor) else item_0['global_orient']
202
+ use_pca = self.body['test'].get('n_pca', -1) >= 1
203
+ # set x and z to 0
204
+ global_orient[0] = 0
205
+ global_orient[2] = 0
206
+
207
+ global_orient = cv.Rodrigues(global_orient)[0]
208
+ time_start = torch.cuda.Event(enable_timing = True)
209
+ time_start_all = torch.cuda.Event(enable_timing = True)
210
+ time_end = torch.cuda.Event(enable_timing = True)
211
+
212
+ if self.body['test'].get('fix_hand', False):
213
+ self.avatar_net.generate_mean_hands()
214
+
215
+ img_scale = self.body['test'].get('img_scale', 1.0)
216
+ view_setting = self.body['test'].get('view_setting', 'free')
217
+ extr, intr, img_h, img_w = get_camera_dir(idx, object_center, global_orient, img_scale, view_setting)
218
+ w2c = extr
219
+ c2w = np.linalg.inv(w2c)
220
+ pos = c2w[:3, 3]
221
+ rot = c2w[:3, :3]
222
+ serializable_array_2d = [x.tolist() for x in rot]
223
+ camera_entry = {
224
+ 'width': int(img_w),
225
+ 'height': int(img_h),
226
+ 'position': pos.tolist(),
227
+ 'rotation': serializable_array_2d,
228
+ 'fy': float(intr[1, 1]),
229
+ 'fx': float(intr[0, 0]),
230
+ }
231
+
232
+ getitem_func = self.body_dataset.getitem_fast if hasattr(self.body_dataset, 'getitem_fast') else self.body_dataset.getitem
233
+ item = getitem_func(
234
+ idx,
235
+ training = False,
236
+ extr = extr,
237
+ intr = intr,
238
+ img_w = img_w,
239
+ img_h = img_h
240
+ )
241
+ items = to_cuda(item, add_batch = False)
242
+
243
+ if 'smpl_pos_map' not in items:
244
+ self.avatar_net.get_pose_map(items)
245
+
246
+ # pca
247
+ if use_pca:
248
+ mask = self.body_training_dataset.pos_map_mask
249
+ live_pos_map = items['smpl_pos_map'].permute(1, 2, 0).cpu().numpy()
250
+ front_live_pos_map, back_live_pos_map = np.split(live_pos_map, [3], 2)
251
+ pose_conds = front_live_pos_map[mask]
252
+ new_pose_conds = self.body_training_dataset.transform_pca(pose_conds, sigma_pca = float(self.body['test'].get('sigma_pca', 2.)))
253
+ front_live_pos_map[mask] = new_pose_conds
254
+ live_pos_map = np.concatenate([front_live_pos_map, back_live_pos_map], 2)
255
+ items.update({
256
+ 'smpl_pos_map_pca': torch.from_numpy(live_pos_map).to(self.device).permute(2, 0, 1)
257
+ })
258
+
259
+ # print items
260
+ # print(items.keys())
261
+ # print(items.values())
262
+ # exit()
263
+
264
+ # get render result
265
+ output = self.avatar_net.render(items, bg_color = self.bg_color, use_pca = use_pca)
266
+ output_wo_hand = self.avatar_net.render_wo_hand(items, bg_color = self.bg_color, use_pca = use_pca)
267
+ mask_output = self.avatar_net.render_mask(items, bg_color = self.bg_color, use_pca = use_pca)
268
+
269
+ # do some postprocess
270
+ rgb_map_wo_hand = output_wo_hand['rgb_map']
271
+
272
+ full_body_mask = mask_output['full_body_rgb_map']
273
+ full_body_mask.clip_(0., 1.)
274
+ full_body_mask = (full_body_mask * 255).to(torch.uint8)
275
+
276
+ hand_only_mask = mask_output['hand_only_rgb_map']
277
+ hand_only_mask.clip_(0., 1.)
278
+ hand_only_mask = (hand_only_mask * 255).to(torch.uint8)
279
+
280
+ # build the covered hand mask and the hand visualbility flag
281
+ body_red_mask = (mask_output['full_body_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['full_body_rgb_map'].device))
282
+ body_red_mask = (body_red_mask*body_red_mask).sum(dim=2) < 0.01 # need save
283
+
284
+ hand_red_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['hand_only_rgb_map'].device))
285
+ hand_red_mask = (hand_red_mask*hand_red_mask).sum(dim=2) < 0.0
286
+ if_mask_r_hand = abs(body_red_mask.sum() - hand_red_mask.sum()) / hand_red_mask.sum() > 0.95
287
+ if_mask_r_hand = if_mask_r_hand.cpu().numpy()
288
+
289
+ body_blue_mask = (mask_output['full_body_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['full_body_rgb_map'].device))
290
+ body_blue_mask = (body_blue_mask*body_blue_mask).sum(dim=2) < 0.01 # need save
291
+
292
+ hand_blue_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['hand_only_rgb_map'].device))
293
+ hand_blue_mask = (hand_blue_mask*hand_blue_mask).sum(dim=2) < 0.01
294
+
295
+ if_mask_l_hand = abs(body_blue_mask.sum() - hand_blue_mask.sum()) / hand_blue_mask.sum() > 0.95
296
+ if_mask_l_hand = if_mask_l_hand.cpu().numpy()
297
+
298
+ # 保存左右手被遮挡部分的mask
299
+ red_mask = hand_red_mask ^ (hand_red_mask & body_red_mask)
300
+ blue_mask = hand_blue_mask ^ (hand_blue_mask & body_blue_mask)
301
+ all_mask = red_mask | blue_mask
302
+
303
+ all_mask = (all_mask * 255).to(torch.uint8)
304
+ r_hand_mask = (body_red_mask * 255).to(torch.uint8)
305
+ l_hand_mask = (body_blue_mask * 255).to(torch.uint8)
306
+ hand_visual = [if_mask_r_hand, if_mask_l_hand]
307
+
308
+ # build sleeve mask
309
+ mask = (r_hand_mask>128) | (l_hand_mask>128)| (all_mask>128)
310
+ mask = mask.cpu().numpy().astype(np.uint8)
311
+ # 定义一个结构元素,可以调整其大小以改变膨胀的程度
312
+ kernel = np.ones((5, 5), np.uint8)
313
+ # 应用膨胀操作
314
+ mask = cv.dilate(mask, kernel, iterations=3)
315
+ mask = torch.tensor(mask).to(self.device)
316
+
317
+ left_hand_mask = mask_output['left_hand_rgb_map']
318
+ left_hand_mask.clip_(0., 1.)
319
+ # non white part is mask
320
+ left_hand_mask = (torch.tensor([1., 1., 1.], device = left_hand_mask.device) - left_hand_mask)
321
+ left_hand_mask = (left_hand_mask*left_hand_mask).sum(dim=2) > 0.01
322
+ # dele two hand mask
323
+ left_hand_mask = left_hand_mask & ~mask
324
+
325
+ right_hand_mask = mask_output['right_hand_rgb_map']
326
+ right_hand_mask.clip_(0., 1.)
327
+ right_hand_mask = (torch.tensor([1., 1., 1.], device = right_hand_mask.device) - right_hand_mask)
328
+ right_hand_mask = (right_hand_mask*right_hand_mask).sum(dim=2) > 0.01
329
+ right_hand_mask = right_hand_mask & ~mask
330
+
331
+ left_sleeve_mask = (left_hand_mask * 255).to(torch.uint8)
332
+ right_sleeve_mask = (right_hand_mask * 255).to(torch.uint8)
333
+
334
+ # 利用 r_hand_mask 和 l_hand_mask,将wo_hand图像中的mask部分覆盖rgb_map
335
+ rgb_map = output['rgb_map']
336
+ rgb_map.clip_(0., 1.)
337
+ rgb_map = (rgb_map * 255).to(torch.uint8).cpu().numpy()
338
+
339
+ rgb_map_wo_hand = output_wo_hand['rgb_map']
340
+ rgb_map_wo_hand.clip_(0., 1.)
341
+ rgb_map_wo_hand = (rgb_map_wo_hand * 255).to(torch.uint8).cpu().numpy()
342
+
343
+ r_mask = (r_hand_mask>128).cpu().numpy()
344
+ l_mask = (l_hand_mask>128).cpu().numpy()
345
+ mask = r_mask | l_mask
346
+ mask = mask.astype(np.uint8)
347
+ # 定义一个结构元素,可以调整其大小以改变膨胀的程度
348
+ kernel = np.ones((5, 5), np.uint8)
349
+ # 应用膨胀操作
350
+ mask = cv.dilate(mask, kernel, iterations=3)
351
+ mask = mask.astype(np.bool_)
352
+ mask = np.expand_dims(mask, axis=2)
353
+ # get the final rgb_map without hand
354
+ mix = rgb_map_wo_hand.copy() * mask + rgb_map * ~mask
355
+
356
+ torso_map = output['torso_map'][:, :, 0]
357
+ torso_map.clip_(0., 1.)
358
+ torso_map = (torso_map * 255).to(torch.uint8).cpu().numpy()
359
+
360
+
361
+ mask_map = output['mask_map'][:, :, 0]
362
+ mask_map.clip_(0., 1.)
363
+ mask_map = (mask_map * 255).to(torch.uint8).cpu().numpy()
364
+
365
+ output={
366
+ # smpl
367
+ 'betas':self.body_training_dataset.smpl_data['betas'].reshape([-1]).detach().cpu().numpy(),
368
+ 'global_orient':item['global_orient'].reshape([-1]).detach().cpu().numpy(),
369
+ 'transl':item['transl'].reshape([-1]).detach().cpu().numpy(),
370
+ 'body_pose':item['body_pose'].reshape([-1]).detach().cpu().numpy(),
371
+
372
+ # camera
373
+ 'extr':extr,
374
+ 'intr':intr,
375
+ 'img_h':img_h,
376
+ 'img_w':img_w,
377
+ 'camera_entry':camera_entry,
378
+
379
+ # rgb and masks
380
+ 'rgb_map':rgb_map,
381
+ 'rgb_map_wo_hand':mix,
382
+ 'torso_map':torso_map,
383
+ 'mask_map':mask_map,
384
+ 'all_mask':all_mask,
385
+ 'left_sleeve_mask':left_sleeve_mask,
386
+ 'right_sleeve_mask':right_sleeve_mask,
387
+ 'hand_visual':hand_visual
388
+ }
389
+
390
+ return output
391
+
392
+ def build_param(self,idx,body_output):
393
+ head_gaussians = load_gaussians_from_ply(self.cat.ref_head_gaussian_path)
394
+ head_pose, head_scale, id_coeff, exp_coeff = load_face_params(self.cat.ref_head_param_path)
395
+ body_head_blending_params = np.load(self.cat.body_head_blending_param_path)
396
+ smplx_to_faceverse = body_head_blending_params['smplx_to_faceverse']
397
+ residual_transf = body_head_blending_params['residual_transf']
398
+ head_color_bw = body_head_blending_params['head_color_bw']
399
+
400
+ smpl = SmplTorch(model_file='./AnimatableGaussians/smpl_files/smplx/SMPLX_NEUTRAL.npz')
401
+ global_orient, transl, body_pose, betas = body_output['global_orient'], body_output['transl'], body_output['body_pose'], body_output['betas']
402
+ smpl_verts, head_joint_transfmat = get_smpl_verts_and_head_transformation(
403
+ smpl, global_orient, body_pose, transl, betas)
404
+ livehead2livebody = calc_livehead2livebody(head_pose, smplx_to_faceverse, head_joint_transfmat)
405
+ total_transf = np.matmul(livehead2livebody, residual_transf)
406
+
407
+ cam, image_size = load_camera_data(body_output['camera_entry'])
408
+ cam_extr = np.matmul(cam[0], total_transf)
409
+ cam_intr = np.copy(cam[1])
410
+
411
+ pts = np.copy(head_gaussians.xyz)
412
+ pts_proj = np.matmul(pts, cam_extr[:3, :3].transpose()) + cam_extr[:3, 3]
413
+ pts_proj = np.matmul(pts_proj, cam_intr.transpose())
414
+ pts_proj = pts_proj / pts_proj[:, 2:]
415
+
416
+ pts_min, pts_max = np.min(pts_proj, axis=0), np.max(pts_proj, axis=0)
417
+ pts_center = (pts_min + pts_max) // 2
418
+ pts_size = np.max(pts_max - pts_min)
419
+ tgt_pts_size = 350
420
+ tgt_image_size = 512
421
+ zoom_scale = tgt_pts_size / pts_size
422
+ cam_intr_zoom = np.copy(cam_intr)
423
+ cam_intr_zoom[:2] *= zoom_scale
424
+ cam_intr_zoom[0, 2] = cam_intr_zoom[0, 2] - (pts_center[0]*zoom_scale - tgt_image_size/2)
425
+ cam_intr_zoom[1, 2] = cam_intr_zoom[1, 2] - (pts_center[1]*zoom_scale - tgt_image_size/2)
426
+
427
+ output = {
428
+ 'cam_extr':cam_extr,
429
+ 'cam_intr':cam_intr,
430
+ 'image_size':image_size,
431
+ 'cam_intr_zoom':cam_intr_zoom,
432
+ 'zoom_image_size':[tgt_image_size, tgt_image_size],
433
+ 'zoom_center':pts_center,
434
+ 'zoom_scale':zoom_scale,
435
+ 'head_pose':head_pose,
436
+ 'head_scale':head_scale,
437
+ 'head_color_bw':head_color_bw,
438
+ }
439
+
440
+ return output
441
+
442
+ def build_head(self, idx, head_offline_rendering_param):
443
+ # head_offline_rendering_param = np.load(offline_rendering_param_fpath)
444
+ cam_extr = head_offline_rendering_param['cam_extr']
445
+ cam_intr = head_offline_rendering_param['cam_intr']
446
+ cam_intr_zoom = head_offline_rendering_param['cam_intr_zoom']
447
+ zoom_image_size = head_offline_rendering_param['zoom_image_size']
448
+ head_pose = head_offline_rendering_param['head_pose']
449
+ head_scale = head_offline_rendering_param['head_scale']
450
+ head_color_bw = head_offline_rendering_param['head_color_bw']
451
+ zoom_scale = head_offline_rendering_param['zoom_scale']
452
+ head_pose = torch.from_numpy(head_pose.astype(np.float32)).to(self.device)
453
+ head_color_bw = torch.from_numpy(head_color_bw.astype(np.float32)).to(self.device)
454
+ render_size = 512
455
+
456
+ # data = self.head_dataloader[idx]
457
+ data = self.head_dataset[idx]
458
+ # add batch dim
459
+ data = {k: v.unsqueeze(0) for k, v in data.items() if isinstance(v, torch.Tensor)}
460
+ # print(data.keys())
461
+
462
+ new_gs_camera_param_dict = self.prepare_camera_data_for_gs_rendering(cam_extr, cam_intr_zoom, render_size, render_size)
463
+ for k in new_gs_camera_param_dict.keys():
464
+ if isinstance(new_gs_camera_param_dict[k], torch.Tensor):
465
+ new_gs_camera_param_dict[k] = new_gs_camera_param_dict[k].unsqueeze(0).to(self.device)
466
+ new_gs_camera_param_dict['pose'] = head_pose.unsqueeze(0).to(self.device)
467
+
468
+ to_cuda = ['images', 'intrinsics', 'extrinsics', 'world_view_transform', 'projection_matrix', 'full_proj_transform', 'camera_center',
469
+ 'pose', 'scale', 'exp_coeff', 'pose_code']
470
+ for data_item in to_cuda:
471
+ data[data_item] = data[data_item].to(device=self.device)
472
+
473
+ data.update(new_gs_camera_param_dict)
474
+
475
+ with torch.no_grad():
476
+ data = self.gaussianhead.generate(data)
477
+ data = self.head_camera.render_gaussian(data, 512)
478
+ render_images = data['render_images']
479
+ supres_images = self.supres(render_images)
480
+ data['supres_images'] = supres_images
481
+ data['bg_color'] = torch.zeros([1, 32], device=self.device, dtype=torch.float32)
482
+ data['color_bk'] = data.pop('color')
483
+ data['color'] = torch.ones_like(data['color_bk']) * head_color_bw.reshape([1, -1, 1]) * 2.0
484
+ data['color'][:, :, 1] = 1
485
+ data['color'] = torch.clamp(data['color'], 0., 1.)
486
+ data = self.head_camera.render_gaussian(data, render_size)
487
+ render_bw = data['render_images'][:, :3, :, :]
488
+ data['color'] = data.pop('color_bk')
489
+ data['render_bw'] = render_bw
490
+
491
+ supres_image = data['supres_images'][0].permute(1, 2, 0).detach().cpu().numpy()
492
+ supres_image = (supres_image * 255).astype(np.uint8)[:,:,::-1]
493
+
494
+ render_bw = data['render_bw'][0].permute(1, 2, 0).detach().cpu().numpy()
495
+ render_bw = np.clip(render_bw * 255, 0, 255).astype(np.uint8)[:,:,::-1]
496
+ render_bw = cv2.resize(render_bw, (supres_image.shape[0], supres_image.shape[1]))
497
+
498
+
499
+ output = {
500
+ 'render_images':supres_image,
501
+ 'render_bw':render_bw,
502
+ }
503
+
504
+ return output
505
+
506
+ def prepare_camera_data_for_gs_rendering(self, extrinsic, intrinsic, original_resolution, new_resolution):
507
+ extrinsic = np.copy(extrinsic)
508
+ intrinsic = np.copy(intrinsic)
509
+ new_intrinsic = np.copy(intrinsic)
510
+ new_intrinsic[:2] *= new_resolution / original_resolution
511
+
512
+ intrinsic[0, 0] = intrinsic[0, 0] * 2 / original_resolution
513
+ intrinsic[0, 2] = intrinsic[1, 2] * 2 / original_resolution - 1
514
+ intrinsic[1, 1] = intrinsic[1, 1] * 2 / original_resolution
515
+ intrinsic[1, 2] = intrinsic[1, 2] * 2 / original_resolution - 1
516
+ fovx = 2 * math.atan(1 / intrinsic[0, 0])
517
+ fovy = 2 * math.atan(1 / intrinsic[1, 1])
518
+
519
+ world_view_transform = torch.tensor(getWorld2View2(extrinsic[:3, :3].transpose(), extrinsic[:3, 3])).transpose(0, 1)
520
+ projection_matrix = getProjectionMatrix(
521
+ znear=0.01, zfar=100, fovX=None, fovY=None,
522
+ K=new_intrinsic, img_h=new_resolution, img_w=new_resolution).transpose(0,1)
523
+ full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
524
+ camera_center = world_view_transform.inverse()[3, :3]
525
+
526
+ c2w = np.linalg.inv(extrinsic)
527
+ viewdir = np.matmul(c2w[:3, :3], np.array([0, 0, -1], np.float32).reshape([3, 1])).reshape([-1])
528
+ viewdir = torch.from_numpy(viewdir.astype(np.float32))
529
+
530
+ return {
531
+ 'extrinsics': torch.from_numpy(extrinsic.astype(np.float32)),
532
+ 'intrinsics': torch.from_numpy(intrinsic.astype(np.float32)),
533
+ 'viewdir': viewdir,
534
+ 'fovx': torch.Tensor([fovx]),
535
+ 'fovy': torch.Tensor([fovy]),
536
+ 'world_view_transform': world_view_transform,
537
+ 'projection_matrix': projection_matrix,
538
+ 'full_proj_transform': full_proj_transform,
539
+ 'camera_center': camera_center
540
+ }
541
+
542
+ def stich_head_body(self,body_rendering,body_mask,body_torso_mask,head_rendering,head_blending_mask,body_head_blending_params,head_offline_rendering_param):
543
+ color_transfer = body_head_blending_params['color_transfer']
544
+ zoom_image_size = head_offline_rendering_param['zoom_image_size']
545
+ zoom_center = head_offline_rendering_param['zoom_center']
546
+ zoom_scale = head_offline_rendering_param['zoom_scale']
547
+
548
+
549
+ if len(body_mask.shape) == 3:
550
+ body_mask = body_mask[:, :, 0]
551
+ if len(body_torso_mask.shape) == 3:
552
+ body_torso_mask = body_torso_mask[:, :, 0]
553
+
554
+ head_rendering = cv2.resize(head_rendering, (int(zoom_image_size[0]), int(zoom_image_size[1])))
555
+ head_blending_mask = cv2.resize(head_blending_mask, (int(zoom_image_size[0]), int(zoom_image_size[1])))
556
+ head_mask = head_blending_mask[:, :, 1]
557
+ head_blending_mask = head_blending_mask[:, :, 0]
558
+ head_blending_mask = soften_blending_mask(head_blending_mask, head_mask)
559
+
560
+ pasteback_center = zoom_center
561
+ pasteback_scale = zoom_scale
562
+
563
+ head_rendering_back = paste_back_with_linear_interp(pasteback_scale, pasteback_center, head_rendering, [body_rendering.shape[1], body_rendering.shape[0]])
564
+ head_blending_mask_back = paste_back_with_linear_interp(pasteback_scale, pasteback_center, head_blending_mask, [body_rendering.shape[1], body_rendering.shape[0]])
565
+ head_mask_back = paste_back_with_linear_interp(pasteback_scale, pasteback_center, head_mask, [body_rendering.shape[1], body_rendering.shape[0]])
566
+ # head_blending_mask_back *= body_mask
567
+ # head_mask_back *= body_mask
568
+ head_blending_mask_back = head_blending_mask_back * (1 - body_torso_mask)
569
+
570
+ head_rendering_back_shape = head_rendering_back.shape
571
+ head_rendering_back = np.matmul(head_rendering_back.reshape(-1, 3), color_transfer[:3, :3].transpose()) + color_transfer[:3, 3][None]
572
+ head_rendering_back = head_rendering_back.reshape(head_rendering_back_shape)
573
+ head_rendering_back = head_rendering_back * head_mask_back[:, :, None] + (1 - head_mask_back[:, :, None])
574
+
575
+ body_rendering = body_rendering * (1 - head_blending_mask_back[:, :, None]) + head_rendering_back * head_blending_mask_back[:, :, None]
576
+
577
+ return np.uint8(np.clip(body_rendering, 0, 1)*255)
578
+
579
+ # def build_hand(betas,poses,camera):
580
+
581
+ # # build hand here
582
+
583
+ # output = {
584
+ # 'hand_render':render,
585
+ # 'hand_mask':mask,
586
+
587
+ # }
588
+
589
+ # return output
590
+
591
+
592
+ if __name__ == '__main__':
593
+ conf = OmegaConf.load('configs/example.yaml')
594
+ avatar = Avatar(conf)
595
+ avatar.build_dataset()
596
+ # avatar.test_body()
597
+ avatar.render_all()