simon_says_v2 / utils /metrics.py
ericmichael's picture
first commit
0079b8d
raw
history blame contribute delete
469 Bytes
import evaluate
_accuracy_metric = evaluate.load("accuracy")
def accuracy(predictions, y):
unique_values = set(y + predictions)
label_mapping = {label: i for i, label in enumerate(unique_values)}
true_labels = [label_mapping[value] for value in y]
pred_labels = [label_mapping[value] for value in predictions]
# Compute accuracy
result = _accuracy_metric.compute(predictions=pred_labels, references=true_labels)
return result["accuracy"]