muellerzr's picture
muellerzr HF staff
Add comma between optimizer, scheduler as reported in https://github.com/huggingface/accelerate/issues/2661
13e3de7
raw
history blame
438 Bytes
<pre>
from accelerate import Accelerator
accelerator = Accelerator()
train_dataloader, model, optimizer, scheduler = accelerator.prepare(
dataloader, model, optimizer, scheduler
)
model.train()
for batch in train_dataloader:
inputs, targets = batch
outputs = model(inputs)
loss = loss_function(outputs, targets)
accelerator.backward(loss)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
</pre>