File size: 4,424 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import pytest
import torch
from torch import nn

from itertools import product
from easydict import EasyDict
from ding.world_model.ddppo import DDPPOWorldMode, get_batch_jacobian, get_neighbor_index
from ding.utils import deep_merge_dicts

# arguments
state_size = [16]
action_size = [16, 1]
args = list(product(*[state_size, action_size]))


@pytest.mark.unittest
class TestDDPPO:

    def get_world_model(self, state_size, action_size):
        cfg = DDPPOWorldMode.default_config()
        cfg.model.max_epochs_since_update = 0
        cfg = deep_merge_dicts(
            cfg, dict(cuda=False, model=dict(state_size=state_size, action_size=action_size, reward_size=1))
        )
        fake_env = EasyDict(termination_fn=lambda obs: torch.zeros_like(obs.sum(-1)).bool())
        model = DDPPOWorldMode(cfg, fake_env, None)
        model.serial_calc_nn = True
        return model

    def test_get_neighbor_index(self):
        k = 2
        data = torch.tensor([[0, 0, 0], [0, 0, 1], [0, 0, -1], [5, 0, 0], [5, 0, 1], [5, 0, -1]])
        idx = get_neighbor_index(data, k, serial=True)
        target_idx = torch.tensor([[2, 1], [0, 2], [0, 1], [5, 4], [3, 5], [3, 4]])
        assert (idx - target_idx).sum() == 0

    def test_get_batch_jacobian(self):
        B, in_dim, out_dim = 64, 4, 8
        net = nn.Linear(in_dim, out_dim)
        x = torch.randn(B, in_dim)
        jacobian = get_batch_jacobian(net, x, out_dim)
        assert jacobian.shape == (B, out_dim, in_dim)

    @pytest.mark.parametrize('state_size, action_size', args)
    def test_get_jacobian(self, state_size, action_size):
        B, ensemble_size = 64, 7
        model = self.get_world_model(state_size, action_size)
        train_input_reg = torch.randn(ensemble_size, B, state_size + action_size)
        jacobian = model._get_jacobian(model.gradient_model, train_input_reg)
        assert jacobian.shape == (ensemble_size, B, state_size + 1, state_size + action_size)
        assert jacobian.requires_grad

    @pytest.mark.parametrize('state_size, action_size', args)
    def test_step(self, state_size, action_size):
        states = torch.rand(128, state_size)
        actions = torch.rand(128, action_size)
        model = self.get_world_model(state_size, action_size)
        model.elite_model_idxes = [0, 1]
        rewards, next_obs, dones = model.step(states, actions)
        assert rewards.shape == (128, )
        assert next_obs.shape == (128, state_size)
        assert dones.shape == (128, )

    @pytest.mark.parametrize('state_size, action_size', args)
    def test_train_rollout_model(self, state_size, action_size):
        states = torch.rand(1280, state_size)
        actions = torch.rand(1280, action_size)

        next_states = states + actions.mean(1, keepdim=True)
        rewards = next_states.mean(1, keepdim=True).repeat(1, 1)

        inputs = torch.cat([states, actions], dim=1)
        labels = torch.cat([rewards, next_states], dim=1)

        model = self.get_world_model(state_size, action_size)
        model._train_rollout_model(inputs[:64], labels[:64])

    @pytest.mark.parametrize('state_size, action_size', args)
    def test_train_graident_model(self, state_size, action_size):
        states = torch.rand(1280, state_size)
        actions = torch.rand(1280, action_size)

        next_states = states + actions.mean(1, keepdim=True)
        rewards = next_states.mean(1, keepdim=True)

        inputs = torch.cat([states, actions], dim=1)
        labels = torch.cat([rewards, next_states], dim=1)

        model = self.get_world_model(state_size, action_size)
        model._train_gradient_model(inputs[:64], labels[:64], inputs[:64], labels[:64])

    @pytest.mark.parametrize('state_size, action_size', args[:1])
    def test_others(self, state_size, action_size):
        states = torch.rand(1280, state_size)
        actions = torch.rand(1280, action_size)

        next_states = states + actions.mean(1, keepdim=True)
        rewards = next_states.mean(1, keepdim=True)

        inputs = torch.cat([states, actions], dim=1)
        labels = torch.cat([rewards, next_states], dim=1)

        model = self.get_world_model(state_size, action_size)
        model._train_rollout_model(inputs[:64], labels[:64])
        model._train_gradient_model(inputs[:64], labels[:64], inputs[:64], labels[:64])
        model._save_states()
        model._load_states()
        model._save_best(0, [1, 2, 3])