|
import torch |
|
from torch.optim import Adam |
|
from transformers import ( |
|
AutoTokenizer, |
|
Trainer, |
|
TrainingArguments, |
|
DataCollatorForLanguageModeling, |
|
get_scheduler, |
|
) |
|
from datasets import load_from_disk |
|
|
|
from configuration_gpt1 import GPT1Config |
|
from modeling_gpt1 import GPT1Model, GPT1ForCausalLM |
|
|
|
|
|
GPT1Config.register_for_auto_class() |
|
GPT1Model.register_for_auto_class('AutoModel') |
|
GPT1ForCausalLM.register_for_auto_class('AutoModelForCausalLM') |
|
|
|
|
|
tokenized_datasets = load_from_disk('data') |
|
|
|
|
|
tokenized_datasets = tokenized_datasets.shuffle(seed=42) |
|
|
|
print(tokenized_datasets) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('.') |
|
config = GPT1Config() |
|
model = GPT1ForCausalLM(config) |
|
|
|
print(model) |
|
|
|
_total_params = sum(p.numel() for p in model.parameters()) |
|
print(f"Model parameters: {_total_params}") |
|
|
|
batch_size = 16 |
|
epochs = 100 |
|
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
|
|
optimizer = Adam(model.parameters(), lr=2.5e-4, weight_decay=0.01) |
|
scheduler = get_scheduler('cosine', |
|
optimizer=optimizer, |
|
num_warmup_steps=2000, |
|
num_training_steps=epochs * len(tokenized_datasets['train'])) |
|
|
|
args = TrainingArguments( |
|
output_dir='checkpoints', |
|
per_device_train_batch_size=batch_size, |
|
per_device_eval_batch_size=batch_size, |
|
evaluation_strategy='epoch', |
|
gradient_accumulation_steps=4, |
|
num_train_epochs=epochs, |
|
save_total_limit=10, |
|
max_grad_norm=1.0, |
|
logging_strategy='steps', |
|
logging_steps=100, |
|
logging_first_step=True, |
|
logging_nan_inf_filter=False, |
|
fp16=False, |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=args, |
|
data_collator=data_collator, |
|
train_dataset=tokenized_datasets['train'], |
|
eval_dataset=tokenized_datasets['test'], |
|
tokenizer=tokenizer, |
|
optimizers=(optimizer, scheduler), |
|
) |
|
|
|
print("Starting training...") |
|
|
|
trainer.train() |
|
|
|
torch.save(trainer.state.log_history, 'trainer_history.pt') |
|
|
|
trainer.save_model('trained') |
|
|