File size: 438 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import pytest
import torch

from ding.torch_utils.model_helper import get_num_params


@pytest.mark.unittest
class TestModelHelper:

    def test_model_helper(self):
        r"""
        Overview:
            Test the model helper.
        """
        net = torch.nn.Linear(3, 4, bias=False)
        assert get_num_params(net) == 12

        net = torch.nn.Conv2d(3, 3, kernel_size=3, bias=False)
        assert get_num_params(net) == 81