|
import os |
|
import json |
|
import copy |
|
import shutil |
|
import argparse |
|
import numpy as np |
|
|
|
import sys |
|
import torch |
|
|
|
from sparseags.mesh_utils.mesh_renderer import Renderer |
|
from sparseags.render_utils.util import render_and_compare, align_to_mesh |
|
from sparseags.visual_utils import vis_output |
|
|
|
|
|
def seed_everything(seed): |
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
def do_reconstruction(cur_dir, out_dir, args, force_all_views=False): |
|
|
|
if not os.path.exists(os.path.join(cur_dir, f'{args.category}_mesh.{args.mesh_format}')): |
|
os.system(f'python sparseags/main_stage1.py ' |
|
f'--config configs/{args.config} ' |
|
f'camera_path={cur_dir}/cameras.json ' |
|
f'outdir=./ ' |
|
f'save_path={cur_dir}/{args.category} ' |
|
f'opt_cam=1 sh_degree=3 ' |
|
f'num_pts={args.num_pts} ' |
|
f'all_views={force_all_views} ' |
|
f'mesh_format={args.mesh_format}') |
|
|
|
|
|
if not os.path.exists(os.path.join(cur_dir, f'{args.category}.{args.mesh_format}')): |
|
os.system(f'python sparseags/main_stage2.py ' |
|
f'--config configs/{args.config} ' |
|
f'camera_path={cur_dir}/cameras_updated.json ' |
|
f'outdir=./ ' |
|
f'save_path={cur_dir}/{args.category} ' |
|
f'all_views={force_all_views} ' |
|
f'mesh_format={args.mesh_format}') |
|
|
|
|
|
mesh_path = os.path.join(cur_dir, f'{args.category}.{args.mesh_format}') |
|
if not os.path.exists(os.path.join(cur_dir, f'{args.category}.mp4')): |
|
os.system(f'python -m kiui.render {mesh_path} ' |
|
f'--save_video {cur_dir}/{args.category}.mp4 ' |
|
f'--wogui ' |
|
f'--elevation -30') |
|
|
|
|
|
def read_updated_cameras(cur_dir): |
|
camera_path = os.path.join(cur_dir, 'cameras_updated.json') |
|
with open(camera_path, 'r') as f: |
|
camera_data = json.load(f) |
|
|
|
return camera_data |
|
|
|
|
|
def save_updated_cameras(camera_data, cur_dir): |
|
camera_path = os.path.join(cur_dir, 'cameras.json') |
|
with open(camera_path, 'w') as f: |
|
json.dump(camera_data, f, indent=4) |
|
|
|
|
|
def save_metrics(metrics_track, out_dir): |
|
metrics_path = os.path.join(out_dir, 'metrics.json') |
|
with open(metrics_path, 'w') as f: |
|
json.dump(metrics_track, f, indent=4) |
|
|
|
|
|
def main(args): |
|
CATEGORY = args.category |
|
NUM_VIEWS = args.num_views |
|
src_dir = os.path.join('data/demo', CATEGORY) |
|
out_dir = os.path.join(args.output, CATEGORY) |
|
os.makedirs(out_dir, exist_ok=True) |
|
source_camera_path = os.path.join(src_dir, "cameras.json") |
|
|
|
print(f'======== processing {src_dir} ========') |
|
if not os.path.exists(source_camera_path): |
|
print(f'{source_camera_path} is missing!') |
|
sys.exit() |
|
|
|
stop = 0 |
|
stop_with_lower_quality = 0 |
|
stop_with_full_iters = 0 |
|
|
|
cnt = 0 |
|
|
|
if NUM_VIEWS <= 5: |
|
MAX_CNT = 1 |
|
elif NUM_VIEWS < 8: |
|
MAX_CNT = NUM_VIEWS - 4 |
|
else: |
|
MAX_CNT = 4 |
|
|
|
if not args.enable_loop: |
|
MAX_CNT = 0 |
|
|
|
THRESHOLD_LPIPS = 0.05 |
|
metrics_track = {} |
|
cur_dir = os.path.join(out_dir, f'round_{cnt}') |
|
os.makedirs(cur_dir, exist_ok=True) |
|
shutil.copy2(source_camera_path, os.path.join(cur_dir, 'cameras.json')) |
|
|
|
while not stop: |
|
metrics_track[cnt] = [] |
|
|
|
do_reconstruction(cur_dir, out_dir, args) |
|
camera_data = read_updated_cameras(cur_dir) |
|
|
|
lpips_losses, mse_losses = vis_output( |
|
camera_data, |
|
mesh_path=os.path.join(cur_dir, f'{CATEGORY}.{args.mesh_format}'), |
|
save_path=os.path.join(cur_dir, 'vis.png'), |
|
num_views=NUM_VIEWS |
|
) |
|
flags = np.array([int(v["flag"]) for k, v in camera_data.items()]) |
|
|
|
mean_lpips = np.sum(lpips_losses * flags) / flags.sum() |
|
mean_mse = np.sum(mse_losses * flags) / flags.sum() |
|
metrics_track[cnt].append(mean_lpips) |
|
|
|
|
|
|
|
if cnt != 0 and args.enable_loop: |
|
last_lpips = mean_lpips_wo_max |
|
diff_lpips = abs(last_lpips - mean_lpips) |
|
if mean_lpips > last_lpips or diff_lpips < THRESHOLD_LPIPS: |
|
stop_with_lower_quality = 1 |
|
|
|
if cnt >= MAX_CNT: |
|
stop_with_full_iters = 1 |
|
|
|
if stop_with_full_iters or stop_with_lower_quality: |
|
stop = 1 |
|
cnt_to_stop = cnt - 1 if stop_with_lower_quality else cnt |
|
camera_path_to_be_copied = os.path.join(out_dir, f'round_{cnt_to_stop}', 'cameras_updated.json') |
|
shutil.copy2(camera_path_to_be_copied, os.path.join(out_dir, 'cameras_outlier_removal.json')) |
|
save_metrics(metrics_track, out_dir) |
|
|
|
elif not args.enable_loop: |
|
stop = 1 |
|
cnt_to_stop = 0 |
|
save_metrics(metrics_track, out_dir) |
|
|
|
|
|
else: |
|
max_lpips_value = -float('inf') |
|
max_index = -1 |
|
|
|
for i in range(NUM_VIEWS): |
|
if flags[i] == 1 and lpips_losses[i] > max_lpips_value: |
|
max_lpips_value = lpips_losses[i] |
|
max_index = i |
|
|
|
flags[max_index] = 0 |
|
mean_lpips_wo_max = np.sum(lpips_losses * flags) / flags.sum() |
|
metrics_track[cnt].append(mean_lpips_wo_max) |
|
|
|
assert camera_data[list(camera_data.keys())[max_index]]["flag"] == 1 |
|
camera_data[list(camera_data.keys())[max_index]]["flag"] = 0 |
|
|
|
cnt += 1 |
|
cur_dir = os.path.join(out_dir, f'round_{cnt}') |
|
os.makedirs(cur_dir, exist_ok=True) |
|
|
|
|
|
save_updated_cameras(camera_data, cur_dir) |
|
|
|
if cnt_to_stop == 0: |
|
pass |
|
|
|
else: |
|
"""If we identified outliers, do render-and-compare to correct them""" |
|
camera_path_outlier_removal = os.path.join(out_dir, 'cameras_outlier_removal.json') |
|
assert os.path.exists(camera_path_outlier_removal) |
|
with open(camera_path_outlier_removal, 'r') as f: |
|
camera_data_outlier_removal = json.load(f) |
|
|
|
camera_path_render_and_compare = os.path.join(out_dir, 'cameras_render_and_compare.json') |
|
if not os.path.exists(camera_path_render_and_compare): |
|
mesh_path = os.path.join(out_dir, f'round_{cnt_to_stop}', f'{CATEGORY}.{args.mesh_format}') |
|
camera_data_render_and_compare = render_and_compare(copy.deepcopy(camera_data_outlier_removal), mesh_path, out_dir, num_views=NUM_VIEWS) |
|
|
|
with open(camera_path_render_and_compare, 'w') as f: |
|
json.dump(camera_data_render_and_compare, f, indent=4) |
|
|
|
|
|
cur_dir = os.path.join(out_dir, f'check_recovered_poses') |
|
os.makedirs(cur_dir, exist_ok=True) |
|
shutil.copy2(camera_path_render_and_compare, os.path.join(cur_dir, 'cameras.json')) |
|
do_reconstruction(cur_dir, out_dir, args, force_all_views=True) |
|
|
|
camera_data = read_updated_cameras(cur_dir) |
|
|
|
lpips_losses, mse_losses = vis_output( |
|
camera_data, |
|
mesh_path=os.path.join(cur_dir, f'{CATEGORY}.{args.mesh_format}'), |
|
save_path=os.path.join(cur_dir, 'vis.png'), |
|
num_views=args.num_views |
|
) |
|
|
|
|
|
cur_dir = os.path.join(out_dir, f'reconsider_init_poses') |
|
os.makedirs(cur_dir, exist_ok=True) |
|
mesh_path = os.path.join(out_dir, f'round_{cnt_to_stop}', f'{CATEGORY}.{args.mesh_format}') |
|
camera_data_aligned = align_to_mesh(camera_data_outlier_removal, mesh_path, cur_dir, num_views=NUM_VIEWS) |
|
save_updated_cameras(camera_data_aligned, cur_dir) |
|
|
|
|
|
do_reconstruction(cur_dir, out_dir, args, force_all_views=True) |
|
camera_data_init = read_updated_cameras(cur_dir) |
|
|
|
lpips_losses_init, mse_losses_init = vis_output( |
|
camera_data_init, |
|
mesh_path=os.path.join(cur_dir, f'{CATEGORY}.{args.mesh_format}'), |
|
save_path=os.path.join(cur_dir, 'vis.png'), |
|
num_views=NUM_VIEWS |
|
) |
|
|
|
flags_sum = np.array([int(v["flag"]) for k, v in camera_data_outlier_removal.items()]).sum() |
|
cnt_valid_cameras = 0 |
|
keep_init_poses = False |
|
if lpips_losses.mean() > lpips_losses_init.mean(): |
|
keep_init_poses = True |
|
|
|
else: |
|
for idx, (k, v) in enumerate(camera_data_outlier_removal.items()): |
|
if int(v["flag"]) == 1: |
|
continue |
|
|
|
if lpips_losses[idx] < lpips_losses_init[idx] and mse_losses[idx] < mse_losses_init[idx]: |
|
cnt_valid_cameras += 1 |
|
|
|
|
|
if cnt_valid_cameras + flags_sum == NUM_VIEWS: |
|
keep_init_poses = False |
|
else: |
|
keep_init_poses = True |
|
|
|
output_path = os.path.join(out_dir, 'cameras_final.json') |
|
if keep_init_poses: |
|
print("Keep the (optimized) initial camera poses.") |
|
with open(output_path.replace(".json", "_init.json"), 'w') as f: |
|
json.dump(camera_data_init, f, indent=4) |
|
else: |
|
print("Replace the initial cameras with the recovered ones!") |
|
with open(output_path.replace(".json", "_recovered.json"), 'w') as f: |
|
json.dump(camera_data, f, indent=4) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
seed_everything(0) |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--output', default='output/demo', type=str, help='Directory where obj files will be saved') |
|
parser.add_argument('--category', default='jordan', type=str, help='Directory where obj files will be saved') |
|
parser.add_argument('--num_pts', default=25000, type=int, help='Number of points at initialization') |
|
parser.add_argument('--num_views', default=8, type=int, help='Number of input images') |
|
parser.add_argument('--mesh_format', default='obj', type=str, help='Format of output mesh') |
|
parser.add_argument('--enable_loop', action='store_true', help='Enable the loop-based strategy to detect and correct outliers') |
|
parser.add_argument('--config', default='navi.yaml', type=str, help='Path to config file') |
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|
|
|
|
|