File size: 4,313 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import pytest
import torch
from easydict import EasyDict
from ding.policy.r2d3 import R2D3Policy
from ding.utils.data import offline_data_save_type
from tensorboardX import SummaryWriter
from ding.model.wrapper.model_wrappers import ArgmaxSampleWrapper, HiddenStateWrapper, EpsGreedySampleWrapper
import os
from typing import List
from collections import namedtuple

obs_space = 5
action_space = 4

cfg = dict(
    cuda=True,
    on_policy=False,
    priority=True,
    priority_IS_weight=True,
    model=dict(
        obs_shape=obs_space,
        action_shape=action_space,
        encoder_hidden_size_list=[128, 128, 512],
    ),
    discount_factor=0.99,
    burnin_step=2,
    nstep=5,
    learn_unroll_len=20,
    burning_step=5,
    learn=dict(
        value_rescale=True,
        update_per_collect=8,
        batch_size=64,
        learning_rate=0.0005,
        target_update_theta=0.001,
        lambda1=1.0,  # n-step return
        lambda2=1.0,  # supervised loss
        lambda3=1e-5,  # L2  it's very important to set Adam optimizer optim_type='adamw'.
        lambda_one_step_td=1,  # 1-step return
        margin_function=0.8,  # margin function in JE, here we implement this as a constant
        per_train_iter_k=0,
        ignore_done=False,
    ),
    collect=dict(
        n_sample=32,
        traj_len_inf=True,
        env_num=8,
        pho=1 / 4,
    ),
    eval=dict(env_num=8, ),
    other=dict(
        eps=dict(
            type='exp',
            start=0.95,
            end=0.1,
            decay=100000,
        ),
        replay_buffer=dict(
            replay_buffer_size=int(1e4),
            alpha=0.6,
            beta=0.4,
        ),
    ),
)
cfg = EasyDict(cfg)


def get_batch(size=8):
    data = {}
    for i in range(size):
        obs = torch.zeros(obs_space)
        data[i] = obs
    return data


def get_transition(size=20):
    data = []
    import numpy as np
    for i in range(size):
        sample = {}
        sample['obs'] = torch.zeros(obs_space)
        sample['action'] = torch.tensor(np.array([int(i % action_space)]))
        sample['done'] = False
        sample['prev_state'] = [torch.randn(1, 1, 512) for __ in range(2)]
        sample['reward'] = torch.Tensor([1.])
        sample['IS'] = 1.
        sample['is_expert'] = bool(i % 2)
        data.append(sample)
    return data


@pytest.mark.parametrize('cfg', [cfg])
@pytest.mark.unittest
def test_r2d3(cfg):
    policy = R2D3Policy(cfg, enable_field=['collect', 'eval'])
    policy._init_learn()
    assert type(policy._learn_model) == ArgmaxSampleWrapper
    assert type(policy._target_model) == HiddenStateWrapper
    policy._reset_learn()
    policy._reset_learn([0])
    state = policy._state_dict_learn()
    policy._load_state_dict_learn(state)
    policy._init_collect()
    assert type(policy._collect_model) == EpsGreedySampleWrapper
    policy._reset_collect()
    policy._reset_collect([0])
    policy._init_eval()
    assert type(policy._eval_model) == ArgmaxSampleWrapper
    policy._reset_eval()
    policy._reset_eval([0])
    assert policy.default_model()[0] == 'drqn'
    var = policy._monitor_vars_learn()
    assert type(var) == list
    assert sum([type(s) == str for s in var]) == len(var)
    batch = get_batch(8)
    out = policy._forward_collect(batch, eps=0.1)
    assert len(set(out[0].keys()).intersection({'logit', 'prev_state', 'action'})) == 3
    assert list(out[0]['logit'].shape) == [action_space]
    timestep = namedtuple('timestep', ['reward', 'done'])
    ts = timestep(
        1.,
        0.,
    )
    ts = policy._process_transition(batch[0], out[0], ts)
    assert len(set(ts.keys()).intersection({'prev_state', 'action', 'reward', 'done', 'obs'})) == 5
    ts = get_transition(64 * policy._sequence_len)
    sample = policy._get_train_sample(ts)
    n_traj = len(ts) // policy._sequence_len
    assert len(sample) == n_traj + 1 if len(ts) % policy._sequence_len != 0 else n_traj
    out = policy._forward_eval(batch)
    assert len(set(out[0].keys()).intersection({'logit', 'action'})) == 2
    assert list(out[0]['logit'].shape) == [action_space]
    for i in range(len(sample)):
        sample[i]['IS'] = sample[i]['IS'][cfg.burnin_step:]
    out = policy._forward_learn(sample)
    policy._value_rescale = False
    out = policy._forward_learn(sample)