Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import unittest | |
import numpy as np | |
import torch | |
from pytorch3d.ops import utils as oputil | |
from .common_testing import TestCaseMixin | |
class TestOpsUtils(TestCaseMixin, unittest.TestCase): | |
def setUp(self) -> None: | |
super().setUp() | |
torch.manual_seed(42) | |
np.random.seed(42) | |
def test_wmean(self): | |
device = torch.device("cuda:0") | |
n_points = 20 | |
x = torch.rand(n_points, 3, device=device) | |
weight = torch.rand(n_points, device=device) | |
x_np = x.cpu().data.numpy() | |
weight_np = weight.cpu().data.numpy() | |
# test unweighted | |
mean = oputil.wmean(x, keepdim=False) | |
mean_gt = np.average(x_np, axis=-2) | |
self.assertClose(mean.cpu().data.numpy(), mean_gt) | |
# test weighted | |
mean = oputil.wmean(x, weight=weight, keepdim=False) | |
mean_gt = np.average(x_np, axis=-2, weights=weight_np) | |
self.assertClose(mean.cpu().data.numpy(), mean_gt) | |
# test keepdim | |
mean = oputil.wmean(x, weight=weight, keepdim=True) | |
self.assertClose(mean[0].cpu().data.numpy(), mean_gt) | |
# test binary weigths | |
mean = oputil.wmean(x, weight=weight > 0.5, keepdim=False) | |
mean_gt = np.average(x_np, axis=-2, weights=weight_np > 0.5) | |
self.assertClose(mean.cpu().data.numpy(), mean_gt) | |
# test broadcasting | |
x = torch.rand(10, n_points, 3, device=device) | |
x_np = x.cpu().data.numpy() | |
mean = oputil.wmean(x, weight=weight, keepdim=False) | |
mean_gt = np.average(x_np, axis=-2, weights=weight_np) | |
self.assertClose(mean.cpu().data.numpy(), mean_gt) | |
weight = weight[None, None, :].repeat(3, 1, 1) | |
mean = oputil.wmean(x, weight=weight, keepdim=False) | |
self.assertClose(mean[0].cpu().data.numpy(), mean_gt) | |
# test failing broadcasting | |
weight = torch.rand(x.shape[0], device=device) | |
with self.assertRaises(ValueError) as context: | |
oputil.wmean(x, weight=weight, keepdim=False) | |
self.assertTrue("weights are not compatible" in str(context.exception)) | |
# test dim | |
weight = torch.rand(x.shape[0], n_points, device=device) | |
weight_np = np.tile( | |
weight[:, :, None].cpu().data.numpy(), (1, 1, x_np.shape[-1]) | |
) | |
mean = oputil.wmean(x, dim=0, weight=weight, keepdim=False) | |
mean_gt = np.average(x_np, axis=0, weights=weight_np) | |
self.assertClose(mean.cpu().data.numpy(), mean_gt) | |
# test dim tuple | |
mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False) | |
mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np) | |
self.assertClose(mean.cpu().data.numpy(), mean_gt) | |
def test_masked_gather_errors(self): | |
idx = torch.randint(0, 10, size=(5, 10, 4, 2)) | |
points = torch.randn(size=(5, 10, 3)) | |
with self.assertRaisesRegex(ValueError, "format is not supported"): | |
oputil.masked_gather(points, idx) | |
points = torch.randn(size=(2, 10, 3)) | |
with self.assertRaisesRegex(ValueError, "same batch dimension"): | |
oputil.masked_gather(points, idx) | |