|
import pytest |
|
from itertools import product |
|
|
|
import torch |
|
|
|
from ding.model.template import ProcedureCloningMCTS, ProcedureCloningBFS |
|
|
|
B = 4 |
|
T = 15 |
|
obs_shape = [(64, 64, 3)] |
|
action_dim = [9] |
|
obs_embeddings = 256 |
|
args = list(product(*[obs_shape, action_dim])) |
|
|
|
|
|
@pytest.mark.unittest |
|
@pytest.mark.parametrize('obs_shape, action_dim', args) |
|
class TestProcedureCloning: |
|
|
|
def test_procedure_cloning_mcts(self, obs_shape, action_dim): |
|
inputs = { |
|
'states': torch.randn(B, *obs_shape), |
|
'goals': torch.randn(B, *obs_shape), |
|
'actions': torch.randn(B, T, action_dim) |
|
} |
|
model = ProcedureCloningMCTS(obs_shape=obs_shape, action_dim=action_dim) |
|
goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions']) |
|
assert goal_preds.shape == (B, obs_embeddings) |
|
assert action_preds.shape == (B, T + 1, action_dim) |
|
|
|
def test_procedure_cloning_bfs(self, obs_shape, action_dim): |
|
o_shape = (obs_shape[2], obs_shape[0], obs_shape[1]) |
|
model = ProcedureCloningBFS(obs_shape=o_shape, action_shape=action_dim) |
|
|
|
inputs = torch.randn(B, *obs_shape) |
|
map_preds = model(inputs) |
|
assert map_preds['logit'].shape == (B, obs_shape[0], obs_shape[1], action_dim + 1) |
|
|