# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import unittest import torch from pytorch3d.common.linear_with_repeat import LinearWithRepeat from .common_testing import TestCaseMixin class TestLinearWithRepeat(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: super().setUp() torch.manual_seed(42) def test_simple(self): x = torch.rand(4, 6, 7, 3) y = torch.rand(4, 6, 4) linear = torch.nn.Linear(7, 8) torch.nn.init.xavier_uniform_(linear.weight.data) linear.bias.data.uniform_() equivalent = torch.cat([x, y.unsqueeze(-2).expand(4, 6, 7, 4)], dim=-1) expected = linear.forward(equivalent) linear_with_repeat = LinearWithRepeat(7, 8) linear_with_repeat.load_state_dict(linear.state_dict()) actual = linear_with_repeat.forward((x, y)) self.assertClose(actual, expected, rtol=1e-4)