``` import evaluate metric = evaluate.load("accuracy") for batch in train_dataloader: optimizer.zero_grad() inputs, targets = batch inputs = inputs.to(device) targets = targets.to(device) outputs = model(inputs) loss = loss_function(outputs, targets) loss.backward() optimizer.step() scheduler.step() model.eval() for batch in eval_dataloader: inputs, targets = batch inputs = inputs.to(device) targets = targets.to(device) with torch.no_grad(): outputs = model(inputs) predictions = outputs.argmax(dim=-1) metric.add_batch( predictions = predictions, references = references ) print(metric.compute()) ```