|
import pytest |
|
import torch |
|
from ding.torch_utils import is_differentiable |
|
|
|
from lzero.model.common import RepresentationNetwork |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestCommon: |
|
|
|
def output_check(self, model, outputs): |
|
if isinstance(outputs, torch.Tensor): |
|
loss = outputs.sum() |
|
elif isinstance(outputs, list): |
|
loss = sum([t.sum() for t in outputs]) |
|
elif isinstance(outputs, dict): |
|
loss = sum([v.sum() for v in outputs.values()]) |
|
is_differentiable(loss, model) |
|
|
|
@pytest.mark.parametrize('batch_size', [10]) |
|
def test_representation_network(self, batch_size): |
|
batch = batch_size |
|
obs = torch.rand(batch, 1, 3, 3) |
|
representation_network = RepresentationNetwork( |
|
observation_shape=[1, 3, 3], num_res_blocks=1, num_channels=16, downsample=False |
|
) |
|
state = representation_network(obs) |
|
assert state.shape == torch.Size([10, 16, 3, 3]) |
|
|