|
import sys |
|
from datetime import datetime |
|
import numpy as np |
|
import torch |
|
from utils.metrics import * |
|
from collections import OrderedDict |
|
import os |
|
|
|
|
|
|
|
def evaluate_matching_score(eval_wrapper,motion_loaders, file): |
|
match_score_dict = OrderedDict({}) |
|
R_precision_dict = OrderedDict({}) |
|
activation_dict = OrderedDict({}) |
|
|
|
print('========== Evaluating Matching Score ==========') |
|
for motion_loader_name, motion_loader in motion_loaders.items(): |
|
all_motion_embeddings = [] |
|
score_list = [] |
|
all_size = 0 |
|
matching_score_sum = 0 |
|
top_k_count = 0 |
|
|
|
with torch.no_grad(): |
|
for idx, batch in enumerate(motion_loader): |
|
word_embeddings, pos_one_hots, _, sent_lens, motions, m_lens, _ = batch |
|
text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings( |
|
word_embs=word_embeddings, |
|
pos_ohot=pos_one_hots, |
|
cap_lens=sent_lens, |
|
motions=motions, |
|
m_lens=m_lens |
|
) |
|
dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(), |
|
motion_embeddings.cpu().numpy()) |
|
matching_score_sum += dist_mat.trace() |
|
|
|
|
|
argsmax = np.argsort(dist_mat, axis=1) |
|
top_k_mat = calculate_top_k(argsmax, top_k=3) |
|
top_k_count += top_k_mat.sum(axis=0) |
|
|
|
all_size += text_embeddings.shape[0] |
|
|
|
all_motion_embeddings.append(motion_embeddings.cpu().numpy()) |
|
|
|
all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0) |
|
|
|
matching_score = matching_score_sum / all_size |
|
R_precision = top_k_count / all_size |
|
match_score_dict[motion_loader_name] = matching_score |
|
R_precision_dict[motion_loader_name] = R_precision |
|
activation_dict[motion_loader_name] = all_motion_embeddings |
|
|
|
print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}') |
|
print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}', file=file, flush=True) |
|
|
|
line = f'---> [{motion_loader_name}] R_precision: ' |
|
for i in range(len(R_precision)): |
|
line += '(top %d): %.4f ' % (i+1, R_precision[i]) |
|
print(line) |
|
print(line, file=file, flush=True) |
|
|
|
return match_score_dict, R_precision_dict, activation_dict |
|
|
|
|
|
def evaluate_fid(eval_wrapper,groundtruth_loader, activation_dict, file): |
|
eval_dict = OrderedDict({}) |
|
gt_motion_embeddings = [] |
|
print('========== Evaluating FID ==========') |
|
with torch.no_grad(): |
|
for idx, batch in enumerate(groundtruth_loader): |
|
_, _, _, sent_lens, motions, m_lens, _ = batch |
|
motion_embeddings = eval_wrapper.get_motion_embeddings( |
|
motions=motions, |
|
m_lens=m_lens |
|
) |
|
gt_motion_embeddings.append(motion_embeddings.cpu().numpy()) |
|
gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0) |
|
gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings) |
|
|
|
for model_name, motion_embeddings in activation_dict.items(): |
|
mu, cov = calculate_activation_statistics(motion_embeddings) |
|
|
|
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) |
|
print(f'---> [{model_name}] FID: {fid:.4f}') |
|
print(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True) |
|
eval_dict[model_name] = fid |
|
return eval_dict |
|
|
|
|
|
def evaluate_diversity(activation_dict, file, diversity_times): |
|
eval_dict = OrderedDict({}) |
|
print('========== Evaluating Diversity ==========') |
|
for model_name, motion_embeddings in activation_dict.items(): |
|
diversity = calculate_diversity(motion_embeddings, diversity_times) |
|
eval_dict[model_name] = diversity |
|
print(f'---> [{model_name}] Diversity: {diversity:.4f}') |
|
print(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True) |
|
return eval_dict |
|
|
|
|
|
def evaluate_multimodality(eval_wrapper, mm_motion_loaders, file, mm_num_times): |
|
eval_dict = OrderedDict({}) |
|
print('========== Evaluating MultiModality ==========') |
|
for model_name, mm_motion_loader in mm_motion_loaders.items(): |
|
mm_motion_embeddings = [] |
|
with torch.no_grad(): |
|
for idx, batch in enumerate(mm_motion_loader): |
|
|
|
motions, m_lens = batch |
|
motion_embedings = eval_wrapper.get_motion_embeddings(motions[0], m_lens[0]) |
|
mm_motion_embeddings.append(motion_embedings.unsqueeze(0)) |
|
if len(mm_motion_embeddings) == 0: |
|
multimodality = 0 |
|
else: |
|
mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy() |
|
multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times) |
|
print(f'---> [{model_name}] Multimodality: {multimodality:.4f}') |
|
print(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True) |
|
eval_dict[model_name] = multimodality |
|
return eval_dict |
|
|
|
|
|
def get_metric_statistics(values, replication_times): |
|
mean = np.mean(values, axis=0) |
|
std = np.std(values, axis=0) |
|
conf_interval = 1.96 * std / np.sqrt(replication_times) |
|
return mean, conf_interval |
|
|
|
|
|
def evaluation(eval_wrapper, gt_loader, eval_motion_loaders, log_file, replication_times, diversity_times, mm_num_times, run_mm=False): |
|
with open(log_file, 'a') as f: |
|
all_metrics = OrderedDict({'Matching Score': OrderedDict({}), |
|
'R_precision': OrderedDict({}), |
|
'FID': OrderedDict({}), |
|
'Diversity': OrderedDict({}), |
|
'MultiModality': OrderedDict({})}) |
|
|
|
for replication in range(replication_times): |
|
print(f'Time: {datetime.now()}') |
|
print(f'Time: {datetime.now()}', file=f, flush=True) |
|
motion_loaders = {} |
|
motion_loaders['ground truth'] = gt_loader |
|
mm_motion_loaders = {} |
|
|
|
for motion_loader_name, motion_loader_getter in eval_motion_loaders.items(): |
|
motion_loader, mm_motion_loader,eval_generate_time = motion_loader_getter() |
|
print(f'---> [{motion_loader_name}] batch_generate_time: {eval_generate_time}s', file=f, flush=True) |
|
motion_loaders[motion_loader_name] = motion_loader |
|
mm_motion_loaders[motion_loader_name] = mm_motion_loader |
|
|
|
if replication_times>1: |
|
print(f'==================== Replication {replication} ====================') |
|
print(f'==================== Replication {replication} ====================', file=f, flush=True) |
|
|
|
|
|
mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(eval_wrapper, motion_loaders, f) |
|
|
|
fid_score_dict = evaluate_fid(eval_wrapper, gt_loader, acti_dict, f) |
|
|
|
div_score_dict = evaluate_diversity(acti_dict, f, diversity_times) |
|
|
|
if run_mm: |
|
mm_score_dict = evaluate_multimodality(eval_wrapper, mm_motion_loaders, f, mm_num_times) |
|
|
|
print(f'!!! DONE !!!') |
|
print(f'!!! DONE !!!', file=f, flush=True) |
|
|
|
for key, item in mat_score_dict.items(): |
|
if key not in all_metrics['Matching Score']: |
|
all_metrics['Matching Score'][key] = [item] |
|
else: |
|
all_metrics['Matching Score'][key] += [item] |
|
|
|
for key, item in R_precision_dict.items(): |
|
if key not in all_metrics['R_precision']: |
|
all_metrics['R_precision'][key] = [item] |
|
else: |
|
all_metrics['R_precision'][key] += [item] |
|
|
|
for key, item in fid_score_dict.items(): |
|
if key not in all_metrics['FID']: |
|
all_metrics['FID'][key] = [item] |
|
else: |
|
all_metrics['FID'][key] += [item] |
|
|
|
for key, item in div_score_dict.items(): |
|
if key not in all_metrics['Diversity']: |
|
all_metrics['Diversity'][key] = [item] |
|
else: |
|
all_metrics['Diversity'][key] += [item] |
|
|
|
for key, item in mm_score_dict.items(): |
|
if key not in all_metrics['MultiModality']: |
|
all_metrics['MultiModality'][key] = [item] |
|
else: |
|
all_metrics['MultiModality'][key] += [item] |
|
|
|
|
|
mean_dict = {} |
|
if replication_times>1: |
|
for metric_name, metric_dict in all_metrics.items(): |
|
print('========== %s Summary ==========' % metric_name) |
|
print('========== %s Summary ==========' % metric_name, file=f, flush=True) |
|
|
|
for model_name, values in metric_dict.items(): |
|
|
|
mean, conf_interval = get_metric_statistics(np.array(values),replication_times) |
|
mean_dict[metric_name + '_' + model_name] = mean |
|
|
|
if isinstance(mean, np.float64) or isinstance(mean, np.float32): |
|
print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}') |
|
print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True) |
|
elif isinstance(mean, np.ndarray): |
|
line = f'---> [{model_name}]' |
|
for i in range(len(mean)): |
|
line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i]) |
|
print(line) |
|
print(line, file=f, flush=True) |
|
return mean_dict |
|
else: |
|
return all_metrics |
|
|
|
|
|
def distributed_evaluation(eval_wrapper, gt_loader, eval_motion_loader, log_file, replication_times, diversity_times): |
|
with open(log_file, 'a') as f: |
|
all_metrics = OrderedDict({'Matching Score': OrderedDict({}), |
|
'R_precision': OrderedDict({}), |
|
'FID': OrderedDict({}), |
|
'Diversity': OrderedDict({}), |
|
'MultiModality': OrderedDict({})}) |
|
|
|
for replication in range(replication_times): |
|
print(f'Time: {datetime.now()}') |
|
print(f'Time: {datetime.now()}', file=f, flush=True) |
|
motion_loaders = {'test':eval_motion_loader} |
|
|
|
if replication_times>1: |
|
print(f'==================== Replication {replication} ====================') |
|
print(f'==================== Replication {replication} ====================', file=f, flush=True) |
|
|
|
|
|
mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(eval_wrapper, motion_loaders, f) |
|
|
|
fid_score_dict = evaluate_fid(eval_wrapper, gt_loader, acti_dict, f) |
|
|
|
div_score_dict = evaluate_diversity(acti_dict, f, diversity_times) |
|
|
|
|
|
print(f'!!! DONE !!!') |
|
print(f'!!! DONE !!!', file=f, flush=True) |
|
|
|
for key, item in mat_score_dict.items(): |
|
if key not in all_metrics['Matching Score']: |
|
all_metrics['Matching Score'][key] = [item] |
|
else: |
|
all_metrics['Matching Score'][key] += [item] |
|
|
|
for key, item in R_precision_dict.items(): |
|
if key not in all_metrics['R_precision']: |
|
all_metrics['R_precision'][key] = [item] |
|
else: |
|
all_metrics['R_precision'][key] += [item] |
|
|
|
for key, item in fid_score_dict.items(): |
|
if key not in all_metrics['FID']: |
|
all_metrics['FID'][key] = [item] |
|
else: |
|
all_metrics['FID'][key] += [item] |
|
|
|
for key, item in div_score_dict.items(): |
|
if key not in all_metrics['Diversity']: |
|
all_metrics['Diversity'][key] = [item] |
|
else: |
|
all_metrics['Diversity'][key] += [item] |
|
|
|
mean_dict = {} |
|
for metric_name, metric_dict in all_metrics.items(): |
|
print('========== %s Summary ==========' % metric_name) |
|
print('========== %s Summary ==========' % metric_name, file=f, flush=True) |
|
|
|
for model_name, values in metric_dict.items(): |
|
|
|
mean, conf_interval = get_metric_statistics(np.array(values),replication_times) |
|
mean_dict[metric_name + '_' + model_name] = mean |
|
|
|
if isinstance(mean, np.float64) or isinstance(mean, np.float32): |
|
print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}') |
|
print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True) |
|
elif isinstance(mean, np.ndarray): |
|
line = f'---> [{model_name}]' |
|
for i in range(len(mean)): |
|
line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i]) |
|
print(line) |
|
print(line, file=f, flush=True) |
|
return mean_dict |
|
|
|
|