Spaces:
Build error
Build error
import csv | |
import functools | |
from typing import Dict, List, Optional, Tuple | |
import datasets | |
import pkg_resources | |
import seqio | |
import t5 | |
import tensorflow as tf | |
from t5.data.glue_utils import get_glue_metric, get_super_glue_metric | |
from t5.evaluation import metrics as mt | |
import promptsource.templates | |
from promptsource.seqio_tasks import utils | |
GET_METRICS = { | |
"BLEU": mt.bleu, | |
"ROUGE": mt.rouge, | |
"Span Squad": mt.span_squad, | |
"Squad": mt.squad, | |
"Trivia QA": mt.trivia_qa, | |
"Accuracy": mt.accuracy, | |
"Sequence Accuracy": mt.sequence_accuracy, | |
"Pearson Correlation": mt.pearson_corrcoef, | |
"Spearman Correlation": mt.spearman_corrcoef, | |
"MultiRC": mt.multirc_f1_over_all_answers, | |
"AUC": mt.auc, | |
"COQA F1": mt.coqa_f1, | |
"Edit Distance": mt.edit_distance, | |
# "Mean Reciprocal Rank": mt.accuracy, # NOTE not in T5? | |
"Other": mt.accuracy, | |
# Missing support for mean_multiclass_f1 etc. which need a num_classes parameter | |
} | |
MAX_EXAMPLES_PER_DATASET = 500_000 | |
def strip_whitespace(output_or_target, example=None, is_target=False): | |
"""Cached tasks from promptsource all have a leading space on the ground-truth targets.""" | |
return output_or_target.strip() | |
def maybe_get_class_id_postprocessor(template): | |
if template.get_fixed_answer_choices_list(): | |
def postprocess_fn(output_or_target, example=None, is_target=False): | |
output_or_target = strip_whitespace(output_or_target) | |
return t5.data.postprocessors.string_label_to_class_id( | |
output_or_target, label_classes=template.get_fixed_answer_choices_list() | |
) | |
return postprocess_fn | |
else: | |
return strip_whitespace | |
def get_tf_dataset(split, shuffle_files, seed, dataset_name, subset_name, template, split_mapping): | |
# HF datasets does not support file-level shuffling | |
del shuffle_files, seed | |
dataset = datasets.load_dataset(dataset_name, subset_name) | |
dataset = dataset[split_mapping[split]] | |
dataset = utils.apply_template(dataset, template) | |
return utils.hf_dataset_to_tf_dataset(dataset) | |
def add_task(dataset_name, subset_name, template_name, task_name=None, split_mapping=None): | |
template = all_templates.get_dataset(dataset_name, subset_name)[template_name] | |
task_name = task_name or utils.get_task_name(dataset_name, subset_name, template_name) | |
if dataset_name == "glue": | |
metrics = get_glue_metric(subset_name) | |
elif dataset_name == "super_glue": | |
if subset_name in ("wsc.fixed", "multirc"): | |
# TODO: WSC and MultiRC need special pre/postprocesing | |
metrics = [mt.accuracy] | |
else: | |
metrics = get_super_glue_metric(subset_name) | |
else: | |
# TODO what if metric is null? | |
metrics = [GET_METRICS[m] for m in template.metadata.metrics] | |
dataset_splits = utils.get_dataset_splits(dataset_name, subset_name) | |
split_mapping = split_mapping or {k: k for k in dataset_splits.keys()} | |
dataset_fn = functools.partial( | |
get_tf_dataset, | |
seed=None, | |
dataset_name=dataset_name, | |
subset_name=subset_name, | |
template=template, | |
split_mapping=split_mapping, | |
) | |
data_source = seqio.FunctionDataSource( | |
dataset_fn, | |
splits=list(split_mapping.keys()), | |
num_input_examples={s: dataset_splits[split_mapping[s]].num_examples for s in split_mapping.keys()}, | |
) | |
output_features = { | |
"inputs": seqio.Feature(t5.data.get_default_vocabulary(), add_eos=False, dtype=tf.int32), | |
"targets": seqio.Feature(t5.data.get_default_vocabulary(), add_eos=True, dtype=tf.int32), | |
} | |
preprocessors = [ | |
seqio.preprocessors.tokenize, | |
seqio.preprocessors.append_eos, | |
seqio.CacheDatasetPlaceholder(required=False), | |
] | |
# Add train and normal eval tasks | |
seqio.TaskRegistry.add( | |
task_name, | |
data_source, | |
preprocessors=preprocessors, | |
output_features=output_features, | |
metric_fns=metrics, | |
postprocess_fn=maybe_get_class_id_postprocessor(template), | |
) | |
# Add rank classification eval task | |
if template.answer_choices: | |
rank_classification_preprocessor = functools.partial( | |
t5.data.preprocessors.rank_classification, | |
inputs_fn=lambda ex: tf.fill((len(ex["answer_choices"]),), ex["inputs"]), | |
targets_fn=lambda ex: ex["answer_choices"], | |
is_correct_fn=lambda ex: tf.equal(ex["answer_choices"], tf.strings.strip(ex["targets"])), | |
weight_fn=lambda ex: 1.0, | |
) | |
fixed_choices = template.get_fixed_answer_choices_list() | |
num_classes = len(fixed_choices) if fixed_choices else None | |
seqio.TaskRegistry.add( | |
task_name + "_score_eval", | |
data_source, | |
preprocessors=[rank_classification_preprocessor] + preprocessors, | |
output_features=output_features, | |
metric_fns=[functools.partial(t5.evaluation.metrics.rank_classification, num_classes=num_classes)], | |
postprocess_fn=t5.data.postprocessors.rank_classification, | |
) | |
datatset_subset_tuple = Tuple[str, Optional[str]] | |
d4_train: List[datatset_subset_tuple] = [] | |
d4_eval: List[datatset_subset_tuple] = [] | |
d3_train_gpt: List[datatset_subset_tuple] = [] | |
d3_train_sglue: List[datatset_subset_tuple] = [] | |
bias_fairness_eval: List[datatset_subset_tuple] = [] | |
gsheet: Dict[datatset_subset_tuple, Dict] = {} | |
experiment_path = pkg_resources.resource_filename(__name__, "experiment_D4.csv") | |
with open(experiment_path) as exp_file: | |
reader = csv.DictReader(exp_file) | |
for row in reader: | |
if row["skip"]: | |
continue | |
if row["subset"] == "": | |
row["subset"] = None # to match promptsource.Template object | |
dataset_subset = (row["HF_name"], row["subset"]) | |
if row["do_train"] == "TRUE": | |
d4_train.append(dataset_subset) | |
if row["do_eval"] == "TRUE": | |
d4_eval.append(dataset_subset) | |
if row["D3_do_train"] == "TRUE" and "GPT" in row["seed_paper"]: | |
d3_train_gpt.append(dataset_subset) | |
if row["D3_do_train"] == "TRUE" and row["HF_name"] == "super_glue": | |
d3_train_sglue.append(dataset_subset) | |
if ( | |
row["do_eval"] == "TRUE" | |
and row["task_by_convention"] == "bias_and_fairness" | |
and row["HF_name"] != "winogender" | |
): | |
bias_fairness_eval.append(dataset_subset) | |
gsheet[dataset_subset] = row | |
all_datasets = d4_train + d4_eval + d3_train_gpt + d3_train_sglue + bias_fairness_eval | |
all_templates = promptsource.templates.TemplateCollection() | |
all_templates.remove("anli") # Need to special-case ANLI due to weird split conventions | |
# 3 stages of training/ablation: D4 -> GPT -> SuperGLUE | |
d4_train_mixture: List[str] = [] # strings are dataset_subset_template | |
gpt_train_mixture: List[str] = [] | |
sglue_train_mixture: List[str] = [] | |
d4_eval_mixture: List[str] = [] | |
bias_fairness_eval_mixture: List[str] = [] | |
mixture_cap: Dict[str, int] = {} | |
single_original_task: Dict[Tuple[str, str], str] = {} | |
all_original_tasks: List[str] = [] | |
for dataset_name, subset_name in all_templates.keys: | |
if (dataset_name, subset_name) not in all_datasets: | |
all_templates.remove(dataset_name, subset_name) | |
continue | |
dataset = all_templates.get_dataset(dataset_name, subset_name) | |
num_templates = len(dataset.all_template_names) | |
train_size = gsheet[(dataset_name, subset_name)]["train_size"] | |
if train_size == "": | |
train_size = 0 | |
else: | |
train_size = int(train_size) | |
if train_size > MAX_EXAMPLES_PER_DATASET: | |
cap = MAX_EXAMPLES_PER_DATASET // num_templates | |
else: | |
cap = train_size | |
for template_name in dataset.all_template_names: | |
add_task(dataset_name, subset_name, template_name) | |
template = dataset[template_name] | |
task_name = utils.get_task_name(dataset_name, subset_name, template_name) | |
if (dataset_name, subset_name) not in single_original_task and template.metadata.original_task: | |
single_original_task[(dataset_name, subset_name)] = task_name | |
if template.metadata.original_task: | |
all_original_tasks.append(task_name) | |
if (dataset_name, subset_name) in d4_train: | |
d4_train_mixture.append(task_name) | |
mixture_cap[task_name] = cap | |
if (dataset_name, subset_name) in d3_train_gpt: | |
gpt_train_mixture.append(task_name) | |
mixture_cap[task_name] = cap | |
if (dataset_name, subset_name) in d3_train_sglue: | |
sglue_train_mixture.append(task_name) | |
mixture_cap[task_name] = cap | |
if (dataset_name, subset_name) in d4_eval: | |
if template.metadata.original_task: | |
d4_eval_mixture.append(task_name) | |
# TODO use template.metadata.answer_choices here for rank eval | |
if (dataset_name, subset_name) in bias_fairness_eval: | |
bias_fairness_eval_mixture.append(task_name) | |
# Special case for ANLI, which has weirdly-named splits and rounds that should be subsets | |
dataset_name, subset_name = ("anli", None) | |
dataset = all_templates.get_dataset(dataset_name, subset_name) | |
for anli_round in ("r1", "r2", "r3"): | |
for template_name in all_templates.get_dataset(dataset_name, subset_name).all_template_names: | |
task_name = utils.get_task_name(dataset_name, subset_name, template_name) + f"_{anli_round}" | |
split_mapping = { | |
"train": f"train_{anli_round}", | |
"validation": f"dev_{anli_round}", | |
"test": f"test_{anli_round}", | |
} | |
add_task(dataset_name, subset_name, template_name, task_name, split_mapping) | |
template = dataset[template_name] | |
if template.metadata.original_task: | |
d4_eval_mixture.append(task_name) # TODO or add to ANLI special mixture | |
# TODO use template.metadata.answer_choices here for rank eval | |
TASK_BLACKLIST = [ | |
# Tasks which often tokenize to > 1024 tokens currently | |
"hotpot_qa_distractor_Generate_Explanations", | |
"hotpot_qa_fullwiki_Generate_Explanations", | |
"hotpot_qa_distractor_Generate_Answer_and_Explanations", | |
"hotpot_qa_fullwiki_Generate_Answer_and_Explanations", | |
"hotpot_qa_fullwiki_Generate_Answer", | |
"hotpot_qa_distractor_Generate_Answer", | |
"hotpot_qa_distractor_Generate_Title_2", | |
"hotpot_qa_fullwiki_Generate_Title_2", | |
"hotpot_qa_fullwiki_Generate_Title_1", | |
"hotpot_qa_distractor_Generate_Title_1", | |
"hotpot_qa_distractor_Generate_Question", | |
"hotpot_qa_fullwiki_Generate_Question", | |
"tab_fact_tab_fact_tab_fact_3", | |
"tab_fact_tab_fact_tab_fact_2", | |
"tab_fact_tab_fact_tab_fact_1", | |
"tab_fact_tab_fact_tab_fact_7", | |
"tab_fact_tab_fact_tab_fact_4", | |
"tab_fact_tab_fact_tab_fact_5", | |
"tab_fact_tab_fact_tab_fact_6", | |
"wiki_hop_masked_Choose_Best_Object_Candidate", | |
"wiki_hop_masked_Indirect_Question_about_Birthplace_Citizenship_Place_of_Death", | |
"narrativeqa_Template_05", | |
"ecthr_cases_alleged_violation_prediction_silver_rationales", | |
# Tasks with broken cached files | |
"gigaword_summarize_", | |
] | |
# Tasks that failed caching (won't try to fix them for now) - remove when we are done | |
D4_TRAIN_SCORE_EVAL_TASK_BLACKLIST = [ | |
"amazon_polarity_Is_this_product_review_positive_score_eval", | |
"amazon_polarity_Is_this_review_negative_score_eval", | |
"amazon_polarity_Is_this_review_score_eval", | |
"amazon_polarity_User_recommend_this_product_score_eval", | |
"amazon_polarity_convey_negative_or_positive_sentiment_score_eval", | |
"amazon_polarity_flattering_or_not_score_eval", | |
"amazon_polarity_negative_or_positive_tone_score_eval", | |
"amazon_polarity_user_satisfied_score_eval", | |
"amazon_polarity_would_you_buy_score_eval", | |
"dbpedia_14_given_a_choice_of_categories__score_eval", | |
"dbpedia_14_given_list_what_category_does_the_paragraph_belong_to_score_eval", | |
"dbpedia_14_pick_one_category_for_the_following_text_score_eval", | |
"wiki_hop_original_choose_best_object_affirmative_1_score_eval", | |
"wiki_hop_original_choose_best_object_affirmative_2_score_eval", | |
"wiki_hop_original_choose_best_object_affirmative_3_score_eval", | |
"wiki_hop_original_choose_best_object_interrogative_1_score_eval", | |
"wiki_hop_original_choose_best_object_interrogative_2_score_eval", | |
] | |
seqio.MixtureRegistry.add( | |
"d4_train", | |
[task for task in d4_train_mixture if task not in TASK_BLACKLIST], | |
default_rate=lambda t: mixture_cap[t.name], | |
) | |
seqio.MixtureRegistry.add( | |
"gpt_train", | |
[task for task in gpt_train_mixture if task not in TASK_BLACKLIST], | |
default_rate=lambda t: mixture_cap[t.name], | |
) | |
seqio.MixtureRegistry.add( | |
"sglue_train", | |
[task for task in sglue_train_mixture if task not in TASK_BLACKLIST], | |
default_rate=lambda t: mixture_cap[t.name], | |
) | |
seqio.MixtureRegistry.add( | |
"d4_gpt_train", | |
[task for task in d4_train_mixture + gpt_train_mixture if task not in TASK_BLACKLIST], | |
default_rate=lambda t: mixture_cap[t.name], | |
) | |
seqio.MixtureRegistry.add( | |
"d4_gpt_sglue_train", | |
[task for task in d4_train_mixture + gpt_train_mixture + sglue_train_mixture if task not in TASK_BLACKLIST], | |
default_rate=lambda t: mixture_cap[t.name], | |
) | |
seqio.MixtureRegistry.add( | |
"d4_eval", | |
[task for task in d4_eval_mixture if task not in TASK_BLACKLIST], | |
default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000), | |
) # eval mixture does not need to be capped | |
seqio.MixtureRegistry.add( | |
"d4_score_eval", | |
[ | |
task | |
for task in seqio.TaskRegistry.names() | |
if task.endswith("_score_eval") | |
and task.split("_score_eval")[0] in d4_eval_mixture | |
and task.split("_score_eval")[0] not in TASK_BLACKLIST | |
], | |
default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000), | |
) | |
# Train tasks we don't care about evaluating on | |
D4_TRAIN_SKIP_EVAL = [ | |
"paws_labeled_final", | |
"adversarial_qa_dbidaf", | |
"adversarial_qa_dbert", | |
"duorc_ParaphraseRC", | |
"dream", | |
"amazon_polarity", | |
"app_reviews", | |
"imdb", | |
"wiki_bio", | |
"gigaword", | |
"multi_news", | |
"samsum", | |
"dbpedia_14", | |
"trec", | |
] | |
seqio.MixtureRegistry.add( | |
"d4_train_eval", | |
[ | |
task | |
for task in d4_train_mixture | |
if task not in TASK_BLACKLIST | |
and not any([skip in task for skip in D4_TRAIN_SKIP_EVAL]) | |
and task in all_original_tasks | |
], | |
default_rate=lambda t: mixture_cap[t.name], | |
) | |
seqio.MixtureRegistry.add( | |
"d4_train_score_eval", | |
[ | |
task | |
for task in seqio.TaskRegistry.names() | |
if task.endswith("_score_eval") | |
and task.split("_score_eval")[0] in d4_train_mixture | |
and task.split("_score_eval")[0] not in TASK_BLACKLIST | |
and task not in D4_TRAIN_SCORE_EVAL_TASK_BLACKLIST | |
and not any([skip in task for skip in D4_TRAIN_SKIP_EVAL]) | |
and task.split("_score_eval")[0] in all_original_tasks | |
], | |
default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000), | |
) | |
seqio.MixtureRegistry.add( | |
"d4_train_one_og_prompt", | |
[task for task in single_original_task.values() if task in d4_train_mixture and task not in TASK_BLACKLIST], | |
default_rate=lambda t: mixture_cap[t.name], | |
) | |
seqio.MixtureRegistry.add( | |
"d4_train_all_og_prompts", | |
[task for task in all_original_tasks if task in d4_train_mixture and task not in TASK_BLACKLIST], | |
default_rate=lambda t: mixture_cap[t.name], | |
) | |
seqio.MixtureRegistry.add( | |
"bias_fairness_eval", | |
bias_fairness_eval_mixture, | |
default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000), | |
) | |
seqio.MixtureRegistry.add( | |
"bias_fairness_eval_score_eval", | |
[ | |
task | |
for task in seqio.TaskRegistry.names() | |
if task.endswith("_score_eval") and task.split("_score_eval")[0] in bias_fairness_eval_mixture | |
], | |
default_rate=functools.partial(seqio.mixing_rate_num_examples, maximum=500_000), | |
) | |