gomoku / DI-engine /ding /model /template /tests /test_procedure_cloning.py
zjowowen's picture
init space
079c32c
raw
history blame
1.28 kB
import pytest
from itertools import product
import torch
from ding.model.template import ProcedureCloningMCTS, ProcedureCloningBFS
B = 4
T = 15
obs_shape = [(64, 64, 3)]
action_dim = [9]
obs_embeddings = 256
args = list(product(*[obs_shape, action_dim]))
@pytest.mark.unittest
@pytest.mark.parametrize('obs_shape, action_dim', args)
class TestProcedureCloning:
def test_procedure_cloning_mcts(self, obs_shape, action_dim):
inputs = {
'states': torch.randn(B, *obs_shape),
'goals': torch.randn(B, *obs_shape),
'actions': torch.randn(B, T, action_dim)
}
model = ProcedureCloningMCTS(obs_shape=obs_shape, action_dim=action_dim)
goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions'])
assert goal_preds.shape == (B, obs_embeddings)
assert action_preds.shape == (B, T + 1, action_dim)
def test_procedure_cloning_bfs(self, obs_shape, action_dim):
o_shape = (obs_shape[2], obs_shape[0], obs_shape[1])
model = ProcedureCloningBFS(obs_shape=o_shape, action_shape=action_dim)
inputs = torch.randn(B, *obs_shape)
map_preds = model(inputs)
assert map_preds['logit'].shape == (B, obs_shape[0], obs_shape[1], action_dim + 1)