# To compute FID, first install pytorch_fid # pip install pytorch-fid import os import cv2 as cv from tqdm import tqdm import shutil from eval.score import * cam_id = 18 ours_dir = './test_results/subject00/styleunet_gaussians3/testing__cam_%03d/batch_750000/rgb_map' % cam_id posevocab_dir = './test_results/subject00/posevocab/testing__cam_%03d/rgb_map' % cam_id tava_dir = './test_results/subject00/tava/cam_%03d' % cam_id arah_dir = './test_results/subject00/arah/cam_%03d' % cam_id slrf_dir = './test_results/subject00/slrf/cam_%03d' % cam_id gt_dir = 'Z:/MultiviewRGB/THuman4/subject00/images/cam%02d' % cam_id mask_dir = 'Z:/MultiviewRGB/THuman4/subject00/masks/cam%02d' % cam_id frame_list = list(range(2000, 2500, 1)) if __name__ == '__main__': ours_metrics = Metrics() posevocab_metrics = Metrics() slrf_metrics = Metrics() arah_metrics = Metrics() tava_metrics = Metrics() shutil.rmtree('./tmp_quant') os.makedirs('./tmp_quant/ours', exist_ok = True) os.makedirs('./tmp_quant/posevocab', exist_ok = True) os.makedirs('./tmp_quant/slrf', exist_ok = True) os.makedirs('./tmp_quant/arah', exist_ok = True) os.makedirs('./tmp_quant/tava', exist_ok = True) os.makedirs('./tmp_quant/gt', exist_ok = True) for frame_id in tqdm(frame_list): ours_img = (cv.imread(ours_dir + '/%08d.jpg' % frame_id, cv.IMREAD_UNCHANGED) / 255.).astype(np.float32) posevocab_img = (cv.imread(posevocab_dir + '/%08d.jpg' % frame_id, cv.IMREAD_UNCHANGED) / 255.).astype(np.float32) slrf_img = (cv.imread(slrf_dir + '/%08d.png' % frame_id, cv.IMREAD_UNCHANGED) / 255.).astype(np.float32) tava_img = (cv.imread(tava_dir + '/%d.jpg' % frame_id, cv.IMREAD_UNCHANGED) / 255.).astype(np.float32) arah_img = (cv.imread(arah_dir + '/%d.jpg' % frame_id, cv.IMREAD_UNCHANGED) / 255.).astype(np.float32) gt_img = (cv.imread(gt_dir + '/%08d.jpg' % frame_id, cv.IMREAD_UNCHANGED) / 255.).astype(np.float32) mask_img = cv.imread(mask_dir + '/%08d.jpg' % frame_id, cv.IMREAD_UNCHANGED) > 128 gt_img[~mask_img] = 1. ours_img_cropped, posevocab_img_cropped, slrf_img_cropped, tava_img_cropped, arah_img_cropped, gt_img_cropped = \ crop_image( mask_img, 512, ours_img, posevocab_img, slrf_img, tava_img, arah_img, gt_img ) cv.imwrite('./tmp_quant/ours/%08d.png' % frame_id, (ours_img_cropped * 255).astype(np.uint8)) cv.imwrite('./tmp_quant/posevocab/%08d.png' % frame_id, (posevocab_img_cropped * 255).astype(np.uint8)) cv.imwrite('./tmp_quant/slrf/%08d.png' % frame_id, (slrf_img_cropped * 255).astype(np.uint8)) cv.imwrite('./tmp_quant/tava/%08d.png' % frame_id, (tava_img_cropped * 255).astype(np.uint8)) cv.imwrite('./tmp_quant/arah/%08d.png' % frame_id, (arah_img_cropped * 255).astype(np.uint8)) cv.imwrite('./tmp_quant/gt/%08d.png' % frame_id, (gt_img_cropped * 255).astype(np.uint8)) if ours_img is not None: ours_metrics.psnr += compute_psnr(ours_img, gt_img) ours_metrics.ssim += compute_ssim(ours_img, gt_img) ours_metrics.lpips += compute_lpips(ours_img_cropped, gt_img_cropped) ours_metrics.count += 1 if posevocab_img is not None: posevocab_metrics.psnr += compute_psnr(posevocab_img, gt_img) posevocab_metrics.ssim += compute_ssim(posevocab_img, gt_img) posevocab_metrics.lpips += compute_lpips(posevocab_img_cropped, gt_img_cropped) posevocab_metrics.count += 1 if slrf_img is not None: slrf_metrics.psnr += compute_psnr(slrf_img, gt_img) slrf_metrics.ssim += compute_ssim(slrf_img, gt_img) slrf_metrics.lpips += compute_lpips(slrf_img_cropped, gt_img_cropped) slrf_metrics.count += 1 if arah_img is not None: arah_metrics.psnr += compute_psnr(arah_img, gt_img) arah_metrics.ssim += compute_ssim(arah_img, gt_img) arah_metrics.lpips += compute_lpips(arah_img_cropped, gt_img_cropped) arah_metrics.count += 1 if tava_img is not None: tava_metrics.psnr += compute_psnr(tava_img, gt_img) tava_metrics.ssim += compute_ssim(tava_img, gt_img) tava_metrics.lpips += compute_lpips(tava_img_cropped, gt_img_cropped) tava_metrics.count += 1 print('Ours metrics: ', ours_metrics) print('PoseVocab metrics: ', posevocab_metrics) print('SLRF metrics: ', slrf_metrics) print('ARAH metrics: ', arah_metrics) print('TAVA metrics: ', tava_metrics) print('--- Ours ---') os.system('python -m pytorch_fid --device cuda {} {}'.format('./tmp_quant/ours', './tmp_quant/gt')) print('--- PoseVocab ---') os.system('python -m pytorch_fid --device cuda {} {}'.format('./tmp_quant/posevocab', './tmp_quant/gt')) print('--- SLRF ---') os.system('python -m pytorch_fid --device cuda {} {}'.format('./tmp_quant/slrf', './tmp_quant/gt')) print('--- ARAH ---') os.system('python -m pytorch_fid --device cuda {} {}'.format('./tmp_quant/arah', './tmp_quant/gt')) print('--- TAVA ---') os.system('python -m pytorch_fid --device cuda {} {}'.format('./tmp_quant/tava', './tmp_quant/gt'))