""" This file runs Masked Language Model. You provide a training file. Each line is interpreted as a sentence / paragraph. Optionally, you can also provide a dev file. The fine-tuned model is stored in the output/model_name folder. python train_mlm.py model_name data/train_sentences.txt [data/dev_sentences.txt] """ from transformers import AutoModelForMaskedLM, AutoTokenizer from transformers import DataCollatorForLanguageModeling, DataCollatorForWholeWordMask from transformers import Trainer, TrainingArguments import sys import gzip from datetime import datetime import wandb wandb.init(project="bert-word2vec") model_name = "nicoladecao/msmarco-word2vec256000-distilbert-base-uncased" per_device_train_batch_size = 16 save_steps = 5000 eval_steps = 1000 num_train_epochs = 3 use_fp16 = True #Set to True, if your GPU supports FP16 operations max_length = 250 #Max length for a text input do_whole_word_mask = True #If set to true, whole words are masked mlm_prob = 15 #Probability that a word is replaced by a [MASK] token model = AutoModelForMaskedLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) ## Freeze embedding layer model.distilbert.embeddings.requires_grad = False output_dir = "output/{}-{}".format(model_name.replace("/", "_"), datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) print("Save checkpoints to:", output_dir) ##### Load our training datasets train_sentences = [] train_path = 'data/train.txt' with gzip.open(train_path, 'rt', encoding='utf8') if train_path.endswith('.gz') else open(train_path, 'r', encoding='utf8') as fIn: for line in fIn: line = line.strip() if len(line) >= 10: train_sentences.append(line) print("Train sentences:", len(train_sentences)) dev_sentences = [] dev_path = 'data/dev.txt' with gzip.open(dev_path, 'rt', encoding='utf8') if dev_path.endswith('.gz') else open(dev_path, 'r', encoding='utf8') as fIn: for line in fIn: line = line.strip() if len(line) >= 10: dev_sentences.append(line) print("Dev sentences:", len(dev_sentences)) #A dataset wrapper, that tokenizes our data on-the-fly class TokenizedSentencesDataset: def __init__(self, sentences, tokenizer, max_length, cache_tokenization=False): self.tokenizer = tokenizer self.sentences = sentences self.max_length = max_length self.cache_tokenization = cache_tokenization def __getitem__(self, item): if not self.cache_tokenization: return self.tokenizer(self.sentences[item], add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True) if isinstance(self.sentences[item], str): self.sentences[item] = self.tokenizer(self.sentences[item], add_special_tokens=True, truncation=True, max_length=self.max_length, return_special_tokens_mask=True) return self.sentences[item] def __len__(self): return len(self.sentences) train_dataset = TokenizedSentencesDataset(train_sentences, tokenizer, max_length) dev_dataset = TokenizedSentencesDataset(dev_sentences, tokenizer, max_length, cache_tokenization=True) if len(dev_sentences) > 0 else None ##### Training arguments if do_whole_word_mask: data_collator = DataCollatorForWholeWordMask(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob) else: data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob) training_args = TrainingArguments( output_dir=output_dir, overwrite_output_dir=True, num_train_epochs=num_train_epochs, evaluation_strategy="steps" if dev_dataset is not None else "no", per_device_train_batch_size=per_device_train_batch_size, eval_steps=eval_steps, save_steps=save_steps, save_total_limit=1, prediction_loss_only=True, fp16=use_fp16 ) trainer = Trainer( model=model, args=training_args, data_collator=data_collator, train_dataset=train_dataset, eval_dataset=dev_dataset ) trainer.train() print("Save model to:", output_dir) model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) print("Training done")