##
import evaluate
+from accelerate import Accelerator
+accelerator = Accelerator()
+train_dataloader, eval_dataloader, model, optimizer, scheduler = (
+    accelerator.prepare(
+        train_dataloader, eval_dataloader, 
+        model, optimizer, scheduler
+    )
+)
metric = evaluate.load("accuracy")
for batch in train_dataloader:
    optimizer.zero_grad()
    inputs, targets = batch
-    inputs = inputs.to(device)
-    targets = targets.to(device)
    outputs = model(inputs)
    loss = loss_function(outputs, targets)
    loss.backward()
    optimizer.step()
    scheduler.step()

model.eval()
for batch in eval_dataloader:
    inputs, targets = batch
-    inputs = inputs.to(device)
-    targets = targets.to(device)
    with torch.no_grad():
        outputs = model(inputs)
    predictions = outputs.argmax(dim=-1)
+    predictions, references = accelerator.gather_for_metrics(
+        (predictions, references)
+    )
    metric.add_batch(
        predictions = predictions,
        references = references
    )
print(metric.compute())
## When calculating metrics on a validation set, you can use the `Accelerator.gather_for_metrics` method to gather the predictions and references from all devices and then calculate the metric on the gathered values. This will also *automatically* drop the padded values from the gathered tensors that were added to ensure that all tensors have the same length. This ensures that the metric is calculated on the correct values. ## To learn more checkout the related documentation: - API reference - Example script