Spaces:
Runtime error
Runtime error
pminervini
commited on
Commit
·
2561b63
1
Parent(s):
f00379a
update
Browse files
halueval-cli.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from huggingface_hub import snapshot_download
|
4 |
+
|
5 |
+
from src.backend.envs import EVAL_REQUESTS_PATH_BACKEND
|
6 |
+
from src.backend.manage_requests import get_eval_requests
|
7 |
+
from src.backend.manage_requests import EvalRequest
|
8 |
+
from src.backend.run_eval_suite import run_evaluation
|
9 |
+
|
10 |
+
from lm_eval.tasks import initialize_tasks, include_task_folder
|
11 |
+
from lm_eval import tasks, evaluator, utils
|
12 |
+
|
13 |
+
from src.backend.envs import Tasks, EVAL_REQUESTS_PATH_BACKEND, EVAL_RESULTS_PATH_BACKEND, DEVICE, LIMIT, Task
|
14 |
+
from src.envs import QUEUE_REPO
|
15 |
+
|
16 |
+
|
17 |
+
def main():
|
18 |
+
snapshot_download(repo_id=QUEUE_REPO, revision="main", local_dir=EVAL_REQUESTS_PATH_BACKEND, repo_type="dataset", max_workers=60)
|
19 |
+
|
20 |
+
PENDING_STATUS = "PENDING"
|
21 |
+
RUNNING_STATUS = "RUNNING"
|
22 |
+
FINISHED_STATUS = "FINISHED"
|
23 |
+
FAILED_STATUS = "FAILED"
|
24 |
+
|
25 |
+
status = [PENDING_STATUS, RUNNING_STATUS, FINISHED_STATUS, FAILED_STATUS]
|
26 |
+
|
27 |
+
# Get all eval request that are FINISHED, if you want to run other evals, change this parameter
|
28 |
+
eval_requests: list[EvalRequest] = get_eval_requests(job_status=status, hf_repo=QUEUE_REPO, local_dir=EVAL_REQUESTS_PATH_BACKEND)
|
29 |
+
eval_request = [r for r in eval_requests if 'bloom-560m' in r.model][0]
|
30 |
+
|
31 |
+
task_names = ['halueval_qa']
|
32 |
+
|
33 |
+
include_task_folder("src/backend/tasks/")
|
34 |
+
initialize_tasks('INFO')
|
35 |
+
|
36 |
+
print(tasks.ALL_TASKS)
|
37 |
+
|
38 |
+
task_names = utils.pattern_match(task_names, tasks.ALL_TASKS)
|
39 |
+
|
40 |
+
print(f"Selected Tasks: {task_names}")
|
41 |
+
|
42 |
+
results = evaluator.simple_evaluate(model="hf-auto", model_args=eval_request.get_model_args(), tasks=task_names, num_fewshot=0,
|
43 |
+
batch_size=4, device=DEVICE, use_cache=None, limit=8, write_out=True)
|
44 |
+
|
45 |
+
print('AAA', results)
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
main()
|
src/backend/tasks/halueval/halueval_qa.yaml
CHANGED
@@ -25,7 +25,7 @@ metric_list:
|
|
25 |
- metric: em
|
26 |
aggregation: mean
|
27 |
higher_is_better: true
|
28 |
-
- metric:
|
29 |
aggregation: mean
|
30 |
higher_is_better: true
|
31 |
metadata:
|
|
|
25 |
- metric: em
|
26 |
aggregation: mean
|
27 |
higher_is_better: true
|
28 |
+
- metric: correctness
|
29 |
aggregation: mean
|
30 |
higher_is_better: true
|
31 |
metadata:
|
src/backend/tasks/halueval/utils.py
CHANGED
@@ -36,52 +36,39 @@ You should try your best to determine if the answer contains non-factual or hall
|
|
36 |
|
37 |
|
38 |
def doc_to_text_qa(doc: dict[str, str]) -> str:
|
|
|
39 |
doc_text = QA_INSTURCTIONS + "\n\n#Question#: " + doc["question"] + "\n#Answer#: " + doc["answer"] + "\n#Your Judgement#:"
|
40 |
return doc_text
|
41 |
|
42 |
|
43 |
def doc_to_target_qa(doc: dict[str, str]) -> str:
|
|
|
44 |
return doc['hallucination']
|
45 |
|
46 |
|
47 |
-
def
|
48 |
-
|
49 |
-
em_sum = 0.0
|
50 |
-
if len(gold_list) > 1:
|
51 |
-
for i in range(len(gold_list)):
|
52 |
-
gold_answers = gold_list[0:i] + gold_list[i + 1 :]
|
53 |
-
# predictions compared against (n) golds and take maximum
|
54 |
-
em_sum += max(squad_metrics.compute_exact(a, predictions) for a in gold_answers)
|
55 |
-
else:
|
56 |
-
em_sum += max(squad_metrics.compute_exact(a, predictions) for a in gold_list)
|
57 |
-
return em_sum / max(1, len(gold_list))
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
|
61 |
-
f1_sum = 0.0
|
62 |
-
em_sum = 0.0
|
63 |
|
64 |
-
|
65 |
-
|
|
|
66 |
|
67 |
-
|
68 |
-
for i in range(len(gold_list)):
|
69 |
-
gold_answers = gold_list[0:i] + gold_list[i + 1 :]
|
70 |
-
# predictions compared against (n) golds and take maximum
|
71 |
-
em_sum += max(squad_metrics.compute_exact(a, predictions) for a in gold_answers)
|
72 |
-
f1_sum += max(squad_metrics.compute_f1(a, predictions) for a in gold_answers)
|
73 |
-
else:
|
74 |
-
em_sum += max(squad_metrics.compute_exact(a, predictions) for a in gold_list)
|
75 |
-
f1_sum += max(squad_metrics.compute_f1(a, predictions) for a in gold_list)
|
76 |
|
77 |
-
return {
|
78 |
-
"em": em_sum / max(1, len(gold_list)),
|
79 |
-
"f1": f1_sum / max(1, len(gold_list)),
|
80 |
-
}
|
81 |
|
82 |
-
|
83 |
-
|
84 |
gold_list = doc_to_target_qa(doc)
|
85 |
-
|
86 |
-
|
|
|
87 |
return scores
|
|
|
36 |
|
37 |
|
38 |
def doc_to_text_qa(doc: dict[str, str]) -> str:
|
39 |
+
# print('XXX doc_to_text_qa')
|
40 |
doc_text = QA_INSTURCTIONS + "\n\n#Question#: " + doc["question"] + "\n#Answer#: " + doc["answer"] + "\n#Your Judgement#:"
|
41 |
return doc_text
|
42 |
|
43 |
|
44 |
def doc_to_target_qa(doc: dict[str, str]) -> str:
|
45 |
+
# print('XXX doc_to_target_qa')
|
46 |
return doc['hallucination']
|
47 |
|
48 |
|
49 |
+
def compute_metrics_qa(gold_answer: str, prediction: str) -> dict[str, float]:
|
50 |
+
is_correct = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
if ("Yes" in prediction and "No" in prediction) or ("Yes" not in prediction and "No" not in prediction):
|
53 |
+
is_correct = False
|
54 |
+
elif "Yes" in prediction:
|
55 |
+
prediction = "yes"
|
56 |
+
elif "No" in prediction:
|
57 |
+
prediction = "no"
|
58 |
|
59 |
+
is_exact = (gold_answer == prediction)
|
|
|
|
|
60 |
|
61 |
+
res = {"correctness": 1.0 if is_correct else 0.0}
|
62 |
+
if is_correct:
|
63 |
+
res["em"] = 1.0 if is_exact else 0.0
|
64 |
|
65 |
+
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
def process_results_qa(doc: dict[str, str], results: list[str]):
|
69 |
+
# results is e.g., ['Yes']
|
70 |
gold_list = doc_to_target_qa(doc)
|
71 |
+
# gold_list is e.g., 'yes'
|
72 |
+
prediction = results[0].strip().split("\n")[0]
|
73 |
+
scores = compute_metrics_qa(gold_list, prediction)
|
74 |
return scores
|