pengc02 commited on
Commit
44925e5
1 Parent(s): 863d3de
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoints/face_0929/gaussianhead_latest filter=lfs diff=lfs merge=lfs -text
37
+ checkpoints/face_0929/supres_latest filter=lfs diff=lfs merge=lfs -text
38
+ checkpoints/face_0929/delta_poses_latest filter=lfs diff=lfs merge=lfs -text
39
+ checkpoints/pos_map_ys/body_mix/smpl_pos_map/cano_smpl_nml_map.exr filter=lfs diff=lfs merge=lfs -text
40
+ checkpoints/pos_map_ys/body_mix/smpl_pos_map/cano_smpl_pos_map.exr filter=lfs diff=lfs merge=lfs -text
41
+ checkpoints/ref_gaussian/head/000000.ply filter=lfs diff=lfs merge=lfs -text
42
+ checkpoints/ filter=lfs diff=lfs merge=lfs -text
__MACOSX/._AnimatableGaussians ADDED
Binary file (220 Bytes). View file
 
__MACOSX/._GHA ADDED
Binary file (220 Bytes). View file
 
__MACOSX/._avatar_generator.py ADDED
Binary file (220 Bytes). View file
 
__MACOSX/._calc_offline_rendering_param.py ADDED
Binary file (220 Bytes). View file
 
__MACOSX/._checkpoints ADDED
Binary file (220 Bytes). View file
 
__MACOSX/._configs ADDED
Binary file (220 Bytes). View file
 
__MACOSX/._gradio_page.py ADDED
Binary file (220 Bytes). View file
 
__MACOSX/._render_utils ADDED
Binary file (220 Bytes). View file
 
__MACOSX/._test_data ADDED
Binary file (220 Bytes). View file
 
__MACOSX/AnimatableGaussians/._.DS_Store ADDED
Binary file (120 Bytes). View file
 
__MACOSX/checkpoints/._pos_map_ys ADDED
Binary file (220 Bytes). View file
 
__MACOSX/test_data/._.DS_Store ADDED
Binary file (120 Bytes). View file
 
app.py CHANGED
@@ -1,7 +1,147 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import moviepy.editor as mpy
3
+ import numpy as np
4
+ import os
5
+ from omegaconf import OmegaConf
6
+ from tqdm import tqdm
7
+ import shutil
8
+ import time
9
+ from avatar_generator import Avatar
10
 
11
+ # # 指定保存文件的目录
12
+ # SAVE_DIR = "./uploaded_files"
13
+ # os.makedirs(SAVE_DIR, exist_ok=True) # 创建目录(如果不存在)
14
+ # 全局变量,用于控制任务是否应当终止
15
+ should_stop = False
16
 
17
+
18
+ # 定义逐帧处理的函数
19
+ def process_files(file1, file2):
20
+ global should_stop
21
+ should_stop = False # 重置停止标志
22
+
23
+ yield None, None, None, "Starting Process!"
24
+
25
+ file_path1 = file1.name
26
+ file_path2 = file2.name
27
+ pose_data = np.load(file_path1)
28
+ exp_data = np.load(file_path2)
29
+
30
+ # save
31
+ pose_path = './test_data/AMASS/online_test_pose_data.npz'
32
+ exp_path = './test_data/face_exp/online_test_exp_data.npy'
33
+
34
+ np.savez(pose_path, **pose_data)
35
+ np.save(exp_path, exp_data)
36
+
37
+
38
+ # with open(file1.name, 'rb') as fsrc:
39
+ # with open(file_path1, 'wb') as fdst:
40
+ # shutil.copyfileobj(fsrc, fdst)
41
+
42
+ # with open(file2.name, 'rb') as fsrc:
43
+ # with open(file_path2, 'wb') as fdst:
44
+ # shutil.copyfileobj(fsrc, fdst)
45
+
46
+ conf = OmegaConf.load('configs/example.yaml')
47
+ avatar = Avatar(conf)
48
+ avatar.build_dataset(pose_path, exp_path)
49
+
50
+ lenth = min(len(avatar.body_dataset), len(avatar.head_dataloader),20)
51
+ output_frames = []
52
+
53
+ start_time = time.time()
54
+ for idx in tqdm(range(lenth)):
55
+ if should_stop:
56
+ yield None, None, None, None
57
+ break # 任务应当终止时跳出循环
58
+ frame = avatar.reder_frame(idx)
59
+ # rgb2bgr
60
+ frame = frame[..., ::-1]
61
+ output_frames.append(frame)
62
+
63
+ elapsed_time = time.time() - start_time
64
+ estimated_total_time = (elapsed_time / (idx + 1)) * lenth
65
+ remaining_time = estimated_total_time - elapsed_time
66
+
67
+ yield frame, None, (idx + 1) / lenth * 100, f"{elapsed_time:.2f} sec/{estimated_total_time:.2f} sec"
68
+
69
+ if not should_stop:
70
+ output_path = "./output/output_video.mp4"
71
+ final_video = mpy.ImageSequenceClip(output_frames, fps=25)
72
+ final_video.write_videofile(output_path, codec='libx264')
73
+
74
+ yield output_frames[-1], output_path, 100.0, "Processing completed!"
75
+
76
+ # 清除操作
77
+ def clear_files():
78
+ global should_stop
79
+ should_stop = True # 设置停止标志
80
+
81
+ # 返回空值以清空界面元素
82
+ return None, None, None, None, None, None
83
+
84
+
85
+ # 创建 Gradio 接口
86
+ with gr.Blocks(css="""
87
+ .equal-height {
88
+ height: 425px; /* 设置为你希望的高度 */
89
+ display: flex;
90
+ flex-direction: column;
91
+ justify-content: center;
92
+ align-items: center;
93
+ }
94
+ .equal-height input {
95
+ height: 100%; /* 输入框占满整个容器高度 */
96
+ }
97
+ .output-container {
98
+ height: 400px; /* 输出框的高度 */
99
+ }
100
+ .custom-text {
101
+ height: 80px; /* 输出框的高度 */
102
+ }
103
+ """) as demo:
104
+ with gr.Row():
105
+ # 左侧列,用于放置文件输入
106
+ with gr.Column(scale=1):
107
+ with gr.Row(elem_classes="equal-height"):
108
+ file_input1 = gr.File(label="Upload File (Body Pose)")
109
+ file_input2 = gr.File(label="Upload File (Face EXP)")
110
+
111
+ with gr.Column(scale=2):
112
+ with gr.Row():
113
+ # 中间列,用于放置帧输出
114
+ with gr.Column(scale=1):
115
+ frame_output = gr.Image(label="Current Frame Output", elem_classes="output-container") # 输出当前帧图像
116
+ # 右侧列,用于放置视频输出
117
+ with gr.Column(scale=1):
118
+ video_output = gr.Video(label="Processed Video Output", elem_classes="output-container") # 输出视频
119
+ # progress_bar = gr.Label(label="Progress")
120
+ with gr.Row():
121
+ with gr.Column(scale=2):
122
+ progress_bar = gr.Slider(visible=True, minimum=0, maximum=100, step=1, label="Progress %",elem_classes="custom-text") # 使用Slider模拟进度条
123
+ with gr.Column(scale=1):
124
+ output_time = gr.Textbox(label='Processing Time/Estimate Time', elem_classes="custom-text")
125
+ # time_label = gr.Label(value="", label="Estimated Time Remaining", elem_classes="custom-label")
126
+ # with gr.Row():
127
+ # progress_bar = gr.Progress() # 添加进度条
128
+ with gr.Row():
129
+ process_button = gr.Button("Start Processing Files")
130
+ clear_button = gr.Button("Clear or Stop Processing")
131
+
132
+ # 定义按钮的功能
133
+ process_button.click(
134
+ fn=process_files,
135
+ inputs=[file_input1, file_input2],
136
+ outputs=[frame_output, video_output, progress_bar, output_time],
137
+ show_progress=False
138
+ )
139
+
140
+ clear_button.click(
141
+ fn= clear_files,
142
+ inputs=[],
143
+ outputs=[file_input1, file_input2, frame_output, video_output, progress_bar, output_time]
144
+ )
145
+
146
+ # 启动应用
147
+ demo.launch()
avatar.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from calendar import c
2
+ import os
3
+ # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
4
+ # os.environ['TORCH_USE_CUDA_DSA'] = '1'
5
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
6
+ import yaml
7
+ import shutil
8
+ import collections
9
+ import torch
10
+ import torch.utils.data
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ import cv2 as cv
14
+ import glob
15
+ import datetime
16
+ import trimesh
17
+ from torch.utils.tensorboard import SummaryWriter
18
+ from tqdm import tqdm
19
+ import importlib
20
+ # import config
21
+ from omegaconf import OmegaConf
22
+ import json
23
+
24
+ # AnimatableGaussians part
25
+ from AnimatableGaussians.network.lpips import LPIPS
26
+ from AnimatableGaussians.dataset.dataset_pose import PoseDataset
27
+ import AnimatableGaussians.utils.net_util as net_util
28
+ import AnimatableGaussians.utils.visualize_util as visualize_util
29
+ from AnimatableGaussians.utils.renderer import Renderer
30
+ from AnimatableGaussians.utils.net_util import to_cuda
31
+ from AnimatableGaussians.utils.obj_io import save_mesh_as_ply
32
+ from AnimatableGaussians.gaussians.obj_io import save_gaussians_as_ply
33
+ import AnimatableGaussians.config as ag_config
34
+
35
+ # Gaussian-Head-Avatar part
36
+ from GHA.config.config import config_reenactment
37
+ from GHA.lib.dataset.Dataset import ReenactmentDataset
38
+ from GHA.lib.dataset.DataLoaderX import DataLoaderX
39
+ from GHA.lib.module.GaussianHeadModule import GaussianHeadModule
40
+ from GHA.lib.module.SuperResolutionModule import SuperResolutionModule
41
+ from GHA.lib.module.CameraModule import CameraModule
42
+ from GHA.lib.recorder.Recorder import ReenactmentRecorder
43
+ from GHA.lib.apps.Reenactment import Reenactment
44
+
45
+ # cat utils
46
+ from calc_offline_rendering_param import calc_offline_rendering_param
47
+
48
+ import ipdb
49
+
50
+ class Avatar:
51
+ def __init__(self, config):
52
+ self.config = config
53
+ self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
54
+
55
+ # animateble gaussians part init
56
+ self.body = config.animatablegaussians
57
+ self.body.mode = 'test'
58
+ ag_config.set_opt(self.body)
59
+ avatar_module = self.body['model'].get('module', 'AnimatableGaussians.network.avatar')
60
+ print('Import AvatarNet from %s' % avatar_module)
61
+ AvatarNet = importlib.import_module(avatar_module).AvatarNet
62
+ self.avatar_net = AvatarNet(self.body.model).to(self.device)
63
+ self.random_bg_color = self.body['train'].get('random_bg_color', True)
64
+ self.bg_color = (1., 1., 1.)
65
+ self.bg_color_cuda = torch.from_numpy(np.asarray(self.bg_color)).to(torch.float32).to(self.device)
66
+ self.loss_weight = self.body['train']['loss_weight']
67
+ self.finetune_color = self.body['train']['finetune_color']
68
+ print('# Parameter number of AvatarNet is %d' % (sum([p.numel() for p in self.avatar_net.parameters()])))
69
+
70
+ # gaussian head avatar part init
71
+ self.head = config.gha
72
+ self.head_config = config_reenactment()
73
+ self.head_config.load(self.head.config_path)
74
+ self.head_config = self.head_config.get_cfg()
75
+
76
+ # cat utils part init
77
+ self.cat = config.cat
78
+
79
+ @torch.no_grad()
80
+ def test_body(self):
81
+ # run the animatable gaussian test
82
+ self.avatar_net.eval()
83
+ dataset_module = self.body.get('dataset', 'MvRgbDatasetAvatarReX')
84
+ MvRgbDataset = importlib.import_module('AnimatableGaussians.dataset.dataset_mv_rgb').__getattribute__(dataset_module)
85
+ training_dataset = MvRgbDataset(**self.body['train']['data'], training = False)
86
+ if self.body['test'].get('n_pca', -1) >= 1:
87
+ training_dataset.compute_pca(n_components = self.body['test']['n_pca'])
88
+ if 'pose_data' in self.body.test:
89
+ testing_dataset = PoseDataset(**self.body['test']['pose_data'], smpl_shape = training_dataset.smpl_data['betas'][0])
90
+ dataset_name = testing_dataset.dataset_name
91
+ seq_name = testing_dataset.seq_name
92
+ else:
93
+ # throw an error
94
+ raise ValueError('No pose data in test config')
95
+
96
+ self.dataset = testing_dataset
97
+ # iter_idx = self.load_ckpt(self.body['test']['prev_ckpt'], False)[1]
98
+
99
+ output_dir = self.body['test'].get('output_dir', None)
100
+ if output_dir is None:
101
+ raise ValueError('No output_dir in test config')
102
+ use_pca = self.body['test'].get('n_pca', -1) >= 1
103
+ if use_pca:
104
+ output_dir += '/pca_%d_sigma_%.2f' % (self.body['test'].get('n_pca', -1), float(self.body['test'].get('sigma_pca', 1.)))
105
+ else:
106
+ output_dir += '/vanilla'
107
+ print('# Output dir: \033[1;31m%s\033[0m' % output_dir)
108
+
109
+ os.makedirs(output_dir + '/live_skeleton', exist_ok = True)
110
+ os.makedirs(output_dir + '/rgb_map', exist_ok = True)
111
+ os.makedirs(output_dir + '/rgb_map_wo_hand', exist_ok = True)
112
+ os.makedirs(output_dir + '/torso_map', exist_ok = True)
113
+ os.makedirs(output_dir + '/mask_map', exist_ok = True)
114
+ os.makedirs(output_dir + '/posed_gaussians', exist_ok = True)
115
+ os.makedirs(output_dir + '/posed_params', exist_ok = True)
116
+ os.makedirs(output_dir + '/full_body_mask', exist_ok = True)
117
+ os.makedirs(output_dir + '/hand_only_mask', exist_ok = True)
118
+
119
+ geo_renderer = None
120
+ item_0 = self.dataset.getitem(0, training = False)
121
+ object_center = item_0['live_bounds'].mean(0)
122
+ global_orient = item_0['global_orient'].cpu().numpy() if isinstance(item_0['global_orient'], torch.Tensor) else item_0['global_orient']
123
+
124
+ # set x and z to 0
125
+ global_orient[0] = 0
126
+ global_orient[2] = 0
127
+
128
+ global_orient = cv.Rodrigues(global_orient)[0]
129
+ time_start = torch.cuda.Event(enable_timing = True)
130
+ time_start_all = torch.cuda.Event(enable_timing = True)
131
+ time_end = torch.cuda.Event(enable_timing = True)
132
+
133
+ data_num = len(self.dataset)
134
+ if self.body['test'].get('fix_hand', False):
135
+ self.avatar_net.generate_mean_hands()
136
+ log_time = False
137
+ extr_list = []
138
+ intr_list = []
139
+ img_h_list = []
140
+ img_w_list = []
141
+
142
+
143
+ for idx in tqdm(range(data_num), desc = 'Rendering avatars...'):
144
+ if log_time:
145
+ time_start.record()
146
+ time_start_all.record()
147
+
148
+ img_scale = self.body['test'].get('img_scale', 1.0)
149
+ view_setting = self.body['test'].get('view_setting', 'free')
150
+ if view_setting == 'camera':
151
+ # training view setting
152
+ cam_id = self.body['test']['render_view_idx']
153
+ intr = self.dataset.intr_mats[cam_id].copy()
154
+ intr[:2] *= img_scale
155
+ extr = self.dataset.extr_mats[cam_id].copy()
156
+ img_h, img_w = int(self.dataset.img_heights[cam_id] * img_scale), int(self.dataset.img_widths[cam_id] * img_scale)
157
+ elif view_setting.startswith('free'):
158
+ # free view setting
159
+ # frame_num_per_circle = 360
160
+ # print(self.opt['test'].get('global_orient', False))
161
+ frame_num_per_circle = 360
162
+ rot_Y = (idx % frame_num_per_circle) / float(frame_num_per_circle) * 2 * np.pi
163
+
164
+ extr = visualize_util.calc_free_mv(object_center,
165
+ tar_pos = np.array([0, 0, 2.5]),
166
+ rot_Y = rot_Y,
167
+ rot_X = 0.3 if view_setting.endswith('bird') else 0.,
168
+ global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
169
+ intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
170
+ intr[:2] *= img_scale
171
+ img_h = int(1024 * img_scale)
172
+ img_w = int(1024 * img_scale)
173
+
174
+ extr_list.append(extr)
175
+ intr_list.append(intr)
176
+ img_h_list.append(img_h)
177
+ img_w_list.append(img_w)
178
+
179
+ elif view_setting.startswith('degree120'):
180
+ print('we render 120 degree')
181
+ # +- 60 degree
182
+ frame_per_cycle = 480
183
+ max_degree = 60
184
+ frame_half_cycle = frame_per_cycle // 2
185
+ if idx%frame_per_cycle < frame_per_cycle/2:
186
+ rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle)
187
+ # rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi
188
+ else:
189
+ rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle)
190
+
191
+ # to radian
192
+ rot_Y = rot_Y * np.pi / 180
193
+ if rot_Y<0:
194
+ rot_Y = rot_Y + 2 * np.pi
195
+ # print('rot_Y: ', rot_Y)
196
+ extr = visualize_util.calc_free_mv(object_center,
197
+ tar_pos = np.array([0, 0, 2.5]),
198
+ rot_Y = rot_Y,
199
+ rot_X = 0.3 if view_setting.endswith('bird') else 0.,
200
+ global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
201
+ intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
202
+ intr[:2] *= img_scale
203
+ img_h = int(1024 * img_scale)
204
+ img_w = int(1024 * img_scale)
205
+
206
+ extr_list.append(extr)
207
+ intr_list.append(intr)
208
+ img_h_list.append(img_h)
209
+ img_w_list.append(img_w)
210
+
211
+ elif view_setting.startswith('degree90'):
212
+ print('we render 90 degree')
213
+ # +- 60 degree
214
+ frame_per_cycle = 360
215
+ max_degree = 45
216
+ frame_half_cycle = frame_per_cycle // 2
217
+ if idx%frame_per_cycle < frame_per_cycle/2:
218
+ rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle)
219
+ # rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi
220
+ else:
221
+ rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle)
222
+
223
+ # to radian
224
+ rot_Y = rot_Y * np.pi / 180
225
+ if rot_Y<0:
226
+ rot_Y = rot_Y + 2 * np.pi
227
+ # print('rot_Y: ', rot_Y)
228
+ extr = visualize_util.calc_free_mv(object_center,
229
+ tar_pos = np.array([0, 0, 2.5]),
230
+ rot_Y = rot_Y,
231
+ rot_X = 0.3 if view_setting.endswith('bird') else 0.,
232
+ global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
233
+ intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
234
+ intr[:2] *= img_scale
235
+ img_h = int(1024 * img_scale)
236
+ img_w = int(1024 * img_scale)
237
+
238
+ extr_list.append(extr)
239
+ intr_list.append(intr)
240
+ img_h_list.append(img_h)
241
+ img_w_list.append(img_w)
242
+
243
+
244
+ elif view_setting.startswith('front'):
245
+ # front view setting
246
+ extr = visualize_util.calc_free_mv(object_center,
247
+ tar_pos = np.array([0, 0, 2.5]),
248
+ rot_Y = 0.,
249
+ rot_X = 0.3 if view_setting.endswith('bird') else 0.,
250
+ global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
251
+ intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
252
+ intr[:2] *= img_scale
253
+ img_h = int(1024 * img_scale)
254
+ img_w = int(1024 * img_scale)
255
+
256
+ extr_list.append(extr)
257
+ intr_list.append(intr)
258
+ img_h_list.append(img_h)
259
+ img_w_list.append(img_w)
260
+
261
+ # print('extr: ', extr)
262
+ # print('intr: ', intr)
263
+ # print('img_h: ', img_h)
264
+ # print('img_w: ', img_w)
265
+ # exit()
266
+
267
+
268
+
269
+ elif view_setting.startswith('back'):
270
+ # back view setting
271
+ extr = visualize_util.calc_free_mv(object_center,
272
+ tar_pos = np.array([0, 0, 2.5]),
273
+ rot_Y = np.pi,
274
+ rot_X = 0.5 * np.pi / 4. if view_setting.endswith('bird') else 0.,
275
+ global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
276
+ intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
277
+ intr[:2] *= img_scale
278
+ img_h = int(1024 * img_scale)
279
+ img_w = int(1024 * img_scale)
280
+ elif view_setting.startswith('moving'):
281
+ # moving camera setting
282
+ extr = visualize_util.calc_free_mv(object_center,
283
+ # tar_pos = np.array([0, 0, 3.0]),
284
+ # rot_Y = -0.3,
285
+ tar_pos = np.array([0, 0, 2.5]),
286
+ rot_Y = 0.,
287
+ rot_X = 0.3 if view_setting.endswith('bird') else 0.,
288
+ global_orient = global_orient if self.body['test'].get('global_orient', False) else None)
289
+ intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32)
290
+ intr[:2] *= img_scale
291
+ img_h = int(1024 * img_scale)
292
+ img_w = int(1024 * img_scale)
293
+ elif view_setting.startswith('cano'):
294
+ cano_center = self.dataset.cano_bounds.mean(0)
295
+ extr = np.identity(4, np.float32)
296
+ extr[:3, 3] = -cano_center
297
+ rot_x = np.identity(4, np.float32)
298
+ rot_x[:3, :3] = cv.Rodrigues(np.array([np.pi, 0, 0], np.float32))[0]
299
+ extr = rot_x @ extr
300
+ f_len = 5000
301
+ extr[2, 3] += f_len / 512
302
+ intr = np.array([[f_len, 0, 512], [0, f_len, 512], [0, 0, 1]], np.float32)
303
+ # item = self.dataset.getitem(idx,
304
+ # training = False,
305
+ # extr = extr,
306
+ # intr = intr,
307
+ # img_w = 1024,
308
+ # img_h = 1024)
309
+ img_w, img_h = 1024, 1024
310
+ # item['live_smpl_v'] = item['cano_smpl_v']
311
+ # item['cano2live_jnt_mats'] = torch.eye(4, dtype = torch.float32)[None].expand(item['cano2live_jnt_mats'].shape[0], -1, -1)
312
+ # item['live_bounds'] = item['cano_bounds']
313
+ else:
314
+ raise ValueError('Invalid view setting for animation!')
315
+
316
+
317
+ self.dump_renderer_info(output_dir, extr_list, intr_list, img_h_list, img_w_list)
318
+ # also save the extr and intr and img_h and img_w to json
319
+ camera_info = []
320
+ for i in range(len(extr_list)):
321
+ camera = {}
322
+ camera['extr'] = extr_list[i].tolist()
323
+ camera['intr'] = intr_list[i].tolist()
324
+ camera['img_h'] = img_h_list[i]
325
+ camera['img_w'] = img_w_list[i]
326
+ camera_info.append(camera)
327
+ with open(os.path.join(output_dir, 'camera_info.json'), 'w') as fp:
328
+ json.dump(camera_info, fp)
329
+
330
+
331
+ getitem_func = self.dataset.getitem_fast if hasattr(self.dataset, 'getitem_fast') else self.dataset.getitem
332
+ item = getitem_func(
333
+ idx,
334
+ training = False,
335
+ extr = extr,
336
+ intr = intr,
337
+ img_w = img_w,
338
+ img_h = img_h
339
+ )
340
+ items = to_cuda(item, add_batch = False)
341
+
342
+ if view_setting.startswith('moving') or view_setting == 'free_moving':
343
+ current_center = items['live_bounds'].cpu().numpy().mean(0)
344
+ delta = current_center - object_center
345
+
346
+ object_center[0] += delta[0]
347
+ # object_center[1] += delta[1]
348
+ # object_center[2] += delta[2]
349
+
350
+ if log_time:
351
+ time_end.record()
352
+ torch.cuda.synchronize()
353
+ print('Loading data costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.))
354
+ time_start.record()
355
+
356
+ if self.body['test'].get('render_skeleton', False):
357
+ from AnimatableGaussians.utils.visualize_skeletons import construct_skeletons
358
+ skel_vertices, skel_faces = construct_skeletons(item['joints'].cpu().numpy(), item['kin_parent'].cpu().numpy())
359
+ skel_mesh = trimesh.Trimesh(skel_vertices, skel_faces, process = False)
360
+
361
+ if geo_renderer is None:
362
+ geo_renderer = Renderer(item['img_w'], item['img_h'], shader_name = 'phong_geometry', bg_color = (1, 1, 1))
363
+ extr, intr = item['extr'], item['intr']
364
+ geo_renderer.set_camera(extr, intr)
365
+ geo_renderer.set_model(skel_vertices[skel_faces.reshape(-1)], skel_mesh.vertex_normals.astype(np.float32)[skel_faces.reshape(-1)])
366
+ skel_img = geo_renderer.render()[:, :, :3]
367
+ skel_img = (skel_img * 255).astype(np.uint8)
368
+ cv.imwrite(output_dir + '/live_skeleton/%08d.jpg' % item['data_idx'], skel_img)
369
+
370
+ if log_time:
371
+ time_end.record()
372
+ torch.cuda.synchronize()
373
+ print('Rendering skeletons costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.))
374
+ time_start.record()
375
+
376
+ if 'smpl_pos_map' not in items:
377
+ self.avatar_net.get_pose_map(items)
378
+
379
+ # pca
380
+ if use_pca:
381
+ mask = training_dataset.pos_map_mask
382
+ live_pos_map = items['smpl_pos_map'].permute(1, 2, 0).cpu().numpy()
383
+ front_live_pos_map, back_live_pos_map = np.split(live_pos_map, [3], 2)
384
+ pose_conds = front_live_pos_map[mask]
385
+ new_pose_conds = training_dataset.transform_pca(pose_conds, sigma_pca = float(self.body['test'].get('sigma_pca', 2.)))
386
+ front_live_pos_map[mask] = new_pose_conds
387
+ live_pos_map = np.concatenate([front_live_pos_map, back_live_pos_map], 2)
388
+ items.update({
389
+ 'smpl_pos_map_pca': torch.from_numpy(live_pos_map).to(self.device).permute(2, 0, 1)
390
+ })
391
+
392
+ if log_time:
393
+ time_end.record()
394
+ torch.cuda.synchronize()
395
+ print('Rendering pose conditions costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.))
396
+ time_start.record()
397
+
398
+ output = self.avatar_net.render(items, bg_color = self.bg_color, use_pca = use_pca)
399
+ output_wo_hand = self.avatar_net.render_wo_hand(items, bg_color = self.bg_color, use_pca = use_pca)
400
+ mask_output = self.avatar_net.render_mask(items, bg_color = self.bg_color, use_pca = use_pca)
401
+
402
+ if log_time:
403
+ time_end.record()
404
+ torch.cuda.synchronize()
405
+ print('Rendering avatar costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.))
406
+ time_start.record()
407
+
408
+ if 'rgb_map' in output_wo_hand:
409
+ rgb_map_wo_hand = output_wo_hand['rgb_map']
410
+
411
+ if 'full_body_rgb_map' in mask_output:
412
+ os.makedirs(output_dir + '/full_body_mask', exist_ok = True)
413
+ full_body_mask = mask_output['full_body_rgb_map']
414
+ full_body_mask.clip_(0., 1.)
415
+ full_body_mask = (full_body_mask * 255).to(torch.uint8)
416
+ cv.imwrite(output_dir + '/full_body_mask/%08d.png' % item['data_idx'], full_body_mask.cpu().numpy())
417
+
418
+ if 'hand_only_rgb_map' in mask_output:
419
+ os.makedirs(output_dir + '/hand_only_mask', exist_ok = True)
420
+ hand_only_mask = mask_output['hand_only_rgb_map']
421
+ hand_only_mask.clip_(0., 1.)
422
+ hand_only_mask = (hand_only_mask * 255).to(torch.uint8)
423
+ cv.imwrite(output_dir + '/hand_only_mask/%08d.png' % item['data_idx'], hand_only_mask.cpu().numpy())
424
+
425
+ if 'full_body_rgb_map' in mask_output and 'hand_only_rgb_map' in mask_output:
426
+ # mask only covers hand
427
+ body_red_mask = (mask_output['full_body_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['full_body_rgb_map'].device))
428
+ body_red_mask = (body_red_mask*body_red_mask).sum(dim=2) < 0.01 # need save
429
+
430
+ hand_red_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['hand_only_rgb_map'].device))
431
+ hand_red_mask = (hand_red_mask*hand_red_mask).sum(dim=2) < 0.01
432
+
433
+ if_mask_r_hand = abs(body_red_mask.sum() - hand_red_mask.sum()) / hand_red_mask.sum() > 0.95
434
+ if_mask_r_hand = if_mask_r_hand.cpu().numpy()
435
+
436
+ body_blue_mask = (mask_output['full_body_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['full_body_rgb_map'].device))
437
+ body_blue_mask = (body_blue_mask*body_blue_mask).sum(dim=2) < 0.01 # need save
438
+
439
+ hand_blue_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['hand_only_rgb_map'].device))
440
+ hand_blue_mask = (hand_blue_mask*hand_blue_mask).sum(dim=2) < 0.01
441
+
442
+ if_mask_l_hand = abs(body_blue_mask.sum() - hand_blue_mask.sum()) / hand_blue_mask.sum() > 0.95
443
+ if_mask_l_hand = if_mask_l_hand.cpu().numpy()
444
+
445
+ # 保存左右手被遮挡部分的mask
446
+ red_mask = hand_red_mask ^ (hand_red_mask & body_red_mask)
447
+ blue_mask = hand_blue_mask ^ (hand_blue_mask & body_blue_mask)
448
+ all_mask = red_mask | blue_mask
449
+
450
+ # now save 3 mask to 3 folders
451
+ os.makedirs(output_dir + '/hand_mask', exist_ok = True)
452
+ os.makedirs(output_dir + '/r_hand_mask', exist_ok = True)
453
+ os.makedirs(output_dir + '/l_hand_mask', exist_ok = True)
454
+ os.makedirs(output_dir + '/hand_visual', exist_ok = True)
455
+
456
+ all_mask = (all_mask * 255).to(torch.uint8)
457
+ cv.imwrite(output_dir + '/hand_mask/%08d.png' % item['data_idx'], all_mask.cpu().numpy())
458
+ r_hand_mask = (body_red_mask * 255).to(torch.uint8)
459
+ cv.imwrite(output_dir + '/r_hand_mask/%08d.png' % item['data_idx'], r_hand_mask.cpu().numpy())
460
+ l_hand_mask = (body_blue_mask * 255).to(torch.uint8)
461
+ cv.imwrite(output_dir + '/l_hand_mask/%08d.png' % item['data_idx'], l_hand_mask.cpu().numpy())
462
+ hand_visual = [if_mask_r_hand, if_mask_l_hand]
463
+ # save to npy
464
+ with open(output_dir + '/hand_visual/%08d.npy' % item['data_idx'], 'wb') as f:
465
+ np.save(f, hand_visual)
466
+
467
+
468
+ # now build sleeve_mask
469
+ if 'left_hand_rgb_map' in mask_output and 'right_hand_rgb_map' in mask_output:
470
+ os.makedirs(output_dir + '/left_sleeve_mask', exist_ok = True)
471
+ os.makedirs(output_dir + '/right_sleeve_mask', exist_ok = True)
472
+
473
+ mask = (r_hand_mask>128) | (l_hand_mask>128)| (all_mask>128)
474
+ mask = mask.cpu().numpy().astype(np.uint8)
475
+ # 定义一个结构元素,可以调整其大小以改变膨胀的程度
476
+ kernel = np.ones((5, 5), np.uint8)
477
+ # 应用膨胀操作
478
+ mask = cv.dilate(mask, kernel, iterations=3)
479
+ mask = torch.tensor(mask).to(self.device)
480
+
481
+ left_hand_mask = mask_output['left_hand_rgb_map']
482
+ left_hand_mask.clip_(0., 1.)
483
+ # non white part is mask
484
+ left_hand_mask = (torch.tensor([1., 1., 1.], device = left_hand_mask.device) - left_hand_mask)
485
+ left_hand_mask = (left_hand_mask*left_hand_mask).sum(dim=2) > 0.01
486
+ # dele two hand mask
487
+ left_hand_mask = left_hand_mask & ~mask
488
+
489
+ right_hand_mask = mask_output['right_hand_rgb_map']
490
+ right_hand_mask.clip_(0., 1.)
491
+ right_hand_mask = (torch.tensor([1., 1., 1.], device = right_hand_mask.device) - right_hand_mask)
492
+ right_hand_mask = (right_hand_mask*right_hand_mask).sum(dim=2) > 0.01
493
+ right_hand_mask = right_hand_mask & ~mask
494
+
495
+ # save
496
+ left_hand_mask = (left_hand_mask * 255).to(torch.uint8)
497
+ cv.imwrite(output_dir + '/left_sleeve_mask/%08d.png' % item['data_idx'], left_hand_mask.cpu().numpy())
498
+ right_hand_mask = (right_hand_mask * 255).to(torch.uint8)
499
+ cv.imwrite(output_dir + '/right_sleeve_mask/%08d.png' % item['data_idx'], right_hand_mask.cpu().numpy())
500
+
501
+ rgb_map = output['rgb_map']
502
+ rgb_map.clip_(0., 1.)
503
+ rgb_map = (rgb_map * 255).to(torch.uint8).cpu().numpy()
504
+ cv.imwrite(output_dir + '/rgb_map/%08d.jpg' % item['data_idx'], rgb_map)
505
+
506
+ # 利用 r_hand_mask 和 l_hand_mask,将wo_hand图像中的mask部分覆盖rgb_map
507
+ if 'rgb_map' in output_wo_hand and 'full_body_rgb_map' in mask_output and 'hand_only_rgb_map' in mask_output:
508
+ rgb_map_wo_hand = output_wo_hand['rgb_map']
509
+ rgb_map_wo_hand.clip_(0., 1.)
510
+ rgb_map_wo_hand = (rgb_map_wo_hand * 255).to(torch.uint8).cpu().numpy()
511
+
512
+ r_mask = (r_hand_mask>128).cpu().numpy()
513
+ l_mask = (l_hand_mask>128).cpu().numpy()
514
+ mask = r_mask | l_mask
515
+ mask = mask.astype(np.uint8)
516
+ # 定义一个结构元素,可以调整其大小以改变膨胀的程度
517
+ kernel = np.ones((5, 5), np.uint8)
518
+ # 应用膨胀操作
519
+ mask = cv.dilate(mask, kernel, iterations=3)
520
+ mask = mask.astype(np.bool_)
521
+ mask = np.expand_dims(mask, axis=2)
522
+ # print('mask shape: ', mask.shape)
523
+ import ipdb
524
+ # ipdb.set_trace()
525
+ mix = rgb_map_wo_hand.copy() * mask + rgb_map * ~mask
526
+ cv.imwrite(output_dir + '/rgb_map_wo_hand/%08d.png' % item['data_idx'], mix)
527
+
528
+ if 'torso_map' in output:
529
+ os.makedirs(output_dir + '/torso_map', exist_ok = True)
530
+ torso_map = output['torso_map'][:, :, 0]
531
+ torso_map.clip_(0., 1.)
532
+ torso_map = (torso_map * 255).to(torch.uint8)
533
+ cv.imwrite(output_dir + '/torso_map/%08d.png' % item['data_idx'], torso_map.cpu().numpy())
534
+
535
+ if 'mask_map' in output:
536
+ os.makedirs(output_dir + '/mask_map', exist_ok = True)
537
+ mask_map = output['mask_map'][:, :, 0]
538
+ mask_map.clip_(0., 1.)
539
+ mask_map = (mask_map * 255).to(torch.uint8)
540
+ cv.imwrite(output_dir + '/mask_map/%08d.png' % item['data_idx'], mask_map.cpu().numpy())
541
+
542
+ if self.body['test'].get('save_tex_map', False):
543
+ os.makedirs(output_dir + '/cano_tex_map', exist_ok = True)
544
+ cano_tex_map = output['cano_tex_map']
545
+ cano_tex_map.clip_(0., 1.)
546
+ cano_tex_map = (cano_tex_map * 255).to(torch.uint8)
547
+ cv.imwrite(output_dir + '/cano_tex_map/%08d.png' % item['data_idx'], cano_tex_map.cpu().numpy())
548
+
549
+ if self.body['test'].get('save_ply', False):
550
+ if item['data_idx'] == 0:
551
+ save_gaussians_as_ply(output_dir + '/posed_gaussians/%08d.ply' % item['data_idx'], output['posed_gaussians'])
552
+ for k in output['posed_gaussians'].keys():
553
+ if isinstance(output['posed_gaussians'][k], torch.Tensor):
554
+ output['posed_gaussians'][k] = output['posed_gaussians'][k].detach().cpu().numpy()
555
+ np.savez(output_dir + '/posed_gaussians/%08d.npz' % item['data_idx'], **output['posed_gaussians'])
556
+ np.savez(output_dir + ('/posed_params/%08d.npz' % item['data_idx']),
557
+ betas=training_dataset.smpl_data['betas'].reshape([-1]).detach().cpu().numpy(),
558
+ global_orient=item['global_orient'].reshape([-1]).detach().cpu().numpy(),
559
+ transl=item['transl'].reshape([-1]).detach().cpu().numpy(),
560
+ body_pose=item['body_pose'].reshape([-1]).detach().cpu().numpy())
561
+
562
+ if log_time:
563
+ time_end.record()
564
+ torch.cuda.synchronize()
565
+ print('Saving images costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.))
566
+ print('Animating one frame costs %.4f secs' % (time_start_all.elapsed_time(time_end) / 1000.))
567
+
568
+ torch.cuda.empty_cache()
569
+
570
+ def dump_renderer_info(self, dump_dir, extrs, intrs, img_heights, img_widths):
571
+ with open(os.path.join(dump_dir, 'cfg_args'), 'w') as fp:
572
+ outstr = "Namespace(sh_degree=%d, source_path='%s', model_path='%s', images='images', resolution=-1, " \
573
+ "white_background=False, data_device='cuda', eval=False)" % (
574
+ 3, self.body['train']['data']['data_dir'], dump_dir)
575
+ fp.write(outstr)
576
+ with open(os.path.join(dump_dir, 'cameras.json'), 'w') as fp:
577
+ cam_jsons = []
578
+ for ci in range(len(extrs)):
579
+ extr, intr = extrs[ci], intrs[ci]
580
+ img_h, img_w = img_heights[ci], img_widths[ci]
581
+
582
+ w2c = extr
583
+ c2w = np.linalg.inv(w2c)
584
+ pos = c2w[:3, 3]
585
+ rot = c2w[:3, :3]
586
+ serializable_array_2d = [x.tolist() for x in rot]
587
+ camera_entry = {
588
+ 'id': ci,
589
+ 'img_name': '%08d' % ci,
590
+ 'width': int(img_w),
591
+ 'height': int(img_h),
592
+ 'position': pos.tolist(),
593
+ 'rotation': serializable_array_2d,
594
+ 'fy': float(intr[1, 1]),
595
+ 'fx': float(intr[0, 0]),
596
+ }
597
+ cam_jsons.append(camera_entry)
598
+ json.dump(cam_jsons, fp)
599
+ return
600
+
601
+ def test_head(self):
602
+ dataset = ReenactmentDataset(self.head_config.dataset)
603
+ dataloader = DataLoaderX(dataset, batch_size=1, shuffle=False, pin_memory=True)
604
+
605
+ device = torch.device('cuda:%d' % self.head_config.gpu_id)
606
+
607
+ gaussianhead_state_dict = torch.load(self.head_config.load_gaussianhead_checkpoint, map_location=lambda storage, loc: storage)
608
+ gaussianhead = GaussianHeadModule(self.head_config.gaussianheadmodule,
609
+ xyz=gaussianhead_state_dict['xyz'],
610
+ feature=gaussianhead_state_dict['feature'],
611
+ landmarks_3d_neutral=gaussianhead_state_dict['landmarks_3d_neutral']).to(device)
612
+ gaussianhead.load_state_dict(gaussianhead_state_dict)
613
+
614
+ supres = SuperResolutionModule(self.head_config.supresmodule).to(device)
615
+ supres.load_state_dict(torch.load(self.head_config.load_supres_checkpoint, map_location=lambda storage, loc: storage))
616
+
617
+ camera = CameraModule()
618
+ recorder = ReenactmentRecorder(self.head_config.recorder)
619
+
620
+ app = Reenactment(dataloader, gaussianhead, supres, camera, recorder, self.head_config.gpu_id, dataset.freeview)
621
+ if self.head.offline_rendering_param_fpath is None:
622
+ app.run(stop_fid=800)
623
+ else:
624
+ app.run_for_offline_stitching(self.head.offline_rendering_param_fpath)
625
+
626
+ def cal_cat_param(self):
627
+ calc_offline_rendering_param(
628
+ self.cat.body_gaussian_root_dir,
629
+ self.cat.ref_head_gaussian_path,
630
+ self.cat.ref_head_param_path,
631
+ self.cat.render_cam_fpath,
632
+ self.cat.body_head_blending_param_path
633
+ )
634
+
635
+
636
+
637
+
638
+ if __name__ == '__main__':
639
+ conf = OmegaConf.load('configs/example.yaml')
640
+ avatar = Avatar(conf)
641
+ avatar.test_body()
642
+ # avatar.test_head()
calc_offline_rendering_param.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tqdm
3
+ import os, glob
4
+ import json
5
+ import argparse
6
+
7
+ from render_utils.lib.networks.smpl_torch import SmplTorch
8
+ from render_utils.lib.utils.gaussian_np_utils import load_gaussians_from_ply
9
+ from render_utils.stitch_body_and_head import load_body_params, load_face_params, get_smpl_verts_and_head_transformation, calc_livehead2livebody
10
+
11
+
12
+ def load_rendering_camera(camera_fpath):
13
+ with open(camera_fpath, 'r') as fp:
14
+ camera_data = json.load(fp)
15
+ camera_data = camera_data[0]
16
+ image_size = [camera_data['width'], camera_data['height']]
17
+ cam_f = [camera_data['fx'], camera_data['fy']]
18
+ cam_pos = np.array(camera_data['position'])
19
+ cam_rot = np.array(camera_data['rotation']).reshape(3, 3)
20
+ c2w = np.eye(4)
21
+ c2w[:3, :3] = cam_rot
22
+ c2w[:3, 3] = cam_pos
23
+ cam_extr = np.linalg.inv(c2w)
24
+ cam_intr = np.eye(3)
25
+ cam_intr[0, 0] = cam_f[0]
26
+ cam_intr[1, 1] = cam_f[1]
27
+ cam_intr[0, 2] = image_size[0] / 2
28
+ cam_intr[1, 2] = image_size[1] / 2
29
+ return cam_extr, cam_intr, image_size
30
+
31
+ def load_camera_list(camera_fpath):
32
+ with open(camera_fpath, 'r') as fp:
33
+ camera_data = json.load(fp)
34
+ image_size = [camera_data[0]['width'], camera_data[0]['height']]
35
+ cam_list = []
36
+ for cam in camera_data:
37
+ cam_f = [cam['fx'], cam['fy']]
38
+ cam_pos = np.array(cam['position'])
39
+ cam_rot = np.array(cam['rotation']).reshape(3, 3)
40
+ c2w = np.eye(4)
41
+ c2w[:3, :3] = cam_rot
42
+ c2w[:3, 3] = cam_pos
43
+ cam_extr = np.linalg.inv(c2w)
44
+ cam_intr = np.eye(3)
45
+ cam_intr[0, 0] = cam_f[0]
46
+ cam_intr[1, 1] = cam_f[1]
47
+ cam_intr[0, 2] = image_size[0] / 2
48
+ cam_intr[1, 2] = image_size[1] / 2
49
+ cam_list.append((cam_extr, cam_intr))
50
+ return cam_list, image_size
51
+
52
+ def load_camera_data(cam):
53
+ image_size = [cam['width'], cam['height']]
54
+ cam_f = [cam['fx'], cam['fy']]
55
+ cam_pos = np.array(cam['position'])
56
+ cam_rot = np.array(cam['rotation']).reshape(3, 3)
57
+ c2w = np.eye(4)
58
+ c2w[:3, :3] = cam_rot
59
+ c2w[:3, 3] = cam_pos
60
+ cam_extr = np.linalg.inv(c2w)
61
+ cam_intr = np.eye(3)
62
+ cam_intr[0, 0] = cam_f[0]
63
+ cam_intr[1, 1] = cam_f[1]
64
+ cam_intr[0, 2] = image_size[0] / 2
65
+ cam_intr[1, 2] = image_size[1] / 2
66
+
67
+ return (cam_extr, cam_intr), image_size
68
+
69
+ def calc_offline_rendering_param(
70
+ body_gaussian_root_dir, ref_head_gaussian_path, ref_head_param_path, render_cam_fpath,
71
+ body_head_blending_param_path):
72
+ body_param_flist = sorted(glob.glob(os.path.join(body_gaussian_root_dir, 'posed_params/*.npz')))
73
+
74
+ head_gaussians = load_gaussians_from_ply(ref_head_gaussian_path)
75
+ head_pose, head_scale, id_coeff, exp_coeff = load_face_params(ref_head_param_path)
76
+ # cam_extr_body, cam_intr_body, image_size = load_rendering_camera(render_cam_fpath)
77
+ cam_list, image_size = load_camera_list(render_cam_fpath)
78
+
79
+ body_head_blending_params = np.load(body_head_blending_param_path)
80
+ smplx_to_faceverse = body_head_blending_params['smplx_to_faceverse']
81
+ residual_transf = body_head_blending_params['residual_transf']
82
+ body_nonface_mask = body_head_blending_params['body_nonface_mask']
83
+ head_nonface_mask = body_head_blending_params['head_nonface_mask']
84
+ head_facial_idx = body_head_blending_params['head_facial_idx']
85
+ body_facial_idx = body_head_blending_params['body_facial_idx']
86
+ head_body_corr_idx = body_head_blending_params['head_body_corr_idx']
87
+ head_color_bw = body_head_blending_params['head_color_bw']
88
+ color_transfer = body_head_blending_params['color_transfer']
89
+
90
+ smpl = SmplTorch(model_file='./AnimatableGaussians/smpl_files/smplx/SMPLX_NEUTRAL.npz')
91
+
92
+ head_cam_extr = []
93
+ head_cam_intr = []
94
+ head_cam_intr_zoom = []
95
+ head_zoom_center = []
96
+ head_zoom_scale = []
97
+
98
+ for i, body_param_fpath in enumerate(tqdm.tqdm(body_param_flist)):
99
+ global_orient, transl, body_pose, betas = load_body_params(body_param_fpath)
100
+ # body_gaussians = load_gaussians_from_ply(body_gaussian_fpath)
101
+
102
+ smpl_verts, head_joint_transfmat = get_smpl_verts_and_head_transformation(
103
+ smpl, global_orient, body_pose, transl, betas)
104
+ livehead2livebody = calc_livehead2livebody(head_pose, smplx_to_faceverse, head_joint_transfmat)
105
+ total_transf = np.matmul(livehead2livebody, residual_transf)
106
+
107
+ cam_extr = np.matmul(cam_list[i][0], total_transf)
108
+ cam_intr = np.copy(cam_list[i][1])
109
+
110
+ head_cam_extr.append(cam_extr)
111
+ head_cam_intr.append(cam_intr)
112
+
113
+ pts = np.copy(head_gaussians.xyz)
114
+ pts_proj = np.matmul(pts, cam_extr[:3, :3].transpose()) + cam_extr[:3, 3]
115
+ pts_proj = np.matmul(pts_proj, cam_intr.transpose())
116
+ pts_proj = pts_proj / pts_proj[:, 2:]
117
+ # pts_proj = np.int32(np.round(pts_proj[:, :2]))
118
+
119
+ # img = np.zeros([image_size[1], image_size[0], 3], dtype=np.uint8)
120
+ # for p in pts_proj[::50]:
121
+ # p = np.clip(p, 0, image_size[0] - 1)
122
+ # cv.circle(img, (int(p[0]), int(p[1])), 2, (0, 255, 0), -1)
123
+ # cv.imshow('img', img)
124
+
125
+ pts_min, pts_max = np.min(pts_proj, axis=0), np.max(pts_proj, axis=0)
126
+ pts_center = (pts_min + pts_max) // 2
127
+ pts_size = np.max(pts_max - pts_min)
128
+ tgt_pts_size = 350
129
+ tgt_image_size = 512
130
+ zoom_scale = tgt_pts_size / pts_size
131
+ cam_intr_zoom = np.copy(cam_intr)
132
+ cam_intr_zoom[:2] *= zoom_scale
133
+ cam_intr_zoom[0, 2] = cam_intr_zoom[0, 2] - (pts_center[0]*zoom_scale - tgt_image_size/2)
134
+ cam_intr_zoom[1, 2] = cam_intr_zoom[1, 2] - (pts_center[1]*zoom_scale - tgt_image_size/2)
135
+ head_cam_intr_zoom.append(cam_intr_zoom)
136
+ head_zoom_center.append(pts_center)
137
+ head_zoom_scale.append(zoom_scale)
138
+
139
+ # pts_proj = np.matmul(pts, cam_extr[:3, :3].transpose()) + cam_extr[:3, 3]
140
+ # pts_proj = np.matmul(pts_proj, cam_intr_zoom.transpose())
141
+ # pts_proj = pts_proj / pts_proj[:, 2:]
142
+ # pts_proj = np.int32(np.round(pts_proj[:, :2]))
143
+ # img = np.zeros([512, 512, 3], dtype=np.uint8)
144
+ # for p in pts_proj[::50]:
145
+ # p = np.clip(p, 0, image_size[0] - 1)
146
+ # cv.circle(img, (int(p[0]), int(p[1])), 2, (0, 255, 0), -1)
147
+ # cv.imshow('img_zoom', img)
148
+ # cv.waitKey()
149
+
150
+ np.savez(os.path.join(os.path.dirname(body_head_blending_param_path), 'head_zoomin_render_param.npz'),
151
+ cam_extr=head_cam_extr, cam_intr=head_cam_intr, image_size=image_size,
152
+ cam_intr_zoom=head_cam_intr_zoom, zoom_image_size=[tgt_image_size, tgt_image_size],
153
+ zoom_center=head_zoom_center,
154
+ zoom_scale=head_zoom_scale,
155
+ head_pose=head_pose, head_scale=head_scale, head_color_bw=head_color_bw)
156
+
157
+
158
+
159
+ if __name__ == '__main__':
160
+ parser = argparse.ArgumentParser()
161
+
162
+ """
163
+ body_gaussian_root_dir, ref_head_gaussian_path, ref_head_param_path, render_cam_fpath,
164
+ body_head_blending_param_path
165
+ """
166
+ parser.add_argument('--body_gaussian_root_dir', type=str)
167
+ parser.add_argument('--ref_head_gaussian_path', type=str)
168
+ parser.add_argument('--ref_head_param_path', type=str)
169
+ parser.add_argument('--render_cam_fpath', type=str)
170
+ parser.add_argument('--body_head_blending_param_path', type=str)
171
+ args = parser.parse_args()
172
+ calc_offline_rendering_param(
173
+ args.body_gaussian_root_dir,
174
+ args.ref_head_gaussian_path,
175
+ args.ref_head_param_path,
176
+ args.render_cam_fpath,
177
+ args.body_head_blending_param_path
178
+ )
179
+
180
+ """
181
+ python calc_offline_rendering_param.py ^
182
+ --body_gaussian_root_dir ./AnimatableGaussians/test_results/huawei0425/checkpoints/AMASS__test_poses_ours_front_view/batch_750000/pca_20_sigma_2.00/ ^
183
+ --ref_head_gaussian_path ./Gaussian-Head-Avatar/results/reenactment/huawei0425_self/posed_gaussians/000000.ply ^
184
+ --ref_head_param_path ./Gaussian-Head-Avatar/results/reenactment/huawei0425_self/params/000000_param.npz ^
185
+ --render_cam_fpath ./AnimatableGaussians/test_results/huawei0425/checkpoints/AMASS__test_poses_ours_front_view/batch_750000/pca_20_sigma_2.00/cameras.json ^
186
+ --body_head_blending_param_path ./data/body_face_stitching_sr/body_head_blending_param.npz
187
+
188
+ """
configs/example.yaml ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trial_name: "body_head_avatar"
2
+ device: cuda
3
+ animatablegaussians:
4
+ train:
5
+ dataset: MvRgbDatasetAvatarReX
6
+ data:
7
+ subject_name: 1007_slow10
8
+ data_dir: ./checkpoints/pos_map_ys/body_mix
9
+ frame_range: &id001
10
+ - 0
11
+ - 200
12
+ - 1
13
+ used_cam_ids:
14
+ - 0
15
+ - 1
16
+ - 2
17
+ - 3
18
+ - 4
19
+ - 5
20
+ - 6
21
+ - 8
22
+ - 9
23
+ - 10
24
+ - 11
25
+ - 12
26
+ - 14
27
+ - 15
28
+ load_smpl_pos_map: true
29
+ pretrained_dir: null
30
+ net_ckpt_dir: ./results/huawei0425/avatar2
31
+ prev_ckpt: null
32
+ ckpt_interval:
33
+ epoch: 10
34
+ batch: 50000
35
+ eval_interval: 1000
36
+ eval_training_ids:
37
+ - 190
38
+ - 7
39
+ eval_testing_ids:
40
+ - 354
41
+ - 7
42
+ eval_img_factor: 1.0
43
+ lr_init: 0.0005
44
+ loss_weight:
45
+ l1: 1.0
46
+ lpips: 0.1
47
+ offset: 0.005
48
+ finetune_color: false
49
+ batch_size: 1
50
+ num_workers: 8
51
+ random_bg_color: true
52
+ test:
53
+ output_dir: ./test_results/temp_test
54
+ dataset: MvRgbDatasetAvatarReX
55
+ data:
56
+ data_dir: ./checkpoints/pos_map_ys/body_mix
57
+ frame_range: [0, 800]
58
+ subject_name: huawei0425
59
+ pose_data:
60
+ data_path: ./test_data/AMASS/1007_train_data_slow10.npz
61
+ frame_range: [0, 2000]
62
+ view_setting: degree90
63
+ render_view_idx: 13
64
+ global_orient: true
65
+ img_scale: 2.0
66
+ save_mesh: false
67
+ render_skeleton: false
68
+ save_tex_map: false
69
+ save_ply: true
70
+ fix_hand: true
71
+ fix_hand_id: 23
72
+ n_pca: 20
73
+ sigma_pca: 2.0
74
+ prev_ckpt: ./checkpoints/checkpoints/body_ys
75
+ model:
76
+ with_viewdirs: true
77
+ random_style: false
78
+
79
+ gha:
80
+ config_path: configs/head.yaml
81
+ offline_rendering_param_fpath: ./checkpoints/render_param/head_zoomin_render_param.npz
82
+
83
+ cat:
84
+ body_gaussian_root_dir: ./checkpoints/pos_map_ys/body_mix
85
+ ref_head_gaussian_path: ./checkpoints/ref_gaussian/head/000000.ply
86
+ ref_head_param_path: ./checkpoints/ref_gaussian/head/000000_param.npz
87
+ render_cam_fpath: /home/pengc02/pengcheng/projects/gaussian_avatar/avatar_final/AnimatableGaussians/test_results/1007_slow10/checkpoints/AMASS__1007_train_data_slow10_degree90_view/batch_789377/pca_20_sigma_2.00/cameras.json
88
+ body_head_blending_param_path: ./checkpoints/render_param/body_head_blending_param.npz
89
+
configs/head.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu_id: 0
2
+ load_supres_checkpoint: './checkpoints/face_0929/supres_latest'
3
+ load_gaussianhead_checkpoint: './checkpoints/face_0929/gaussianhead_latest'
4
+
5
+ dataset:
6
+ dataroot: './test_data/face1001'
7
+ image_files: 'images/*/wrong_image.jpg'
8
+ param_files: 'params/*/params.npz'
9
+ camera_path: './test_data/face1001/cameras/0000/camera_22070938.npz'
10
+ pose_code_path: './test_data/face1001/params/0000/params.npz'
11
+ exp_path: '/home/pengc02/pengcheng/projects/gaussian_avatar/avatar_final/data/1005_thu_slow/thuSlow10.npy'
12
+ freeview: False
13
+ resolution: 2048
14
+ original_resolution: 2048
15
+
16
+ supresmodule:
17
+ input_dim: 32
18
+ output_dim: 3
19
+ network_capacity: 32
20
+
21
+ gaussianheadmodule:
22
+ num_add_mouth_points: 3000
23
+ exp_color_mlp: [180, 256, 256, 32]
24
+ pose_color_mlp: [182, 128, 32]
25
+ exp_deform_mlp: [79, 256, 256, 256, 256, 256, 3]
26
+ pose_deform_mlp: [81, 256, 256, 3]
27
+ exp_attributes_mlp: [180, 256, 256, 256, 8]
28
+ pose_attributes_mlp: [182, 128, 128, 8]
29
+ exp_coeffs_dim: 52
30
+ pos_freq: 4
31
+ dist_threshold_near: 0.05
32
+ dist_threshold_far: 0.12
33
+ deform_scale: 0.3
34
+ attributes_scale: 0.2
35
+
36
+ recorder:
37
+ name: 'thu_exp_slow'
38
+ result_path: 'results/reenactment'
39
+
gradio_debug.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+
4
+ def load_and_display_video(video_path):
5
+ if os.path.exists(video_path):
6
+ return video_path
7
+ else:
8
+ return "Invalid video path."
9
+
10
+ with gr.Blocks() as demo:
11
+ video_input = gr.Textbox(label="Enter Video Path")
12
+ video_output = gr.Video(label="Video Output")
13
+
14
+ load_button = gr.Button("Load Video")
15
+
16
+ load_button.click(fn=load_and_display_video,
17
+ inputs=video_input,
18
+ outputs=video_output)
19
+
20
+ # 启动应用
21
+ demo.launch()
other_requirement.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pip install kaolin==0.16.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.4.0_cu121.html
2
+
3
+
4
+ cd AnimatableGaussians
5
+ # install diff-gaussian-rasterization-depth-alpha
6
+ cd gaussians/diff_gaussian_rasterization_depth_alpha
7
+ python setup.py install
8
+ cd ../..
9
+
10
+ # install styleunet
11
+ cd network/styleunet
12
+ python setup.py install
13
+ cd ../..
14
+
15
+ # HTTPS
16
+ git clone https://github.com/graphdeco-inria/gaussian-splatting --recursive
17
+ # Modify "submodules/diff-gaussian-rasterization/cuda_rasterizer/config.h" from "NUM_CHANNELS 3" to "NUM_CHANNELS 32"
18
+ pip install submodules/diff-gaussian-rasterization
19
+ pip install submodules/simple-knn
output/00000000.jpg ADDED
output/00000001.jpg ADDED
output/00000002.jpg ADDED
output/00000003.jpg ADDED
output/00000004.jpg ADDED
output/00000005.jpg ADDED
output/00000006.jpg ADDED
output/00000007.jpg ADDED
output/00000008.jpg ADDED
output/00000009.jpg ADDED
output/00000010.jpg ADDED
output/00000011.jpg ADDED
output/00000012.jpg ADDED
output/00000013.jpg ADDED
output/00000014.jpg ADDED
output/00000015.jpg ADDED
output/00000016.jpg ADDED
output/00000017.jpg ADDED
output/00000018.jpg ADDED
output/00000019.jpg ADDED
output/00000020.jpg ADDED
output/00000021.jpg ADDED
output/00000022.jpg ADDED
output/00000023.jpg ADDED
output/00000024.jpg ADDED
output/00000025.jpg ADDED
output/00000026.jpg ADDED
output/00000027.jpg ADDED
output/00000028.jpg ADDED
output/00000029.jpg ADDED