""" |
This file runs Masked Language Model. You provide a training file. Each line is interpreted as a sentence / paragraph. |
Optionally, you can also provide a dev file. |
The fine-tuned model is stored in the output/model_name folder. |
python train_mlm.py model_name data/train_sentences.txt [data/dev_sentences.txt] |
""" |
from transformers import AutoModelForMaskedLM, AutoTokenizer |
from transformers import DataCollatorForLanguageModeling, DataCollatorForWholeWordMask |
from transformers import Trainer, TrainingArguments |
import sys |
import gzip |
from datetime import datetime |
import wandb |
from shutil import copyfile |
wandb.init(project="bert-word2vec") |
model_name = "nicoladecao/msmarco-word2vec256000-distilbert-base-uncased" |
per_device_train_batch_size = 16 |
save_steps = 5000 |
eval_steps = 1000 |
num_train_epochs = 3 |
use_fp16 = True |
max_length = 250 |
do_whole_word_mask = False |
mlm_prob = 15 |
model = AutoModelForMaskedLM.from_pretrained(model_name) |
tokenizer = AutoTokenizer.from_pretrained(model_name) |
model.distilbert.embeddings.word_embeddings.requires_grad_(False) |
output_dir = "output-mlm/{}-{}".format(model_name.replace("/", "_"), datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) |
print("Save checkpoints to:", output_dir) |
os.makedirs(model_save_path, exist_ok=True) |
train_script_path = os.path.join(model_save_path, 'train_script.py') |
copyfile(__file__, train_script_path) |
with open(train_script_path, 'a') as fOut: |
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) |
train_sentences = [] |
train_path = 'data/train.txt' |
with gzip.open(train_path, 'rt', encoding='utf8') if train_path.endswith('.gz') else open(train_path, 'r', encoding='utf8') as fIn: |
for line in fIn: |
line = line.strip() |
if len(line) >= 10: |
train_sentences.append(line) |
print("Train sentences:", len(train_sentences)) |
dev_sentences = [] |
dev_path = 'data/dev.txt' |
with gzip.open(dev_path, 'rt', encoding='utf8') if dev_path.endswith('.gz') else open(dev_path, 'r', encoding='utf8') as fIn: |
for line in fIn: |
line = line.strip() |
if len(line) >= 10: |
dev_sentences.append(line) |
print("Dev sentences:", len(dev_sentences)) |
class TokenizedSentencesDataset: |
def __init__(self, sentences, tokenizer, max_length, cache_tokenization=False): |
self.tokenizer = tokenizer |
self.sentences = sentences |
self.max_length = max_length |
self.cache_tokenization = cache_tokenization |
def __getitem__(self, item): |
if not self.cache_tokenization: |
return self.tokenizer(self.sentences[item], add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True) |
if isinstance(self.sentences[item], str): |
self.sentences[item] = self.tokenizer(self.sentences[item], add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True) |
return self.sentences[item] |
def __len__(self): |
return len(self.sentences) |
train_dataset = TokenizedSentencesDataset(train_sentences, tokenizer, max_length) |
dev_dataset = TokenizedSentencesDataset(dev_sentences, tokenizer, max_length, cache_tokenization=True) if len(dev_sentences) > 0 else None |
if do_whole_word_mask: |
data_collator = DataCollatorForWholeWordMask(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob) |
else: |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob) |
training_args = TrainingArguments( |
output_dir=output_dir, |
overwrite_output_dir=True, |
num_train_epochs=num_train_epochs, |
evaluation_strategy="steps" if dev_dataset is not None else "no", |
per_device_train_batch_size=per_device_train_batch_size, |
eval_steps=eval_steps, |
save_steps=save_steps, |
save_total_limit=1, |
prediction_loss_only=True, |
fp16=use_fp16 |
) |
trainer = Trainer( |
model=model, |
args=training_args, |
data_collator=data_collator, |
train_dataset=train_dataset, |
eval_dataset=dev_dataset |
) |
trainer.train() |
print("Save model to:", output_dir) |
model.save_pretrained(output_dir) |
tokenizer.save_pretrained(output_dir) |
print("Training done") |