File size: 3,722 Bytes
c32ee7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import csv
from pprint import pprint
from typing import Dict, List

import pkg_resources
from t5.data.glue_utils import get_glue_metric, get_super_glue_metric
from t5.evaluation.metrics import accuracy, mean_multiclass_f1, rouge


SAFE_EXCLUDE_CRETERIA = [
    "template_bug",
    "negated_answers",
    "counting",
    "answer_span_indices",
    "non_natural_language",
    "generative_non_true_implausible",
]

AGGRESSIVE_EXCLUDE_CRETERIA = [
    "generative_non_true_task",
    "nontrivial_choices_hidden",
    "awkward_phrasing",
    "ungrammatical",
] + SAFE_EXCLUDE_CRETERIA


NON_GLUE_METRICS = {  # for those with do_eval = True
    "anli": [accuracy],
    "hans": [accuracy],
    "circa_goldstandard1_judgement": [mean_multiclass_f1(num_classes=8), accuracy],
    "circa_goldstandard2_judgement": [mean_multiclass_f1(num_classes=5), accuracy],
    "mc_taco": [accuracy],
    "nq_open": [accuracy],
    "qa_srl": [accuracy],
    "openbookqa": [accuracy],
    "race": [accuracy],
    "social_i_qa": [accuracy],
    "emo": [mean_multiclass_f1(num_classes=4)],
    "xsum": [rouge],
}


def exclude_bad_prompts(prompt: Dict) -> bool:
    for criterion in SAFE_EXCLUDE_CRETERIA:  # or AGGRESSIVE_EXCLUDE_CRETERIA
        if prompt.get(criterion):
            return False
    return True


def load_annotated_prompts() -> List[Dict]:
    annotated_csv_path = pkg_resources.resource_filename(__name__, "experiment_D3.csv")
    with open(annotated_csv_path) as in_file:
        reader = csv.DictReader(in_file)
        all_tasks = [row for row in reader]

    clean_tasks = list(filter(exclude_bad_prompts, all_tasks))

    # Assign metrics
    non_glue_eval_sets = list(NON_GLUE_METRICS.keys())
    for task in clean_tasks:
        if not task["do_eval"]:
            continue

        full_name = task["dataset_subset_template"]
        if full_name.startswith("glue"):
            subset = full_name.split("_")[1]
            task["metrics"] = get_glue_metric(subset)
        elif full_name.startswith("super_glue"):
            subset = full_name.split("_")[2]
            if subset in ("wsc.fixed", "multirc"):
                # TODO: WSC and MultiRC need special pre/postprocesing
                task["metrics"] = [accuracy]
                continue
            task["metrics"] = get_super_glue_metric(subset)

        for dataset_name in non_glue_eval_sets:
            if full_name.startswith(dataset_name):
                task["metrics"] = NON_GLUE_METRICS[dataset_name]

        # Skip rank_classification for now until we actually support it
        # if task["nontrivial_choices_hidden"]:
        #     # Trick of plugging in answer options and rank LM probabilites as predictions.
        #     # Required for all prompts with non_trivial_choices_hidden,
        #     # but could be used for other tasks as well where answer choices are given.
        #     if "metrics" not in task:
        #         task["metrics"] = [rank_classification]
        #     elif rank_classification not in task["metrics"]:
        #         task["metrics"].append(rank_classification)

        # should be already handled by NON_GLUE_METRICS
        # if task['generative_true_task'] or task['generative_non_true_task']:
        #     task['metrics'] = rouge

    return clean_tasks


def preview() -> None:
    clean_tasks = load_annotated_prompts()

    train_tasks = [t for t in clean_tasks if not t["skip_train"]]
    eval_tasks = [t for t in clean_tasks if t["do_eval"]]

    pprint([t["dataset_subset_template"] for t in train_tasks])
    print(len(train_tasks))

    pprint([f'{t["dataset_subset_template"]} {t["metrics"]}' for t in eval_tasks])
    print(len(eval_tasks))


if __name__ == "__main__":
    preview()