pminervini's picture
update
7b5f39c
raw
history blame
12.5 kB
#!/usr/bin/env python3
import os
import sys
import json
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import linkage
from src.backend.envs import Tasks, EVAL_REQUESTS_PATH_BACKEND, EVAL_RESULTS_PATH_BACKEND, DEVICE, LIMIT, Task
from src.envs import QUEUE_REPO, RESULTS_REPO, API
from src.utils import my_snapshot_download
def is_float(string):
try:
float(string)
return True
except ValueError:
return False
def find_json_files(json_path):
res = []
for root, dirs, files in os.walk(json_path):
for file in files:
if file.endswith(".json"):
res.append(os.path.join(root, file))
return res
def sanitise_metric(name: str) -> str:
res = name
res = res.replace("prompt_level_strict_acc", "Prompt-Level Accuracy")
res = res.replace("acc", "Accuracy")
res = res.replace("exact_match", "EM")
res = res.replace("avg-selfcheckgpt", "AVG")
res = res.replace("max-selfcheckgpt", "MAX")
res = res.replace("rouge", "ROUGE-")
res = res.replace("bertscore_precision", "BERT-P")
res = res.replace("exact", "EM")
res = res.replace("HasAns_EM", "HasAns")
res = res.replace("NoAns_EM", "NoAns")
res = res.replace("em", "EM")
return res
def sanitise_dataset(name: str) -> str:
res = name
res = res.replace("tqa8", "TriviaQA (8-shot)")
res = res.replace("nq8", "NQ (8-shot)")
res = res.replace("nq_open", "NQ (64-shot)")
res = res.replace("triviaqa", "TriviaQA (64-shot)")
res = res.replace("truthfulqa", "TruthfulQA")
res = res.replace("ifeval", "IFEval")
res = res.replace("selfcheckgpt", "SelfCheckGPT")
res = res.replace("truefalse_cieacf", "True-False")
res = res.replace("mc", "MC")
res = res.replace("race", "RACE")
res = res.replace("squad", "SQuAD")
res = res.replace("memo-trap", "MemoTrap")
res = res.replace("cnndm", "CNN/DM")
res = res.replace("xsum", "XSum")
res = res.replace("qa", "QA")
res = res.replace("summarization", "Summarization")
res = res.replace("dialogue", "Dialog")
res = res.replace("halueval", "HaluEval")
res = res.replace("_", " ")
return res
cache_file = 'data_map_cache.pkl'
def load_data_map_from_cache(cache_file):
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
return pickle.load(f)
else:
return None
def save_data_map_to_cache(data_map, cache_file):
with open(cache_file, 'wb') as f:
pickle.dump(data_map, f)
# Try to load the data_map from the cache file
data_map = load_data_map_from_cache(cache_file)
if data_map is None:
my_snapshot_download(repo_id=RESULTS_REPO, revision="main", local_dir=EVAL_RESULTS_PATH_BACKEND, repo_type="dataset", max_workers=60)
my_snapshot_download(repo_id=QUEUE_REPO, revision="main", local_dir=EVAL_REQUESTS_PATH_BACKEND, repo_type="dataset", max_workers=60)
result_path_lst = find_json_files(EVAL_RESULTS_PATH_BACKEND)
request_path_lst = find_json_files(EVAL_REQUESTS_PATH_BACKEND)
model_name_to_model_map = {}
for path in request_path_lst:
with open(path, 'r') as f:
data = json.load(f)
model_name_to_model_map[data["model"]] = data
model_dataset_metric_to_result_map = {}
# data_map[model_name][(dataset_name, sanitised_metric_name)] = value
data_map = {}
for path in result_path_lst:
with open(path, 'r') as f:
data = json.load(f)
model_name = data["config"]["model_name"]
for dataset_name, results_dict in data["results"].items():
for metric_name, value in results_dict.items():
if model_name_to_model_map[model_name]["likes"] > 128:
to_add = True
if 'f1' in metric_name:
to_add = False
if 'stderr' in metric_name:
to_add = False
if 'memo-trap_v2' in dataset_name:
to_add = False
if 'selfcheck' in dataset_name:
# if 'max' in metric_name:
# to_add = False
pass
if 'faithdial' in dataset_name:
to_add = False
if 'truthfulqa_gen' in dataset_name:
to_add = False
if 'bertscore' in metric_name:
if 'precision' not in metric_name:
to_add = False
if 'halueval' in dataset_name:
if 'acc' not in metric_name:
to_add = False
if 'ifeval' in dataset_name:
if 'prompt_level_strict_acc' not in metric_name:
to_add = False
if 'squad' in dataset_name:
# to_add = False
if 'best_exact' in metric_name:
to_add = False
if 'fever' in dataset_name:
to_add = False
if ('xsum' in dataset_name or 'cnn' in dataset_name) and 'v2' in dataset_name:
to_add = False
if isinstance(value, str):
if is_float(value):
value = float(value)
else:
to_add = False
if to_add:
if 'rouge' in metric_name:
value /= 100.0
if 'squad' in dataset_name:
value /= 100.0
sanitised_metric_name = metric_name
if "," in sanitised_metric_name:
sanitised_metric_name = sanitised_metric_name.split(',')[0]
sanitised_metric_name = sanitise_metric(sanitised_metric_name)
sanitised_dataset_name = sanitise_dataset(dataset_name)
model_dataset_metric_to_result_map[(model_name, sanitised_dataset_name, sanitised_metric_name)] = value
if model_name not in data_map:
data_map[model_name] = {}
data_map[model_name][(sanitised_dataset_name, sanitised_metric_name)] = value
print('model_name', model_name, 'dataset_name', sanitised_dataset_name, 'metric_name', sanitised_metric_name, 'value', value)
save_data_map_to_cache(data_map, cache_file)
model_name_lst = [m for m in data_map.keys()]
nb_max_metrics = max(len(data_map[model_name]) for model_name in model_name_lst)
for model_name in model_name_lst:
if len(data_map[model_name]) < nb_max_metrics - 5:
del data_map[model_name]
plot_type_lst = ['all', 'summ', 'qa', 'instr', 'detect', 'rc']
for plot_type in plot_type_lst:
data_map_v2 = {}
for model_name in data_map.keys():
for dataset_metric in data_map[model_name].keys():
if dataset_metric not in data_map_v2:
data_map_v2[dataset_metric] = {}
if plot_type in {'all'}:
to_add = True
if 'ROUGE' in dataset_metric[1] and 'ROUGE-L' not in dataset_metric[1]:
to_add = False
if 'SQuAD' in dataset_metric[0] and 'EM' not in dataset_metric[1]:
to_add = False
if 'SelfCheckGPT' in dataset_metric[0] and 'MAX' not in dataset_metric[1]:
to_add = False
if '64-shot' in dataset_metric[0]:
to_add = False
if to_add is True:
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric]
elif plot_type in {'summ'}:
if 'CNN' in dataset_metric[0] or 'XSum' in dataset_metric[0]:
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric]
elif plot_type in {'qa'}:
if 'TriviaQA' in dataset_metric[0] or 'NQ' in dataset_metric[0] or 'TruthfulQA' in dataset_metric[0]:
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric]
elif plot_type in {'instr'}:
if 'MemoTrap' in dataset_metric[0] or 'IFEval' in dataset_metric[0]:
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric]
elif plot_type in {'detect'}:
if 'HaluEval' in dataset_metric[0] or 'SelfCheck' in dataset_metric[0]:
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric]
elif plot_type in {'rc'}:
if 'RACE' in dataset_metric[0] or 'SQuAD' in dataset_metric[0]:
data_map_v2[dataset_metric][model_name] = data_map[model_name][dataset_metric]
else:
assert False, f"Unknown plot type: {plot_type}"
# df = pd.DataFrame.from_dict(data_map, orient='index') # Invert the y-axis (rows)
df = pd.DataFrame.from_dict(data_map_v2, orient='index') # Invert the y-axis (rows)
df.index = [', '.join(map(str, idx)) for idx in df.index]
o_df = df.copy(deep=True)
# breakpoint()
print(df)
# Check for NaN or infinite values and replace them
df.replace([np.inf, -np.inf], np.nan, inplace=True) # Replace infinities with NaN
df.fillna(0, inplace=True) # Replace NaN with 0 (or use another imputation strategy)
from sklearn.preprocessing import MinMaxScaler
# scaler = MinMaxScaler()
# df = pd.DataFrame(scaler.fit_transform(df), index=df.index, columns=df.columns)
# Calculate dimensions based on the DataFrame size
cell_height = 1.0 # Height of each cell in inches
cell_width = 1.0 # Width of each cell in inches
n_rows = len(df.index) # Datasets and Metrics
n_cols = len(df.columns) # Models
# Calculate figure size dynamically
fig_width = cell_width * n_cols + 0
fig_height = cell_height * n_rows + 0
col_cluster = True
row_cluster = True
sns.set_context("notebook", font_scale=1.3)
dendrogram_ratio = (.1, .1)
if plot_type in {'detect'}:
fig_width = cell_width * n_cols - 2
fig_height = cell_height * n_rows + 5.2
dendrogram_ratio = (.1, .2)
if plot_type in {'instr'}:
fig_width = cell_width * n_cols - 2
fig_height = cell_height * n_rows + 5.2
dendrogram_ratio = (.1, .4)
if plot_type in {'qa'}:
fig_width = cell_width * n_cols - 2
fig_height = cell_height * n_rows + 4
dendrogram_ratio = (.1, .2)
if plot_type in {'summ'}:
fig_width = cell_width * n_cols - 2
fig_height = cell_height * n_rows + 2.0
dendrogram_ratio = (.1, .1)
row_cluster = False
if plot_type in {'rc'}:
fig_width = cell_width * n_cols - 2
fig_height = cell_height * n_rows + 5.2
dendrogram_ratio = (.1, .4)
print('figsize', (fig_width, fig_height))
o_df.to_json(f'plots/clustermap_{plot_type}.json', orient='split')
print(f'Generating the clustermaps for {plot_type}')
for cmap in [None, 'coolwarm', 'viridis']:
fig = sns.clustermap(df,
method='ward',
metric='euclidean',
cmap=cmap,
figsize=(fig_width, fig_height), # figsize=(24, 16),
annot=True,
mask=o_df.isnull(),
dendrogram_ratio=dendrogram_ratio,
fmt='.2f',
col_cluster=col_cluster,
row_cluster=row_cluster)
# Adjust the size of the cells (less wide)
plt.setp(fig.ax_heatmap.get_yticklabels(), rotation=0)
plt.setp(fig.ax_heatmap.get_xticklabels(), rotation=90)
cmap_suffix = '' if cmap is None else f'_{cmap}'
# Save the clustermap to file
fig.savefig(f'blog/figures/clustermap_{plot_type}{cmap_suffix}.pdf')
fig.savefig(f'blog/figures/clustermap_{plot_type}{cmap_suffix}.png')
fig.savefig(f'blog/figures/clustermap_{plot_type}{cmap_suffix}_t.png', transparent=True, facecolor="none")