|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import evaluate |
|
import torch |
|
from datasets import load_dataset |
|
from torch.optim import AdamW |
|
from torch.utils.data import DataLoader |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup |
|
|
|
from accelerate import Accelerator, DistributedType |
|
from accelerate.utils import set_seed |
|
|
|
import transformers |
|
|
|
transformers.logging.set_verbosity_error() |
|
|
|
import os |
|
from torch.nn.parallel import DistributedDataParallel |
|
import torch.distributed as torch_distributed |
|
|
|
|
|
|
|
def get_dataloaders(batch_size: int = 16): |
|
""" |
|
Creates a set of `DataLoader`s for the `glue` dataset, |
|
using "bert-base-cased" as the tokenizer. |
|
|
|
Args: |
|
accelerator (`Accelerator`): |
|
An `Accelerator` object |
|
batch_size (`int`, *optional*): |
|
The batch size for the train and validation DataLoaders. |
|
""" |
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") |
|
datasets = load_dataset("glue", "mrpc") |
|
|
|
def tokenize_function(examples): |
|
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) |
|
return outputs |
|
|
|
tokenized_datasets = datasets.map( |
|
tokenize_function, |
|
batched=True, |
|
remove_columns=["idx", "sentence1", "sentence2"], |
|
) |
|
tokenized_datasets = tokenized_datasets.rename_column("label", "labels") |
|
|
|
def collate_fn(examples): |
|
return tokenizer.pad( |
|
examples, |
|
padding="longest", |
|
max_length=None, |
|
pad_to_multiple_of=8, |
|
return_tensors="pt", |
|
) |
|
|
|
train_dataloader = DataLoader( |
|
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True |
|
) |
|
eval_dataloader = DataLoader( |
|
tokenized_datasets["validation"], |
|
shuffle=False, |
|
collate_fn=collate_fn, |
|
batch_size=32, |
|
drop_last=False, |
|
) |
|
|
|
return train_dataloader, eval_dataloader |
|
|
|
|
|
def training_function(): |
|
torch_distributed.init_process_group(backend="nccl") |
|
num_processes = torch_distributed.get_world_size() |
|
process_index = torch_distributed.get_rank() |
|
local_process_index = int(os.environ.get("LOCAL_RANK", -1)) |
|
device = torch.device("cuda", local_process_index) |
|
torch.cuda.set_device(device) |
|
config = {"lr": 2e-5, "num_epochs": 3, "seed": 42} |
|
seed = int(config["seed"]) |
|
batch_size = 32 |
|
config["batch_size"] = batch_size |
|
metric = evaluate.load("glue", "mrpc") |
|
|
|
set_seed(seed, device_specific=False) |
|
train_dataloader, eval_dataloader = get_dataloaders(batch_size) |
|
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True).to(device) |
|
model = DistributedDataParallel( |
|
model, device_ids=[local_process_index], output_device=local_process_index |
|
) |
|
|
|
optimizer = AdamW(params=model.parameters(), lr=config["lr"]) |
|
lr_scheduler = get_linear_schedule_with_warmup( |
|
optimizer=optimizer, |
|
num_warmup_steps=0, |
|
num_training_steps=(len(train_dataloader) * config["num_epochs"]), |
|
) |
|
|
|
current_step = 0 |
|
for epoch in range(config["num_epochs"]): |
|
model.train() |
|
total_loss = 0 |
|
for _, batch in enumerate(train_dataloader): |
|
batch = batch.to(device) |
|
outputs = model(**batch) |
|
loss = outputs.loss |
|
total_loss += loss.detach().cpu().float() |
|
current_step += 1 |
|
loss.backward() |
|
optimizer.step() |
|
lr_scheduler.step() |
|
optimizer.zero_grad() |
|
|
|
model.eval() |
|
for step, batch in enumerate(eval_dataloader): |
|
|
|
batch = batch.to(device) |
|
with torch.no_grad(): |
|
outputs = model(**batch) |
|
predictions = outputs.logits.argmax(dim=-1) |
|
metric.add_batch( |
|
predictions=predictions, |
|
references=batch["labels"], |
|
) |
|
|
|
eval_metric = metric.compute() |
|
if process_index == 0: |
|
print( |
|
f"epoch {epoch}: {eval_metric}\n" |
|
f"train_loss: {total_loss.item()/len(train_dataloader)}" |
|
) |
|
|
|
|
|
def main(): |
|
training_function() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|