RobustViT / SegmentationTest /utils /confusionmatrix.py
Hila's picture
init commit
7754b29
raw
history blame
3.62 kB
import numpy as np
import torch
from . import metric
class ConfusionMatrix(metric.Metric):
"""Constructs a confusion matrix for a multi-class classification problems.
Does not support multi-label, multi-class problems.
Keyword arguments:
- num_classes (int): number of classes in the classification problem.
- normalized (boolean, optional): Determines whether or not the confusion
matrix is normalized or not. Default: False.
Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py
"""
def __init__(self, num_classes, normalized=False):
super().__init__()
self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32)
self.normalized = normalized
self.num_classes = num_classes
self.reset()
def reset(self):
self.conf.fill(0)
def add(self, predicted, target):
"""Computes the confusion matrix
The shape of the confusion matrix is K x K, where K is the number
of classes.
Keyword arguments:
- predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of
predicted scores obtained from the model for N examples and K classes,
or an N-tensor/array of integer values between 0 and K-1.
- target (Tensor or numpy.ndarray): Can be an N x K tensor/array of
ground-truth classes for N examples and K classes, or an N-tensor/array
of integer values between 0 and K-1.
"""
# If target and/or predicted are tensors, convert them to numpy arrays
if torch.is_tensor(predicted):
predicted = predicted.cpu().numpy()
if torch.is_tensor(target):
target = target.cpu().numpy()
assert predicted.shape[0] == target.shape[0], \
'number of targets and predicted outputs do not match'
if np.ndim(predicted) != 1:
assert predicted.shape[1] == self.num_classes, \
'number of predictions does not match size of confusion matrix'
predicted = np.argmax(predicted, 1)
else:
assert (predicted.max() < self.num_classes) and (predicted.min() >= 0), \
'predicted values are not between 0 and k-1'
if np.ndim(target) != 1:
assert target.shape[1] == self.num_classes, \
'Onehot target does not match size of confusion matrix'
assert (target >= 0).all() and (target <= 1).all(), \
'in one-hot encoding, target values should be 0 or 1'
assert (target.sum(1) == 1).all(), \
'multi-label setting is not supported'
target = np.argmax(target, 1)
else:
assert (target.max() < self.num_classes) and (target.min() >= 0), \
'target values are not between 0 and k-1'
# hack for bincounting 2 arrays together
x = predicted + self.num_classes * target
bincount_2d = np.bincount(
x.astype(np.int32), minlength=self.num_classes**2)
assert bincount_2d.size == self.num_classes**2
conf = bincount_2d.reshape((self.num_classes, self.num_classes))
self.conf += conf
def value(self):
"""
Returns:
Confustion matrix of K rows and K columns, where rows corresponds
to ground-truth targets and columns corresponds to predicted
targets.
"""
if self.normalized:
conf = self.conf.astype(np.float32)
return conf / conf.sum(1).clip(min=1e-12)[:, None]
else:
return self.conf