|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import evaluate |
|
import torch |
|
from datasets import load_dataset |
|
from torch.optim import AdamW |
|
from torch.utils.data import DataLoader |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup |
|
|
|
from accelerate import Accelerator, DistributedType |
|
from accelerate.utils import set_seed |
|
|
|
|
|
def get_dataloaders(accelerator: Accelerator, batch_size: int = 16): |
|
""" |
|
Creates a set of `DataLoader`s for the `glue` dataset, |
|
using "bert-base-cased" as the tokenizer. |
|
|
|
Args: |
|
accelerator (`Accelerator`): |
|
An `Accelerator` object |
|
batch_size (`int`, *optional*): |
|
The batch size for the train and validation DataLoaders. |
|
""" |
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") |
|
datasets = load_dataset("glue", "mrpc") |
|
|
|
def tokenize_function(examples): |
|
|
|
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) |
|
return outputs |
|
|
|
|
|
|
|
with accelerator.main_process_first(): |
|
tokenized_datasets = datasets.map( |
|
tokenize_function, |
|
batched=True, |
|
remove_columns=["idx", "sentence1", "sentence2"], |
|
) |
|
|
|
|
|
|
|
tokenized_datasets = tokenized_datasets.rename_column("label", "labels") |
|
|
|
def collate_fn(examples): |
|
|
|
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None |
|
|
|
if accelerator.mixed_precision != "no": |
|
pad_to_multiple_of = 8 |
|
else: |
|
pad_to_multiple_of = None |
|
|
|
return tokenizer.pad( |
|
examples, |
|
padding="longest", |
|
max_length=max_length, |
|
pad_to_multiple_of=pad_to_multiple_of, |
|
return_tensors="pt", |
|
) |
|
|
|
|
|
train_dataloader = DataLoader( |
|
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True |
|
) |
|
eval_dataloader = DataLoader( |
|
tokenized_datasets["validation"], |
|
shuffle=False, |
|
collate_fn=collate_fn, |
|
batch_size=32, |
|
drop_last=(accelerator.mixed_precision == "fp8"), |
|
) |
|
|
|
return train_dataloader, eval_dataloader |
|
|
|
|
|
def training_function(config): |
|
|
|
accelerator = Accelerator( |
|
mixed_precision="fp16", |
|
log_with="aim", |
|
project_dir="aim_logs" |
|
) |
|
|
|
lr = config["lr"] |
|
num_epochs = int(config["num_epochs"]) |
|
seed = int(config["seed"]) |
|
batch_size = 16 if accelerator.num_processes > 1 else 32 |
|
config["batch_size"] = batch_size |
|
metric = evaluate.load("glue", "mrpc") |
|
|
|
set_seed(seed, device_specific=True) |
|
train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size) |
|
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True) |
|
lr = lr * accelerator.num_processes |
|
|
|
optimizer = AdamW(params=model.parameters(), lr=lr) |
|
lr_scheduler = get_linear_schedule_with_warmup( |
|
optimizer=optimizer, |
|
num_warmup_steps=0, |
|
num_training_steps=(len(train_dataloader) * num_epochs), |
|
) |
|
|
|
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( |
|
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler |
|
) |
|
|
|
accelerator.init_trackers(f'{accelerator.num_processes}_gpus', config) |
|
|
|
current_step = 0 |
|
for epoch in range(num_epochs): |
|
model.train() |
|
total_loss = 0 |
|
for _, batch in enumerate(train_dataloader): |
|
lr = lr_scheduler.get_lr() |
|
outputs = model(**batch) |
|
loss = outputs.loss |
|
batch_loss = accelerator.gather(loss).detach().mean().cpu().float() |
|
total_loss += batch_loss |
|
current_step += 1 |
|
accelerator.log( |
|
{ |
|
"batch_loss":batch_loss, |
|
"learning_rate":lr, |
|
}, |
|
step=current_step, |
|
log_kwargs={"aim":{"epoch":epoch}} |
|
) |
|
accelerator.backward(loss) |
|
optimizer.step() |
|
lr_scheduler.step() |
|
optimizer.zero_grad() |
|
current_step += 1 |
|
|
|
model.eval() |
|
for step, batch in enumerate(eval_dataloader): |
|
|
|
batch.to(accelerator.device) |
|
with torch.no_grad(): |
|
outputs = model(**batch) |
|
predictions = outputs.logits.argmax(dim=-1) |
|
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"])) |
|
metric.add_batch( |
|
predictions=predictions, |
|
references=references, |
|
) |
|
|
|
eval_metric = metric.compute() |
|
|
|
|
|
accelerator.print(f"epoch {epoch}:", eval_metric) |
|
|
|
accelerator.log( |
|
{ |
|
"accuracy": eval_metric["accuracy"], |
|
"f1": eval_metric["f1"], |
|
"train_loss": total_loss.item() / len(train_dataloader), |
|
}, |
|
log_kwargs = {"aim":{"epoch":epoch}} |
|
) |
|
accelerator.end_training() |
|
|
|
|
|
def main(): |
|
config = {"lr": 2e-5, "num_epochs": 3, "seed": 42} |
|
training_function(config) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|