Plonk / metrics /distance_based.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
import torch
from metrics.utils import haversine, reverse
from sklearn.metrics import pairwise_distances
from torchmetrics import Metric
import numpy as np
from utils.kde import BatchedKDE
from tqdm import tqdm
class HaversineMetrics(Metric):
"""
Computes the average haversine distance between the predicted and ground truth points.
Compute the accuracy given some radiuses.
Compute the Geoguessr score given some radiuses.
Args:
acc_radiuses (list): list of radiuses to compute the accuracy from
acc_area (list): list of areas to compute the accuracy from.
"""
def __init__(
self,
acc_radiuses=[],
acc_area=["country", "region", "sub-region", "city"],
use_kde=False,
manifold_k=3,
):
super().__init__()
self.use_kde = use_kde
self.add_state("haversine_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("geoguessr_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
for acc in acc_radiuses:
self.add_state(
f"close_enough_points_{acc}",
default=torch.tensor(0.0),
dist_reduce_fx="sum",
)
for acc in acc_area:
self.add_state(
f"close_enough_points_{acc}",
default=torch.tensor(0.0),
dist_reduce_fx="sum",
)
self.add_state(
f"count_{acc}", default=torch.tensor(0), dist_reduce_fx="sum"
)
self.acc_radius = acc_radiuses
self.acc_area = acc_area
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state(
"real_points",
[],
dist_reduce_fx=None,
)
self.add_state(
"fake_points",
[],
dist_reduce_fx=None,
)
self.manifold_k = manifold_k
def update(self, pred, gt):
if self.use_kde:
(x_mode, y_mode), kde = estimate_kde_mode(pred["gps"])
# self.nll_sum += -torch.log(
# kde.score(gt["gps"].unsqueeze(1).to(pred["gps"].device))
# ).sum()
pred["gps"] = torch.stack([x_mode, y_mode], dim=1)
# Handle NaN values without modifying the original inputs
if pred["gps"].isnan().any():
valid_mask = ~pred["gps"].isnan().any(dim=1)
pred_gps = pred["gps"][valid_mask]
gt_gps = gt["gps"][valid_mask]
if len(pred_gps) == 0: # Skip if no valid predictions remain
return
else:
pred_gps = pred["gps"]
gt_gps = gt["gps"]
haversine_distance = haversine(pred_gps, gt_gps)
for acc in self.acc_radius:
self.__dict__[f"close_enough_points_{acc}"] += (
haversine_distance < acc
).sum()
if len(self.acc_area) > 0:
area_pred, area_gt = reverse(pred_gps, gt, self.acc_area)
for acc in self.acc_area:
self.__dict__[f"close_enough_points_{acc}"] += (
area_pred[acc] == area_gt["_".join(["unique", acc])]
).sum()
self.__dict__[f"count_{acc}"] += len(area_gt["_".join(["unique", acc])])
self.haversine_sum += haversine_distance.sum()
self.geoguessr_sum += 5000 * torch.exp(-haversine_distance / 1492.7).sum()
self.real_points.append(gt_gps)
self.fake_points.append(pred_gps)
self.count += pred_gps.shape[0]
def compute(self):
output = {
"Haversine": self.haversine_sum / self.count,
"Geoguessr": self.geoguessr_sum / self.count,
}
for acc in self.acc_radius:
output[f"Accuracy_{acc}_km_radius"] = (
self.__dict__[f"close_enough_points_{acc}"] / self.count
)
for acc in self.acc_area:
output[f"Accuracy_{acc}"] = (
self.__dict__[f"close_enough_points_{acc}"]
/ self.__dict__[f"count_{acc}"]
)
real_points = torch.cat(self.real_points, dim=0)
fake_points = torch.cat(self.fake_points, dim=0)
(
output["precision"],
output["recall"],
output["density"],
output["coverage"],
) = self.manifold_metrics(real_points, fake_points, self.manifold_k)
return output
def compute_pairwise_distance(self, data_x, data_y=None):
"""
Args:
data_x: numpy.ndarray([N, feature_dim], dtype=np.float32)
data_y: numpy.ndarray([N, feature_dim], dtype=np.float32)
Returns:
numpy.ndarray([N, N], dtype=np.float32) of pairwise distances.
"""
if data_y is None:
data_y = data_x
dists = pairwise_distances(data_x, data_y, metric="haversine", n_jobs=8)
return dists
def get_kth_value(self, unsorted, k, axis=-1):
"""
Args:
unsorted: numpy.ndarray of any dimensionality.
k: int
Returns:
kth values along the designated axis.
"""
indices = np.argpartition(unsorted, k, axis=axis)[..., :k]
k_smallests = np.take_along_axis(unsorted, indices, axis=axis)
kth_values = k_smallests.max(axis=axis)
return kth_values
def compute_nearest_neighbour_distances(self, input_features, nearest_k):
"""
Args:
input_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
nearest_k: int
Returns:
Distances to kth nearest neighbours.
"""
distances = self.compute_pairwise_distance(input_features)
radii = self.get_kth_value(distances, k=nearest_k + 1, axis=-1)
return radii
def compute_prdc(self, real_features, fake_features, nearest_k):
"""
Computes precision, recall, density, and coverage given two manifolds.
Args:
real_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
nearest_k: int.
Returns:
dict of precision, recall, density, and coverage.
"""
real_nearest_neighbour_distances = self.compute_nearest_neighbour_distances(
real_features, nearest_k
)
fake_nearest_neighbour_distances = self.compute_nearest_neighbour_distances(
fake_features, nearest_k
)
distance_real_fake = self.compute_pairwise_distance(
real_features, fake_features
)
precision = (
(
distance_real_fake
< np.expand_dims(real_nearest_neighbour_distances, axis=1)
)
.any(axis=0)
.mean()
)
recall = (
(
distance_real_fake
< np.expand_dims(fake_nearest_neighbour_distances, axis=0)
)
.any(axis=1)
.mean()
)
density = (1.0 / float(nearest_k)) * (
distance_real_fake
< np.expand_dims(real_nearest_neighbour_distances, axis=1)
).sum(axis=0).mean()
coverage = (
distance_real_fake.min(axis=1) < real_nearest_neighbour_distances
).mean()
return precision, recall, density, coverage
def manifold_metrics(self, real_features, fake_features, nearest_k, num_splits=20):
"""
Computes precision, recall, density, and coverage given two manifolds.
Args:
real_features: torch.Tensor([N, feature_dim], dtype=torch.float32)
fake_features: torch.Tensor([N, feature_dim], dtype=torch.float32)
nearest_k: int.
num_splits: int. Number of splits to use for computing metrics.
Returns:
dict of precision, recall, density, and coverage.
"""
real_features = real_features.chunk(num_splits, dim=0)
fake_features = fake_features.chunk(num_splits, dim=0)
precision, recall, density, coverage = [], [], [], []
for real, fake in tqdm(
zip(real_features, fake_features), desc="Computing manifold"
):
p, r, d, c = self.compute_prdc(
real.cpu().numpy(), fake.cpu().numpy(), nearest_k=nearest_k
)
precision.append(torch.tensor(p, device=real.device))
recall.append(torch.tensor(r, device=real.device))
density.append(torch.tensor(d, device=real.device))
coverage.append(torch.tensor(c, device=real.device))
return (
torch.stack(precision).mean().item(),
torch.stack(recall).mean().item(),
torch.stack(density).mean().item(),
torch.stack(coverage).mean().item(),
)
def estimate_kde_mode(points):
kde = BatchedKDE()
kde.fit(points)
batch_size = points.shape[0]
X, Y, positions = batched_make_grid(points.cpu())
X = X.to(points.device)
Y = Y.to(points.device)
positions = positions.to(points.device)
Z = kde.score(positions).reshape(X.shape)
x_mode = X.reshape(batch_size, -1)[
torch.arange(batch_size), Z.reshape(batch_size, -1).argmax(dim=1)
]
y_mode = Y.reshape(batch_size, -1)[
torch.arange(batch_size), Z.reshape(batch_size, -1).argmax(dim=1)
]
return (x_mode, y_mode), kde
def make_grid(points):
(lat_min, long_min), _ = points.min(dim=-2)
(lat_max, long_max), _ = points.max(dim=-2)
x = torch.linspace(lat_min, lat_max, 100)
y = torch.linspace(long_min, long_max, 100)
X, Y = torch.meshgrid(x, y)
positions = torch.vstack([X.flatten(), Y.flatten()]).transpose(-1, -2)
return X, Y, positions
batched_make_grid = torch.vmap(make_grid)