File size: 1,560 Bytes
b91e31d
 
 
 
13e3de7
b91e31d
 
 
 
 
 
 
 
 
 
18db29a
b91e31d
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
##
<pre>
from accelerate import Accelerator
accelerator = Accelerator()
dataloader, model, optimizer, scheduler = accelerator.prepare(
        dataloader, model, optimizer, scheduler
)

for batch in dataloader:
    inputs, targets = batch
    outputs = model(inputs)
    loss = loss_function(outputs, targets)
    accelerator.backward(loss)
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()
+accelerator.save_state("checkpoint_dir")
+accelerator.load_state("checkpoint_dir")</pre>
##
To save or load a checkpoint in, `Accelerator` provides the `save_state` and `load_state` methods.
These methods will save or load the state of the model, optimizer, scheduler, as well as random states and
any custom registered objects from the main process on each device to a passed in folder. 
**This API is designed to save and resume training states only from within the same python script or training setup.**
##
To learn more checkout the related documentation:
- <a href="https://huggingface.co/docs/accelerate/usage_guides/checkpoint" target="_blank">Saving and loading training states</a>
- <a href="https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" target="_blank">`save_state` API reference</a>
- <a href="https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.load_state" target="_blank">`load_state` API reference</a>
- <a href="https://github.com/huggingface/accelerate/blob/main/examples/by_feature/checkpointing.py" target="_blank">Example script</a>