Linly-Talker / pytorch3d /tests /test_common_linear_with_repeat.py
linxianzhong0128's picture
Upload folder using huggingface_hub
7088d16 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from .common_testing import TestCaseMixin
class TestLinearWithRepeat(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
def test_simple(self):
x = torch.rand(4, 6, 7, 3)
y = torch.rand(4, 6, 4)
linear = torch.nn.Linear(7, 8)
torch.nn.init.xavier_uniform_(linear.weight.data)
linear.bias.data.uniform_()
equivalent = torch.cat([x, y.unsqueeze(-2).expand(4, 6, 7, 4)], dim=-1)
expected = linear.forward(equivalent)
linear_with_repeat = LinearWithRepeat(7, 8)
linear_with_repeat.load_state_dict(linear.state_dict())
actual = linear_with_repeat.forward((x, y))
self.assertClose(actual, expected, rtol=1e-4)