import argparse | |
import logging | |
from torch.utils.data import Dataset, IterableDataset | |
import gzip | |
import json | |
from transformers import Seq2SeqTrainer, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments | |
import sys | |
from datetime import datetime | |
import torch | |
import random | |
from shutil import copyfile | |
import os | |
import wandb | |
import re | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
handlers=[logging.StreamHandler(sys.stdout)], | |
) | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model_name", default="google/t5-v1_1-base") | |
parser.add_argument("--train_files", required=True, nargs='+', default=[]) | |
parser.add_argument("--epochs", default=1, type=int) | |
parser.add_argument("--batch_size", default=32, type=int) | |
parser.add_argument("--max_source_length", default=320, type=int) | |
parser.add_argument("--max_target_length", default=64, type=int) | |
parser.add_argument("--name", required=True) | |
parser.add_argument("--train_size", default=10*1000*1000, type=int) | |
parser.add_argument("--eval_size", default=10000, type=int) | |
parser.add_argument("--fp16", default=False, action='store_true') | |
args = parser.parse_args() | |
wandb.init(project="doc2query", name=f"{args.name}-{args.model_name}") | |
class PairDataset: | |
def __init__(self, filepath): | |
self.filepath = filepath | |
self.examples = [] | |
def __iter__(self): | |
print("open", self.filepath) | |
with gzip.open(self.filepath, 'rt') as fIn: | |
for line in fIn: | |
example = self.get_example(json.loads(line)) | |
if example is not None: | |
self.examples.append(example) | |
yield example | |
while True: | |
random.shuffle(self.examples) | |
for ex in self.examples: | |
yield ex | |
def get_example(self, raw_example): | |
return [raw_example[0], raw_example[1]] | |
class RedditTitleDataset(PairDataset): | |
def get_example(self, raw_example): | |
return [self.clean_title(raw_example['title']), raw_example['body']] | |
def clean_title(self, text): | |
text = text.replace("&", "&").strip() | |
if text.startswith("["): | |
text = re.sub("^\[[a-zA-Z0-9]+\]", "", text).strip() | |
if text.endswith("]"): | |
text = re.sub("\[[a-zA-Z0-9\.]+\]$", "", text).strip() | |
if text.startswith("/r"): | |
text = re.sub("^/[a-zA-Z0-9/]+[;,: \-]+", "", text).strip() | |
return text | |
class StackExchangeTitleBodyDataset(PairDataset): | |
def get_example(self, raw_example): | |
return raw_example['texts'] | |
class MultiDataset(IterableDataset): | |
def __init__(self, filepaths, num_samples): | |
self.num_samples = num_samples | |
self.datasets = [] | |
self.data_iterators = [] | |
for filepath in filepaths: | |
if 'reddit_title_text' in filepath: | |
dataset = RedditTitleDataset(filepath) | |
elif 'stackexchange_archive/jsonl' in filepath: | |
dataset = StackExchangeTitleBodyDataset(filepath) | |
else: | |
dataset = PairDataset(filepath) | |
self.datasets.append(dataset) | |
self.data_iterators.append(iter(dataset)) | |
def __len__(self): | |
return self.num_samples | |
def __iter__(self): | |
while True: | |
for dataset in self.data_iterators: | |
yield next(dataset) | |
random.shuffle(self.data_iterators) | |
def delete_examples_cache(self): | |
for dataset in self.datasets: | |
dataset.examples = [] | |
def main(): | |
############ Model | |
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name) | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |
save_steps = 1000 | |
output_dir = 'output/'+args.name+'-'+args.model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
print("Output dir:", output_dir) | |
# Write self to path | |
os.makedirs(output_dir, exist_ok=True) | |
train_script_path = os.path.join(output_dir, 'train_script.py') | |
copyfile(__file__, train_script_path) | |
with open(train_script_path, 'a') as fOut: | |
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) | |
#### | |
training_args = Seq2SeqTrainingArguments( | |
output_dir=output_dir, | |
fp16=args.fp16, | |
fp16_backend="amp", | |
per_device_train_batch_size=args.batch_size, | |
evaluation_strategy="steps", | |
save_steps=save_steps, | |
logging_steps=100, | |
eval_steps=save_steps, #logging_steps, | |
warmup_steps=1000, | |
save_total_limit=1, | |
num_train_epochs=args.epochs, | |
report_to="wandb", | |
) | |
############ Arguments | |
############ Load datasets | |
train_dataset = MultiDataset(args.train_files, args.train_size) | |
train_dataset_iter = iter(train_dataset) | |
eval_dataset = [next(train_dataset_iter) for _ in range(args.eval_size)] | |
train_dataset.delete_examples_cache() #Make sure dev data is no re-used for training | |
print("Target:", eval_dataset[0][0]) | |
print("Input:", eval_dataset[0][1]) | |
print("Train dataset len:", len(train_dataset)) | |
def data_collator(examples): | |
targets = [row[0] for row in examples] | |
inputs = [row[1] for row in examples] | |
label_pad_token_id = -100 | |
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8 if training_args.fp16 else None) | |
# Setup the tokenizer for targets | |
with tokenizer.as_target_tokenizer(): | |
labels = tokenizer(targets, max_length=args.max_target_length, padding=True, truncation=True, pad_to_multiple_of=8 if training_args.fp16 else None) | |
# replace all tokenizer.pad_token_id in the labels by -100 to ignore padding in the loss. | |
labels["input_ids"] = [ | |
[(l if l != tokenizer.pad_token_id else label_pad_token_id) for l in label] for label in labels["input_ids"] | |
] | |
model_inputs["labels"] = torch.tensor(labels["input_ids"]) | |
return model_inputs | |
## Define the trainer | |
trainer = Seq2SeqTrainer( | |
model=model, | |
args=training_args, | |
train_dataset=train_dataset, | |
eval_dataset=eval_dataset, | |
tokenizer=tokenizer, | |
data_collator=data_collator | |
) | |
### Save the model | |
train_result = trainer.train() | |
trainer.save_model() | |
if __name__ == "__main__": | |
main() | |
# Script was called via: | |
#python train_hf_trainer.py --model_name google/t5-v1_1-base --train_files /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/academia.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/android.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/anime.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/apple.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/arduino.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/askubuntu.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/aviation.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/bicycles.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/biology.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/bitcoin.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/blender.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/boardgames.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/chemistry.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/christianity.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/civicrm.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/codereview.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/cooking.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/craftcms.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/crypto.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/cs.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/datascience.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/dba.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/diy.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/drupal.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/dsp.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/electronics.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/ell.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/emacs.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/english.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/ethereum.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/expressionengine.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/french.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/gamedev.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/gaming.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/gardening.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/german.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/gis.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/graphicdesign.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/history.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/islam.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/japanese.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/judaism.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/law.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/magento.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/math.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/mathematica.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/mathoverflow.net.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/mechanics.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/meta.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/meta.stackoverflow.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/money.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/movies.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/music.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/networkengineering.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/philosophy.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/photo.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/physics.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/politics.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/puzzling.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/quant.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/raspberrypi.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/rpg.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/rus.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/salesforce.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/scifi.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/security.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/serverfault.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/sharepoint.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/small_stackexchanges.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/softwareengineering.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/softwarerecs.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/space.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/stackoverflow.com-Posts.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/stats.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/superuser.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/tex.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/travel.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/unix.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/ux.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/webapps.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/webmasters.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/wordpress.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/workplace.stackexchange.com.jsonl.gz /home/stackexchange_archive/stackexchange_extracted/TitleAnswer/worldbuilding.stackexchange.com.jsonl.gz --name stackexchange_title_answer --train_size 100000000 --max_source_length 320 |