Linly-Talker / pytorch3d /tests /test_ops_utils.py
linxianzhong0128's picture
Upload folder using huggingface_hub
7088d16 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import numpy as np
import torch
from pytorch3d.ops import utils as oputil
from .common_testing import TestCaseMixin
class TestOpsUtils(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
np.random.seed(42)
def test_wmean(self):
device = torch.device("cuda:0")
n_points = 20
x = torch.rand(n_points, 3, device=device)
weight = torch.rand(n_points, device=device)
x_np = x.cpu().data.numpy()
weight_np = weight.cpu().data.numpy()
# test unweighted
mean = oputil.wmean(x, keepdim=False)
mean_gt = np.average(x_np, axis=-2)
self.assertClose(mean.cpu().data.numpy(), mean_gt)
# test weighted
mean = oputil.wmean(x, weight=weight, keepdim=False)
mean_gt = np.average(x_np, axis=-2, weights=weight_np)
self.assertClose(mean.cpu().data.numpy(), mean_gt)
# test keepdim
mean = oputil.wmean(x, weight=weight, keepdim=True)
self.assertClose(mean[0].cpu().data.numpy(), mean_gt)
# test binary weigths
mean = oputil.wmean(x, weight=weight > 0.5, keepdim=False)
mean_gt = np.average(x_np, axis=-2, weights=weight_np > 0.5)
self.assertClose(mean.cpu().data.numpy(), mean_gt)
# test broadcasting
x = torch.rand(10, n_points, 3, device=device)
x_np = x.cpu().data.numpy()
mean = oputil.wmean(x, weight=weight, keepdim=False)
mean_gt = np.average(x_np, axis=-2, weights=weight_np)
self.assertClose(mean.cpu().data.numpy(), mean_gt)
weight = weight[None, None, :].repeat(3, 1, 1)
mean = oputil.wmean(x, weight=weight, keepdim=False)
self.assertClose(mean[0].cpu().data.numpy(), mean_gt)
# test failing broadcasting
weight = torch.rand(x.shape[0], device=device)
with self.assertRaises(ValueError) as context:
oputil.wmean(x, weight=weight, keepdim=False)
self.assertTrue("weights are not compatible" in str(context.exception))
# test dim
weight = torch.rand(x.shape[0], n_points, device=device)
weight_np = np.tile(
weight[:, :, None].cpu().data.numpy(), (1, 1, x_np.shape[-1])
)
mean = oputil.wmean(x, dim=0, weight=weight, keepdim=False)
mean_gt = np.average(x_np, axis=0, weights=weight_np)
self.assertClose(mean.cpu().data.numpy(), mean_gt)
# test dim tuple
mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False)
mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np)
self.assertClose(mean.cpu().data.numpy(), mean_gt)
def test_masked_gather_errors(self):
idx = torch.randint(0, 10, size=(5, 10, 4, 2))
points = torch.randn(size=(5, 10, 3))
with self.assertRaisesRegex(ValueError, "format is not supported"):
oputil.masked_gather(points, idx)
points = torch.randn(size=(2, 10, 3))
with self.assertRaisesRegex(ValueError, "same batch dimension"):
oputil.masked_gather(points, idx)