|
import pytest |
|
import torch |
|
from easydict import EasyDict |
|
from ding.policy.r2d3 import R2D3Policy |
|
from ding.utils.data import offline_data_save_type |
|
from tensorboardX import SummaryWriter |
|
from ding.model.wrapper.model_wrappers import ArgmaxSampleWrapper, HiddenStateWrapper, EpsGreedySampleWrapper |
|
import os |
|
from typing import List |
|
from collections import namedtuple |
|
|
|
obs_space = 5 |
|
action_space = 4 |
|
|
|
cfg = dict( |
|
cuda=True, |
|
on_policy=False, |
|
priority=True, |
|
priority_IS_weight=True, |
|
model=dict( |
|
obs_shape=obs_space, |
|
action_shape=action_space, |
|
encoder_hidden_size_list=[128, 128, 512], |
|
), |
|
discount_factor=0.99, |
|
burnin_step=2, |
|
nstep=5, |
|
learn_unroll_len=20, |
|
burning_step=5, |
|
learn=dict( |
|
value_rescale=True, |
|
update_per_collect=8, |
|
batch_size=64, |
|
learning_rate=0.0005, |
|
target_update_theta=0.001, |
|
lambda1=1.0, |
|
lambda2=1.0, |
|
lambda3=1e-5, |
|
lambda_one_step_td=1, |
|
margin_function=0.8, |
|
per_train_iter_k=0, |
|
ignore_done=False, |
|
), |
|
collect=dict( |
|
n_sample=32, |
|
traj_len_inf=True, |
|
env_num=8, |
|
pho=1 / 4, |
|
), |
|
eval=dict(env_num=8, ), |
|
other=dict( |
|
eps=dict( |
|
type='exp', |
|
start=0.95, |
|
end=0.1, |
|
decay=100000, |
|
), |
|
replay_buffer=dict( |
|
replay_buffer_size=int(1e4), |
|
alpha=0.6, |
|
beta=0.4, |
|
), |
|
), |
|
) |
|
cfg = EasyDict(cfg) |
|
|
|
|
|
def get_batch(size=8): |
|
data = {} |
|
for i in range(size): |
|
obs = torch.zeros(obs_space) |
|
data[i] = obs |
|
return data |
|
|
|
|
|
def get_transition(size=20): |
|
data = [] |
|
import numpy as np |
|
for i in range(size): |
|
sample = {} |
|
sample['obs'] = torch.zeros(obs_space) |
|
sample['action'] = torch.tensor(np.array([int(i % action_space)])) |
|
sample['done'] = False |
|
sample['prev_state'] = [torch.randn(1, 1, 512) for __ in range(2)] |
|
sample['reward'] = torch.Tensor([1.]) |
|
sample['IS'] = 1. |
|
sample['is_expert'] = bool(i % 2) |
|
data.append(sample) |
|
return data |
|
|
|
|
|
@pytest.mark.parametrize('cfg', [cfg]) |
|
@pytest.mark.unittest |
|
def test_r2d3(cfg): |
|
policy = R2D3Policy(cfg, enable_field=['collect', 'eval']) |
|
policy._init_learn() |
|
assert type(policy._learn_model) == ArgmaxSampleWrapper |
|
assert type(policy._target_model) == HiddenStateWrapper |
|
policy._reset_learn() |
|
policy._reset_learn([0]) |
|
state = policy._state_dict_learn() |
|
policy._load_state_dict_learn(state) |
|
policy._init_collect() |
|
assert type(policy._collect_model) == EpsGreedySampleWrapper |
|
policy._reset_collect() |
|
policy._reset_collect([0]) |
|
policy._init_eval() |
|
assert type(policy._eval_model) == ArgmaxSampleWrapper |
|
policy._reset_eval() |
|
policy._reset_eval([0]) |
|
assert policy.default_model()[0] == 'drqn' |
|
var = policy._monitor_vars_learn() |
|
assert type(var) == list |
|
assert sum([type(s) == str for s in var]) == len(var) |
|
batch = get_batch(8) |
|
out = policy._forward_collect(batch, eps=0.1) |
|
assert len(set(out[0].keys()).intersection({'logit', 'prev_state', 'action'})) == 3 |
|
assert list(out[0]['logit'].shape) == [action_space] |
|
timestep = namedtuple('timestep', ['reward', 'done']) |
|
ts = timestep( |
|
1., |
|
0., |
|
) |
|
ts = policy._process_transition(batch[0], out[0], ts) |
|
assert len(set(ts.keys()).intersection({'prev_state', 'action', 'reward', 'done', 'obs'})) == 5 |
|
ts = get_transition(64 * policy._sequence_len) |
|
sample = policy._get_train_sample(ts) |
|
n_traj = len(ts) // policy._sequence_len |
|
assert len(sample) == n_traj + 1 if len(ts) % policy._sequence_len != 0 else n_traj |
|
out = policy._forward_eval(batch) |
|
assert len(set(out[0].keys()).intersection({'logit', 'action'})) == 2 |
|
assert list(out[0]['logit'].shape) == [action_space] |
|
for i in range(len(sample)): |
|
sample[i]['IS'] = sample[i]['IS'][cfg.burnin_step:] |
|
out = policy._forward_learn(sample) |
|
policy._value_rescale = False |
|
out = policy._forward_learn(sample) |
|
|