import torch from metrics.utils import haversine, reverse from sklearn.metrics import pairwise_distances from torchmetrics import Metric import numpy as np from utils.kde import BatchedKDE from tqdm import tqdm class HaversineMetrics(Metric): """ Computes the average haversine distance between the predicted and ground truth points. Compute the accuracy given some radiuses. Compute the Geoguessr score given some radiuses. Args: acc_radiuses (list): list of radiuses to compute the accuracy from acc_area (list): list of areas to compute the accuracy from. """ def __init__( self, acc_radiuses=[], acc_area=["country", "region", "sub-region", "city"], use_kde=False, manifold_k=3, ): super().__init__() self.use_kde = use_kde self.add_state("haversine_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("geoguessr_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") for acc in acc_radiuses: self.add_state( f"close_enough_points_{acc}", default=torch.tensor(0.0), dist_reduce_fx="sum", ) for acc in acc_area: self.add_state( f"close_enough_points_{acc}", default=torch.tensor(0.0), dist_reduce_fx="sum", ) self.add_state( f"count_{acc}", default=torch.tensor(0), dist_reduce_fx="sum" ) self.acc_radius = acc_radiuses self.acc_area = acc_area self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state( "real_points", [], dist_reduce_fx=None, ) self.add_state( "fake_points", [], dist_reduce_fx=None, ) self.manifold_k = manifold_k def update(self, pred, gt): if self.use_kde: (x_mode, y_mode), kde = estimate_kde_mode(pred["gps"]) # self.nll_sum += -torch.log( # kde.score(gt["gps"].unsqueeze(1).to(pred["gps"].device)) # ).sum() pred["gps"] = torch.stack([x_mode, y_mode], dim=1) # Handle NaN values without modifying the original inputs if pred["gps"].isnan().any(): valid_mask = ~pred["gps"].isnan().any(dim=1) pred_gps = pred["gps"][valid_mask] gt_gps = gt["gps"][valid_mask] if len(pred_gps) == 0: # Skip if no valid predictions remain return else: pred_gps = pred["gps"] gt_gps = gt["gps"] haversine_distance = haversine(pred_gps, gt_gps) for acc in self.acc_radius: self.__dict__[f"close_enough_points_{acc}"] += ( haversine_distance < acc ).sum() if len(self.acc_area) > 0: area_pred, area_gt = reverse(pred_gps, gt, self.acc_area) for acc in self.acc_area: self.__dict__[f"close_enough_points_{acc}"] += ( area_pred[acc] == area_gt["_".join(["unique", acc])] ).sum() self.__dict__[f"count_{acc}"] += len(area_gt["_".join(["unique", acc])]) self.haversine_sum += haversine_distance.sum() self.geoguessr_sum += 5000 * torch.exp(-haversine_distance / 1492.7).sum() self.real_points.append(gt_gps) self.fake_points.append(pred_gps) self.count += pred_gps.shape[0] def compute(self): output = { "Haversine": self.haversine_sum / self.count, "Geoguessr": self.geoguessr_sum / self.count, } for acc in self.acc_radius: output[f"Accuracy_{acc}_km_radius"] = ( self.__dict__[f"close_enough_points_{acc}"] / self.count ) for acc in self.acc_area: output[f"Accuracy_{acc}"] = ( self.__dict__[f"close_enough_points_{acc}"] / self.__dict__[f"count_{acc}"] ) real_points = torch.cat(self.real_points, dim=0) fake_points = torch.cat(self.fake_points, dim=0) ( output["precision"], output["recall"], output["density"], output["coverage"], ) = self.manifold_metrics(real_points, fake_points, self.manifold_k) return output def compute_pairwise_distance(self, data_x, data_y=None): """ Args: data_x: numpy.ndarray([N, feature_dim], dtype=np.float32) data_y: numpy.ndarray([N, feature_dim], dtype=np.float32) Returns: numpy.ndarray([N, N], dtype=np.float32) of pairwise distances. """ if data_y is None: data_y = data_x dists = pairwise_distances(data_x, data_y, metric="haversine", n_jobs=8) return dists def get_kth_value(self, unsorted, k, axis=-1): """ Args: unsorted: numpy.ndarray of any dimensionality. k: int Returns: kth values along the designated axis. """ indices = np.argpartition(unsorted, k, axis=axis)[..., :k] k_smallests = np.take_along_axis(unsorted, indices, axis=axis) kth_values = k_smallests.max(axis=axis) return kth_values def compute_nearest_neighbour_distances(self, input_features, nearest_k): """ Args: input_features: numpy.ndarray([N, feature_dim], dtype=np.float32) nearest_k: int Returns: Distances to kth nearest neighbours. """ distances = self.compute_pairwise_distance(input_features) radii = self.get_kth_value(distances, k=nearest_k + 1, axis=-1) return radii def compute_prdc(self, real_features, fake_features, nearest_k): """ Computes precision, recall, density, and coverage given two manifolds. Args: real_features: numpy.ndarray([N, feature_dim], dtype=np.float32) fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32) nearest_k: int. Returns: dict of precision, recall, density, and coverage. """ real_nearest_neighbour_distances = self.compute_nearest_neighbour_distances( real_features, nearest_k ) fake_nearest_neighbour_distances = self.compute_nearest_neighbour_distances( fake_features, nearest_k ) distance_real_fake = self.compute_pairwise_distance( real_features, fake_features ) precision = ( ( distance_real_fake < np.expand_dims(real_nearest_neighbour_distances, axis=1) ) .any(axis=0) .mean() ) recall = ( ( distance_real_fake < np.expand_dims(fake_nearest_neighbour_distances, axis=0) ) .any(axis=1) .mean() ) density = (1.0 / float(nearest_k)) * ( distance_real_fake < np.expand_dims(real_nearest_neighbour_distances, axis=1) ).sum(axis=0).mean() coverage = ( distance_real_fake.min(axis=1) < real_nearest_neighbour_distances ).mean() return precision, recall, density, coverage def manifold_metrics(self, real_features, fake_features, nearest_k, num_splits=20): """ Computes precision, recall, density, and coverage given two manifolds. Args: real_features: torch.Tensor([N, feature_dim], dtype=torch.float32) fake_features: torch.Tensor([N, feature_dim], dtype=torch.float32) nearest_k: int. num_splits: int. Number of splits to use for computing metrics. Returns: dict of precision, recall, density, and coverage. """ real_features = real_features.chunk(num_splits, dim=0) fake_features = fake_features.chunk(num_splits, dim=0) precision, recall, density, coverage = [], [], [], [] for real, fake in tqdm( zip(real_features, fake_features), desc="Computing manifold" ): p, r, d, c = self.compute_prdc( real.cpu().numpy(), fake.cpu().numpy(), nearest_k=nearest_k ) precision.append(torch.tensor(p, device=real.device)) recall.append(torch.tensor(r, device=real.device)) density.append(torch.tensor(d, device=real.device)) coverage.append(torch.tensor(c, device=real.device)) return ( torch.stack(precision).mean().item(), torch.stack(recall).mean().item(), torch.stack(density).mean().item(), torch.stack(coverage).mean().item(), ) def estimate_kde_mode(points): kde = BatchedKDE() kde.fit(points) batch_size = points.shape[0] X, Y, positions = batched_make_grid(points.cpu()) X = X.to(points.device) Y = Y.to(points.device) positions = positions.to(points.device) Z = kde.score(positions).reshape(X.shape) x_mode = X.reshape(batch_size, -1)[ torch.arange(batch_size), Z.reshape(batch_size, -1).argmax(dim=1) ] y_mode = Y.reshape(batch_size, -1)[ torch.arange(batch_size), Z.reshape(batch_size, -1).argmax(dim=1) ] return (x_mode, y_mode), kde def make_grid(points): (lat_min, long_min), _ = points.min(dim=-2) (lat_max, long_max), _ = points.max(dim=-2) x = torch.linspace(lat_min, lat_max, 100) y = torch.linspace(long_min, long_max, 100) X, Y = torch.meshgrid(x, y) positions = torch.vstack([X.flatten(), Y.flatten()]).transpose(-1, -2) return X, Y, positions batched_make_grid = torch.vmap(make_grid)