import evaluate
metric = evaluate.load("accuracy")
for batch in train_dataloader:
    inputs, targets = batch
    inputs = inputs.to(device)
    targets = targets.to(device)
    outputs = model(inputs)
    loss = loss_function(outputs, targets)
    loss.backward()
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()

model.eval()
for batch in eval_dataloader:
    inputs, targets = batch
    inputs = inputs.to(device)
    targets = targets.to(device)
    with torch.no_grad():
        outputs = model(inputs)
    predictions = outputs.argmax(dim=-1)
    metric.add_batch(
        predictions = predictions,
        references = references
    )
print(metric.compute())