zjowowen's picture
init space
079c32c
raw
history blame
952 Bytes
import pytest
import torch
from itertools import product
from ding.world_model.model.ensemble import EnsembleFC, EnsembleModel
# arguments
state_size = [16]
action_size = [16, 1]
reward_size = [1]
args = list(product(*[state_size, action_size, reward_size]))
@pytest.mark.unittest
def test_EnsembleFC():
in_dim, out_dim, ensemble_size, B = 4, 8, 7, 64
fc = EnsembleFC(in_dim, out_dim, ensemble_size)
x = torch.randn(ensemble_size, B, in_dim)
y = fc(x)
assert y.shape == (ensemble_size, B, out_dim)
@pytest.mark.parametrize('state_size, action_size, reward_size', args)
def test_EnsembleModel(state_size, action_size, reward_size):
ensemble_size, B = 7, 64
model = EnsembleModel(state_size, action_size, reward_size, ensemble_size)
x = torch.randn(ensemble_size, B, state_size + action_size)
y = model(x)
assert len(y) == 2
assert y[0].shape == y[1].shape == (ensemble_size, B, state_size + reward_size)