performance-debugging / scripts /core_example_multigpu.py
muellerzr's picture
muellerzr HF staff
Core scripts work 1:1
c995e38
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import evaluate
import torch
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
from accelerate import Accelerator, DistributedType
from accelerate.utils import set_seed
import transformers
transformers.logging.set_verbosity_error()
import os
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as torch_distributed
def get_dataloaders(batch_size: int = 16):
"""
Creates a set of `DataLoader`s for the `glue` dataset,
using "bert-base-cased" as the tokenizer.
Args:
accelerator (`Accelerator`):
An `Accelerator` object
batch_size (`int`, *optional*):
The batch size for the train and validation DataLoaders.
"""
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
datasets = load_dataset("glue", "mrpc")
def tokenize_function(examples):
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
return outputs
tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
remove_columns=["idx", "sentence1", "sentence2"],
)
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
def collate_fn(examples):
return tokenizer.pad(
examples,
padding="longest",
max_length=None,
pad_to_multiple_of=8,
return_tensors="pt",
)
train_dataloader = DataLoader(
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True
)
eval_dataloader = DataLoader(
tokenized_datasets["validation"],
shuffle=False,
collate_fn=collate_fn,
batch_size=32,
drop_last=False,
)
return train_dataloader, eval_dataloader
def training_function():
torch_distributed.init_process_group(backend="nccl")
num_processes = torch_distributed.get_world_size()
process_index = torch_distributed.get_rank()
local_process_index = int(os.environ.get("LOCAL_RANK", -1))
device = torch.device("cuda", local_process_index)
torch.cuda.set_device(device)
config = {"lr": 2e-5, "num_epochs": 3, "seed": 42}
seed = int(config["seed"])
batch_size = 32 # Check if this needs to be 32?
config["batch_size"] = batch_size
metric = evaluate.load("glue", "mrpc")
set_seed(seed, device_specific=False)
train_dataloader, eval_dataloader = get_dataloaders(batch_size)
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True).to(device)
model = DistributedDataParallel(
model, device_ids=[local_process_index], output_device=local_process_index
)
optimizer = AdamW(params=model.parameters(), lr=config["lr"])
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=(len(train_dataloader) * config["num_epochs"]),
)
current_step = 0
for epoch in range(config["num_epochs"]):
model.train()
total_loss = 0
for _, batch in enumerate(train_dataloader):
batch = batch.to(device)
outputs = model(**batch)
loss = outputs.loss
total_loss += loss.detach().cpu().float()
current_step += 1
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
model.eval()
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch = batch.to(device)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
metric.add_batch(
predictions=predictions,
references=batch["labels"],
)
eval_metric = metric.compute()
if process_index == 0:
print(
f"epoch {epoch}: {eval_metric}\n"
f"train_loss: {total_loss.item()/len(train_dataloader)}"
)
def main():
training_function()
if __name__ == "__main__":
main()