Spaces:
Running
Running
# To compute FID, first install pytorch_fid | |
# pip install pytorch-fid | |
import os | |
import cv2 as cv | |
from tqdm import tqdm | |
import shutil | |
from eval.score import * | |
cam_id = 18 | |
ours_dir = './test_results/subject00/styleunet_gaussians3/testing__cam_%03d/batch_750000/rgb_map' % cam_id | |
posevocab_dir = './test_results/subject00/posevocab/testing__cam_%03d/rgb_map' % cam_id | |
tava_dir = './test_results/subject00/tava/cam_%03d' % cam_id | |
arah_dir = './test_results/subject00/arah/cam_%03d' % cam_id | |
slrf_dir = './test_results/subject00/slrf/cam_%03d' % cam_id | |
gt_dir = 'Z:/MultiviewRGB/THuman4/subject00/images/cam%02d' % cam_id | |
mask_dir = 'Z:/MultiviewRGB/THuman4/subject00/masks/cam%02d' % cam_id | |
frame_list = list(range(2000, 2500, 1)) | |
if __name__ == '__main__': | |
ours_metrics = Metrics() | |
posevocab_metrics = Metrics() | |
slrf_metrics = Metrics() | |
arah_metrics = Metrics() | |
tava_metrics = Metrics() | |
shutil.rmtree('./tmp_quant') | |
os.makedirs('./tmp_quant/ours', exist_ok = True) | |
os.makedirs('./tmp_quant/posevocab', exist_ok = True) | |
os.makedirs('./tmp_quant/slrf', exist_ok = True) | |
os.makedirs('./tmp_quant/arah', exist_ok = True) | |
os.makedirs('./tmp_quant/tava', exist_ok = True) | |
os.makedirs('./tmp_quant/gt', exist_ok = True) | |
for frame_id in tqdm(frame_list): | |
ours_img = (cv.imread(ours_dir + '/%08d.jpg' % frame_id, cv.IMREAD_UNCHANGED) / 255.).astype(np.float32) | |
posevocab_img = (cv.imread(posevocab_dir + '/%08d.jpg' % frame_id, cv.IMREAD_UNCHANGED) / 255.).astype(np.float32) | |
slrf_img = (cv.imread(slrf_dir + '/%08d.png' % frame_id, cv.IMREAD_UNCHANGED) / 255.).astype(np.float32) | |
tava_img = (cv.imread(tava_dir + '/%d.jpg' % frame_id, cv.IMREAD_UNCHANGED) / 255.).astype(np.float32) | |
arah_img = (cv.imread(arah_dir + '/%d.jpg' % frame_id, cv.IMREAD_UNCHANGED) / 255.).astype(np.float32) | |
gt_img = (cv.imread(gt_dir + '/%08d.jpg' % frame_id, cv.IMREAD_UNCHANGED) / 255.).astype(np.float32) | |
mask_img = cv.imread(mask_dir + '/%08d.jpg' % frame_id, cv.IMREAD_UNCHANGED) > 128 | |
gt_img[~mask_img] = 1. | |
ours_img_cropped, posevocab_img_cropped, slrf_img_cropped, tava_img_cropped, arah_img_cropped, gt_img_cropped = \ | |
crop_image( | |
mask_img, | |
512, | |
ours_img, | |
posevocab_img, | |
slrf_img, | |
tava_img, | |
arah_img, | |
gt_img | |
) | |
cv.imwrite('./tmp_quant/ours/%08d.png' % frame_id, (ours_img_cropped * 255).astype(np.uint8)) | |
cv.imwrite('./tmp_quant/posevocab/%08d.png' % frame_id, (posevocab_img_cropped * 255).astype(np.uint8)) | |
cv.imwrite('./tmp_quant/slrf/%08d.png' % frame_id, (slrf_img_cropped * 255).astype(np.uint8)) | |
cv.imwrite('./tmp_quant/tava/%08d.png' % frame_id, (tava_img_cropped * 255).astype(np.uint8)) | |
cv.imwrite('./tmp_quant/arah/%08d.png' % frame_id, (arah_img_cropped * 255).astype(np.uint8)) | |
cv.imwrite('./tmp_quant/gt/%08d.png' % frame_id, (gt_img_cropped * 255).astype(np.uint8)) | |
if ours_img is not None: | |
ours_metrics.psnr += compute_psnr(ours_img, gt_img) | |
ours_metrics.ssim += compute_ssim(ours_img, gt_img) | |
ours_metrics.lpips += compute_lpips(ours_img_cropped, gt_img_cropped) | |
ours_metrics.count += 1 | |
if posevocab_img is not None: | |
posevocab_metrics.psnr += compute_psnr(posevocab_img, gt_img) | |
posevocab_metrics.ssim += compute_ssim(posevocab_img, gt_img) | |
posevocab_metrics.lpips += compute_lpips(posevocab_img_cropped, gt_img_cropped) | |
posevocab_metrics.count += 1 | |
if slrf_img is not None: | |
slrf_metrics.psnr += compute_psnr(slrf_img, gt_img) | |
slrf_metrics.ssim += compute_ssim(slrf_img, gt_img) | |
slrf_metrics.lpips += compute_lpips(slrf_img_cropped, gt_img_cropped) | |
slrf_metrics.count += 1 | |
if arah_img is not None: | |
arah_metrics.psnr += compute_psnr(arah_img, gt_img) | |
arah_metrics.ssim += compute_ssim(arah_img, gt_img) | |
arah_metrics.lpips += compute_lpips(arah_img_cropped, gt_img_cropped) | |
arah_metrics.count += 1 | |
if tava_img is not None: | |
tava_metrics.psnr += compute_psnr(tava_img, gt_img) | |
tava_metrics.ssim += compute_ssim(tava_img, gt_img) | |
tava_metrics.lpips += compute_lpips(tava_img_cropped, gt_img_cropped) | |
tava_metrics.count += 1 | |
print('Ours metrics: ', ours_metrics) | |
print('PoseVocab metrics: ', posevocab_metrics) | |
print('SLRF metrics: ', slrf_metrics) | |
print('ARAH metrics: ', arah_metrics) | |
print('TAVA metrics: ', tava_metrics) | |
print('--- Ours ---') | |
os.system('python -m pytorch_fid --device cuda {} {}'.format('./tmp_quant/ours', './tmp_quant/gt')) | |
print('--- PoseVocab ---') | |
os.system('python -m pytorch_fid --device cuda {} {}'.format('./tmp_quant/posevocab', './tmp_quant/gt')) | |
print('--- SLRF ---') | |
os.system('python -m pytorch_fid --device cuda {} {}'.format('./tmp_quant/slrf', './tmp_quant/gt')) | |
print('--- ARAH ---') | |
os.system('python -m pytorch_fid --device cuda {} {}'.format('./tmp_quant/arah', './tmp_quant/gt')) | |
print('--- TAVA ---') | |
os.system('python -m pytorch_fid --device cuda {} {}'.format('./tmp_quant/tava', './tmp_quant/gt')) | |