File size: 1,072 Bytes
7088d16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat

from .common_testing import TestCaseMixin


class TestLinearWithRepeat(TestCaseMixin, unittest.TestCase):
    def setUp(self) -> None:
        super().setUp()
        torch.manual_seed(42)

    def test_simple(self):
        x = torch.rand(4, 6, 7, 3)
        y = torch.rand(4, 6, 4)

        linear = torch.nn.Linear(7, 8)
        torch.nn.init.xavier_uniform_(linear.weight.data)
        linear.bias.data.uniform_()
        equivalent = torch.cat([x, y.unsqueeze(-2).expand(4, 6, 7, 4)], dim=-1)
        expected = linear.forward(equivalent)

        linear_with_repeat = LinearWithRepeat(7, 8)
        linear_with_repeat.load_state_dict(linear.state_dict())
        actual = linear_with_repeat.forward((x, y))
        self.assertClose(actual, expected, rtol=1e-4)