zjowowen's picture
init space
079c32c
raw
history blame
957 Bytes
import pytest
import torch
from ding.torch_utils import is_differentiable
from lzero.model.common import RepresentationNetwork
@pytest.mark.unittest
class TestCommon:
def output_check(self, model, outputs):
if isinstance(outputs, torch.Tensor):
loss = outputs.sum()
elif isinstance(outputs, list):
loss = sum([t.sum() for t in outputs])
elif isinstance(outputs, dict):
loss = sum([v.sum() for v in outputs.values()])
is_differentiable(loss, model)
@pytest.mark.parametrize('batch_size', [10])
def test_representation_network(self, batch_size):
batch = batch_size
obs = torch.rand(batch, 1, 3, 3)
representation_network = RepresentationNetwork(
observation_shape=[1, 3, 3], num_res_blocks=1, num_channels=16, downsample=False
)
state = representation_network(obs)
assert state.shape == torch.Size([10, 16, 3, 3])