File size: 1,282 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import pytest
from itertools import product

import torch

from ding.model.template import ProcedureCloningMCTS, ProcedureCloningBFS

B = 4
T = 15
obs_shape = [(64, 64, 3)]
action_dim = [9]
obs_embeddings = 256
args = list(product(*[obs_shape, action_dim]))


@pytest.mark.unittest
@pytest.mark.parametrize('obs_shape, action_dim', args)
class TestProcedureCloning:

    def test_procedure_cloning_mcts(self, obs_shape, action_dim):
        inputs = {
            'states': torch.randn(B, *obs_shape),
            'goals': torch.randn(B, *obs_shape),
            'actions': torch.randn(B, T, action_dim)
        }
        model = ProcedureCloningMCTS(obs_shape=obs_shape, action_dim=action_dim)
        goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions'])
        assert goal_preds.shape == (B, obs_embeddings)
        assert action_preds.shape == (B, T + 1, action_dim)

    def test_procedure_cloning_bfs(self, obs_shape, action_dim):
        o_shape = (obs_shape[2], obs_shape[0], obs_shape[1])
        model = ProcedureCloningBFS(obs_shape=o_shape, action_shape=action_dim)

        inputs = torch.randn(B, *obs_shape)
        map_preds = model(inputs)
        assert map_preds['logit'].shape == (B, obs_shape[0], obs_shape[1], action_dim + 1)