gomoku / DI-engine /ding /torch_utils /tests /test_model_helper.py
zjowowen's picture
init space
079c32c
raw
history blame contribute delete
438 Bytes
import pytest
import torch
from ding.torch_utils.model_helper import get_num_params
@pytest.mark.unittest
class TestModelHelper:
def test_model_helper(self):
r"""
Overview:
Test the model helper.
"""
net = torch.nn.Linear(3, 4, bias=False)
assert get_num_params(net) == 12
net = torch.nn.Conv2d(3, 3, kernel_size=3, bias=False)
assert get_num_params(net) == 81