|
""" |
|
This file runs Masked Language Model. You provide a training file. Each line is interpreted as a sentence / paragraph. |
|
Optionally, you can also provide a dev file. |
|
|
|
The fine-tuned model is stored in the output/model_name folder. |
|
|
|
python train_mlm.py model_name data/train_sentences.txt [data/dev_sentences.txt] |
|
""" |
|
|
|
from transformers import AutoModelForMaskedLM, AutoTokenizer |
|
from transformers import DataCollatorForLanguageModeling, DataCollatorForWholeWordMask |
|
from transformers import Trainer, TrainingArguments |
|
import sys |
|
import gzip |
|
from datetime import datetime |
|
import wandb |
|
|
|
wandb.init(project="bert-word2vec") |
|
|
|
model_name = "nicoladecao/msmarco-word2vec256000-distilbert-base-uncased" |
|
per_device_train_batch_size = 16 |
|
save_steps = 5000 |
|
eval_steps = 1000 |
|
num_train_epochs = 3 |
|
use_fp16 = True |
|
max_length = 250 |
|
do_whole_word_mask = True |
|
mlm_prob = 15 |
|
|
|
model = AutoModelForMaskedLM.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
model.distilbert.embeddings.requires_grad = False |
|
|
|
output_dir = "output/{}-{}".format(model_name.replace("/", "_"), datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) |
|
print("Save checkpoints to:", output_dir) |
|
|
|
|
|
|
|
|
|
train_sentences = [] |
|
train_path = 'data/train.txt' |
|
with gzip.open(train_path, 'rt', encoding='utf8') if train_path.endswith('.gz') else open(train_path, 'r', encoding='utf8') as fIn: |
|
for line in fIn: |
|
line = line.strip() |
|
if len(line) >= 10: |
|
train_sentences.append(line) |
|
|
|
print("Train sentences:", len(train_sentences)) |
|
|
|
dev_sentences = [] |
|
|
|
dev_path = 'data/dev.txt' |
|
with gzip.open(dev_path, 'rt', encoding='utf8') if dev_path.endswith('.gz') else open(dev_path, 'r', encoding='utf8') as fIn: |
|
for line in fIn: |
|
line = line.strip() |
|
if len(line) >= 10: |
|
dev_sentences.append(line) |
|
|
|
print("Dev sentences:", len(dev_sentences)) |
|
|
|
|
|
class TokenizedSentencesDataset: |
|
def __init__(self, sentences, tokenizer, max_length, cache_tokenization=False): |
|
self.tokenizer = tokenizer |
|
self.sentences = sentences |
|
self.max_length = max_length |
|
self.cache_tokenization = cache_tokenization |
|
|
|
def __getitem__(self, item): |
|
if not self.cache_tokenization: |
|
return self.tokenizer(self.sentences[item], add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True) |
|
|
|
if isinstance(self.sentences[item], str): |
|
self.sentences[item] = self.tokenizer(self.sentences[item], add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True) |
|
return self.sentences[item] |
|
|
|
def __len__(self): |
|
return len(self.sentences) |
|
|
|
train_dataset = TokenizedSentencesDataset(train_sentences, tokenizer, max_length) |
|
dev_dataset = TokenizedSentencesDataset(dev_sentences, tokenizer, max_length, cache_tokenization=True) if len(dev_sentences) > 0 else None |
|
|
|
|
|
|
|
|
|
if do_whole_word_mask: |
|
data_collator = DataCollatorForWholeWordMask(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob) |
|
else: |
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob) |
|
|
|
training_args = TrainingArguments( |
|
output_dir=output_dir, |
|
overwrite_output_dir=True, |
|
num_train_epochs=num_train_epochs, |
|
evaluation_strategy="steps" if dev_dataset is not None else "no", |
|
per_device_train_batch_size=per_device_train_batch_size, |
|
eval_steps=eval_steps, |
|
save_steps=save_steps, |
|
save_total_limit=1, |
|
prediction_loss_only=True, |
|
fp16=use_fp16 |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
data_collator=data_collator, |
|
train_dataset=train_dataset, |
|
eval_dataset=dev_dataset |
|
) |
|
|
|
trainer.train() |
|
|
|
print("Save model to:", output_dir) |
|
model.save_pretrained(output_dir) |
|
tokenizer.save_pretrained(output_dir) |
|
|
|
print("Training done") |