Spaces:
Runtime error
Runtime error
pminervini
commited on
Commit
•
b06387f
1
Parent(s):
6524ea0
update
Browse files- cli/analysis-cli.py +136 -0
cli/analysis-cli.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import json
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import pandas as pd
|
10 |
+
import seaborn as sns
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
|
13 |
+
from scipy.cluster.hierarchy import linkage
|
14 |
+
|
15 |
+
from src.backend.envs import Tasks, EVAL_REQUESTS_PATH_BACKEND, EVAL_RESULTS_PATH_BACKEND, DEVICE, LIMIT, Task
|
16 |
+
|
17 |
+
from src.envs import QUEUE_REPO, RESULTS_REPO, API
|
18 |
+
from src.utils import my_snapshot_download
|
19 |
+
|
20 |
+
|
21 |
+
def find_json_files(json_path):
|
22 |
+
res = []
|
23 |
+
for root, dirs, files in os.walk(json_path):
|
24 |
+
for file in files:
|
25 |
+
if file.endswith(".json"):
|
26 |
+
res.append(os.path.join(root, file))
|
27 |
+
return res
|
28 |
+
|
29 |
+
|
30 |
+
my_snapshot_download(repo_id=RESULTS_REPO, revision="main", local_dir=EVAL_RESULTS_PATH_BACKEND, repo_type="dataset", max_workers=60)
|
31 |
+
my_snapshot_download(repo_id=QUEUE_REPO, revision="main", local_dir=EVAL_REQUESTS_PATH_BACKEND, repo_type="dataset", max_workers=60)
|
32 |
+
|
33 |
+
result_path_lst = find_json_files(EVAL_RESULTS_PATH_BACKEND)
|
34 |
+
request_path_lst = find_json_files(EVAL_REQUESTS_PATH_BACKEND)
|
35 |
+
|
36 |
+
model_name_to_model_map = {}
|
37 |
+
|
38 |
+
for path in request_path_lst:
|
39 |
+
with open(path, 'r') as f:
|
40 |
+
data = json.load(f)
|
41 |
+
model_name_to_model_map[data["model"]] = data
|
42 |
+
|
43 |
+
model_dataset_metric_to_result_map = {}
|
44 |
+
data_map = {}
|
45 |
+
|
46 |
+
for path in result_path_lst:
|
47 |
+
with open(path, 'r') as f:
|
48 |
+
data = json.load(f)
|
49 |
+
model_name = data["config"]["model_name"]
|
50 |
+
for dataset_name, results_dict in data["results"].items():
|
51 |
+
for metric_name, value in results_dict.items():
|
52 |
+
|
53 |
+
# print(model_name, dataset_name, metric_name, value)
|
54 |
+
|
55 |
+
if ',' in metric_name and '_stderr' not in metric_name \
|
56 |
+
and 'f1' not in metric_name \
|
57 |
+
and 'selfcheckgpt' not in dataset_name \
|
58 |
+
and model_name_to_model_map[model_name]["likes"] > 256:
|
59 |
+
|
60 |
+
to_add = True
|
61 |
+
|
62 |
+
if 'nq_open' in dataset_name or 'triviaqa' in dataset_name:
|
63 |
+
to_add = False
|
64 |
+
# pass
|
65 |
+
|
66 |
+
# breakpoint()
|
67 |
+
|
68 |
+
if 'bertscore' in metric_name:
|
69 |
+
if 'precision' not in metric_name:
|
70 |
+
to_add = False
|
71 |
+
|
72 |
+
if 'correctness,' in metric_name or 'em,' in metric_name:
|
73 |
+
to_add = False
|
74 |
+
|
75 |
+
if 'rouge' in metric_name:
|
76 |
+
if 'rougeL' not in metric_name:
|
77 |
+
to_add = False
|
78 |
+
|
79 |
+
if 'ifeval' in dataset_name:
|
80 |
+
if 'prompt_level_strict_acc' not in metric_name:
|
81 |
+
to_add = False
|
82 |
+
|
83 |
+
if 'squad' in dataset_name:
|
84 |
+
to_add = False
|
85 |
+
|
86 |
+
if 'fever' in dataset_name:
|
87 |
+
to_add = False
|
88 |
+
|
89 |
+
if 'rouge' in metric_name:
|
90 |
+
value /= 100.0
|
91 |
+
|
92 |
+
if to_add:
|
93 |
+
sanitised_metric_name = metric_name.split(',')[0]
|
94 |
+
model_dataset_metric_to_result_map[(model_name, dataset_name, sanitised_metric_name)] = value
|
95 |
+
|
96 |
+
# if (model_name, dataset_name) not in data_map:
|
97 |
+
# data_map[(model_name, dataset_name)] = {}
|
98 |
+
# data_map[(model_name, dataset_name)][metric_name] = value
|
99 |
+
|
100 |
+
if model_name not in data_map:
|
101 |
+
data_map[model_name] = {}
|
102 |
+
data_map[model_name][(dataset_name, sanitised_metric_name)] = value
|
103 |
+
|
104 |
+
print('model_name', model_name, 'dataset_name', dataset_name, 'metric_name', metric_name, 'value', value)
|
105 |
+
|
106 |
+
model_name_lst = [m for m in data_map.keys()]
|
107 |
+
for m in model_name_lst:
|
108 |
+
if len(data_map[m]) < 8:
|
109 |
+
del data_map[m]
|
110 |
+
|
111 |
+
df = pd.DataFrame.from_dict(data_map, orient='index')
|
112 |
+
o_df = df.copy(deep=True)
|
113 |
+
|
114 |
+
print(df)
|
115 |
+
|
116 |
+
# Check for NaN or infinite values and replace them
|
117 |
+
df.replace([np.inf, -np.inf], np.nan, inplace=True) # Replace infinities with NaN
|
118 |
+
df.fillna(0, inplace=True) # Replace NaN with 0 (or use another imputation strategy)
|
119 |
+
|
120 |
+
from sklearn.preprocessing import MinMaxScaler
|
121 |
+
|
122 |
+
# scaler = MinMaxScaler()
|
123 |
+
# df = pd.DataFrame(scaler.fit_transform(df), index=df.index, columns=df.columns)
|
124 |
+
|
125 |
+
sns.set_context("notebook", font_scale=1.0)
|
126 |
+
|
127 |
+
# fig = sns.clustermap(df, method='average', metric='cosine', cmap='coolwarm', figsize=(16, 12), annot=True)
|
128 |
+
fig = sns.clustermap(df, method='ward', metric='euclidean', cmap='coolwarm', figsize=(16, 12), annot=True, mask=o_df.isnull())
|
129 |
+
|
130 |
+
# Adjust the size of the cells (less wide)
|
131 |
+
plt.setp(fig.ax_heatmap.get_yticklabels(), rotation=0)
|
132 |
+
plt.setp(fig.ax_heatmap.get_xticklabels(), rotation=90)
|
133 |
+
|
134 |
+
# Save the clustermap to file
|
135 |
+
fig.savefig('plots/clustermap.pdf')
|
136 |
+
fig.savefig('plots/clustermap.png')
|