File size: 437 Bytes
b91e31d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d091751
b91e31d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
<pre>
from accelerate import Accelerator
accelerator = Accelerator()
train_dataloader, model, optimizer scheduler = accelerator.prepare(
        dataloader, model, optimizer, scheduler
)

model.train()
for batch in train_dataloader:
    inputs, targets = batch
    outputs = model(inputs)
    loss = loss_function(outputs, targets)
    accelerator.backward(loss)
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()
</pre>