accelerate_examples / code_samples /base /calculating_metrics
muellerzr's picture
muellerzr HF staff
Adding tabs for different set of Accelerate's features and content for large scale training features (#2)
b91e31d
raw
history blame
1.99 kB
##
<pre>
import evaluate
+from accelerate import Accelerator
+accelerator = Accelerator()
+train_dataloader, eval_dataloader, model, optimizer, scheduler = (
+ accelerator.prepare(
+ train_dataloader, eval_dataloader,
+ model, optimizer, scheduler
+ )
+)
metric = evaluate.load("accuracy")
for batch in train_dataloader:
optimizer.zero_grad()
inputs, targets = batch
- inputs = inputs.to(device)
- targets = targets.to(device)
outputs = model(inputs)
loss = loss_function(outputs, targets)
loss.backward()
optimizer.step()
scheduler.step()
model.eval()
for batch in eval_dataloader:
inputs, targets = batch
- inputs = inputs.to(device)
- targets = targets.to(device)
with torch.no_grad():
outputs = model(inputs)
predictions = outputs.argmax(dim=-1)
+ predictions, references = accelerator.gather_for_metrics(
+ (predictions, references)
+ )
metric.add_batch(
predictions = predictions,
references = references
)
print(metric.compute())</pre>
##
When calculating metrics on a validation set, you can use the `Accelerator.gather_for_metrics`
method to gather the predictions and references from all devices and then calculate the metric on the gathered values.
This will also *automatically* drop the padded values from the gathered tensors that were added to ensure
that all tensors have the same length. This ensures that the metric is calculated on the correct values.
##
To learn more checkout the related documentation:
- <a href="https://huggingface.co/docs/accelerate/en/quicktour#distributed-evaluation" target="_blank">Quicktour - Calculating metrics</a>
- <a href="https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.gather_for_metrics" target="_blank">API reference</a>
- <a href="https://github.com/huggingface/accelerate/blob/main/examples/by_feature/multi_process_metrics.py" target="_blank">Example script</a>