mTk-AdversarialQA_en-SberQuAD_ru-1B / train_mt5_qa_en_AQA+ru_info.py
michal-stefanik's picture
README & training scripts
32980f3
from typing import List
from adaptor.adapter import Adapter
from adaptor.evaluators.generative import BLEU
from adaptor.lang_module import LangModule
from adaptor.schedules import ParallelSchedule
from adaptor.utils import AdaptationArguments, StoppingStrategy
from datasets import load_dataset
from priming_objective import Priming
training_arguments = AdaptationArguments(output_dir="train_dir_AQA_info_large_ru",
learning_rate=2e-5, # we set LR=2e-4 for pre-training experiments
stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
# stopping_strategy=StoppingStrategy.NUM_STEPS_TOTAL,
do_train=True,
do_eval=True,
warmup_steps=1000,
max_steps=10000,
gradient_accumulation_steps=30,
eval_steps=500,
logging_steps=10,
save_steps=500,
num_train_epochs=5,
evaluation_strategy="steps",
save_total_limit=10,
stopping_patience=10)
eval_examples = 200
# priming
num_demonstrations = 3
def _construct_priming_prompt(previous_examples: List[str], current_example: str) -> str:
return " ".join(previous_examples + [current_example])
lang_module = LangModule("google/mt5-large")
# priming
per_type_examples = {}
qa_en = load_dataset("adversarial_qa", "adversarialQA")
qa_train = qa_en["train"].filter(lambda entry: len(entry["context"]) < 2000)
val_metrics = [BLEU(**{"additional_sep_char": "▁"})]
# Adversarial QA dataset & objective:
def _get_firstword_categories(data) -> List[str]:
return [question.split()[0] if not question.startswith("To")
else " ".join(question.split()[:2])
for question in data["question"]]
q_answering_en = Priming(lang_module,
max_eval_samples=eval_examples,
demos_selection_strategy="informative",
texts_or_path=qa_train["question"],
text_pair_or_path=qa_train["context"],
val_texts_or_path=qa_en["validation"]["question"][-eval_examples:],
val_text_pair_or_path=qa_en["validation"]["context"][-eval_examples:],
labels_or_path=[a["text"][0] for a in qa_train["answers"]],
val_labels_or_path=[a["text"][0] for a in qa_en["validation"]["answers"]][-eval_examples:],
train_question_categories=_get_firstword_categories(qa_train),
val_question_categories=_get_firstword_categories(qa_en["validation"])[-eval_examples:],
batch_size=1,
val_evaluators=val_metrics,
# val_evaluators=val_metrics,
source_lang_id="en",
objective_id="AQA-en")
qa_ru = load_dataset("sberquad")
qa_ru_train = qa_ru["train"].filter(lambda entry: len(entry["context"]) < 800)
skipped = 0
q_answering_ru = Priming(lang_module,
max_eval_samples=eval_examples,
demos_selection_strategy="informative",
texts_or_path=qa_ru_train["question"],
text_pair_or_path=qa_ru_train["context"],
val_texts_or_path=qa_ru["validation"]["question"][-eval_examples:],
val_text_pair_or_path=qa_ru["validation"]["context"][-eval_examples:],
labels_or_path=[a["text"][0] for a in qa_ru_train["answers"]],
val_labels_or_path=[a["text"][0] for a in qa_ru["validation"]["answers"]][-eval_examples:],
train_question_categories=_get_firstword_categories(qa_ru_train),
val_question_categories=_get_firstword_categories(qa_ru["validation"])[-eval_examples:],
batch_size=1,
val_evaluators=val_metrics,
# val_evaluators=val_metrics,
source_lang_id="ru",
objective_id="SQuAD-ru")
schedule = ParallelSchedule(objectives=[q_answering_en,
q_answering_ru
],
args=training_arguments)
adapter = Adapter(lang_module, schedule, args=training_arguments)
adapter.train()