File size: 1,800 Bytes
4bd209d
0ba62b0
 
 
 
4bd209d
 
 
 
 
0ba62b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bd209d
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
##
<pre>
import evaluate
+from accelerate import Accelerator
+accelerator = Accelerator()
+train_dataloader, eval_dataloader, model, optimizer, scheduler = (
+    accelerator.prepare(
+        train_dataloader, eval_dataloader, 
+        model, optimizer, scheduler
+    )
+)
metric = evaluate.load("accuracy")
for batch in train_dataloader:
    optimizer.zero_grad()
    inputs, targets = batch
-    inputs = inputs.to(device)
-    targets = targets.to(device)
    outputs = model(inputs)
    loss = loss_function(outputs, targets)
    loss.backward()
    optimizer.step()
    scheduler.step()

model.eval()
for batch in eval_dataloader:
    inputs, targets = batch
-    inputs = inputs.to(device)
-    targets = targets.to(device)
    with torch.no_grad():
        outputs = model(inputs)
    predictions = outputs.argmax(dim=-1)
+    predictions, references = accelerator.gather_for_metrics(
+        (predictions, references)
+    )
    metric.add_batch(
        predictions = predictions,
        references = references
    )
print(metric.compute())</pre>

##
When calculating metrics on a validation set, you can use the `Accelerator.gather_for_metrics`
method to gather the predictions and references from all devices and then calculate the metric on the gathered values. 
This will also *automatically* drop the padded values from the gathered tensors that were added to ensure 
that all tensors have the same length. This ensures that the metric is calculated on the correct values.
##
To learn more checkout the related documentation:
- [API reference](https://huggingface.co/docs/accelerate/v0.15.0/package_reference/accelerator#accelerate.Accelerator.gather_for_metrics)
- [Example script](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/multi_process_metrics.py)