File size: 4,313 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import pytest
import torch
from easydict import EasyDict
from ding.policy.r2d3 import R2D3Policy
from ding.utils.data import offline_data_save_type
from tensorboardX import SummaryWriter
from ding.model.wrapper.model_wrappers import ArgmaxSampleWrapper, HiddenStateWrapper, EpsGreedySampleWrapper
import os
from typing import List
from collections import namedtuple
obs_space = 5
action_space = 4
cfg = dict(
cuda=True,
on_policy=False,
priority=True,
priority_IS_weight=True,
model=dict(
obs_shape=obs_space,
action_shape=action_space,
encoder_hidden_size_list=[128, 128, 512],
),
discount_factor=0.99,
burnin_step=2,
nstep=5,
learn_unroll_len=20,
burning_step=5,
learn=dict(
value_rescale=True,
update_per_collect=8,
batch_size=64,
learning_rate=0.0005,
target_update_theta=0.001,
lambda1=1.0, # n-step return
lambda2=1.0, # supervised loss
lambda3=1e-5, # L2 it's very important to set Adam optimizer optim_type='adamw'.
lambda_one_step_td=1, # 1-step return
margin_function=0.8, # margin function in JE, here we implement this as a constant
per_train_iter_k=0,
ignore_done=False,
),
collect=dict(
n_sample=32,
traj_len_inf=True,
env_num=8,
pho=1 / 4,
),
eval=dict(env_num=8, ),
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=100000,
),
replay_buffer=dict(
replay_buffer_size=int(1e4),
alpha=0.6,
beta=0.4,
),
),
)
cfg = EasyDict(cfg)
def get_batch(size=8):
data = {}
for i in range(size):
obs = torch.zeros(obs_space)
data[i] = obs
return data
def get_transition(size=20):
data = []
import numpy as np
for i in range(size):
sample = {}
sample['obs'] = torch.zeros(obs_space)
sample['action'] = torch.tensor(np.array([int(i % action_space)]))
sample['done'] = False
sample['prev_state'] = [torch.randn(1, 1, 512) for __ in range(2)]
sample['reward'] = torch.Tensor([1.])
sample['IS'] = 1.
sample['is_expert'] = bool(i % 2)
data.append(sample)
return data
@pytest.mark.parametrize('cfg', [cfg])
@pytest.mark.unittest
def test_r2d3(cfg):
policy = R2D3Policy(cfg, enable_field=['collect', 'eval'])
policy._init_learn()
assert type(policy._learn_model) == ArgmaxSampleWrapper
assert type(policy._target_model) == HiddenStateWrapper
policy._reset_learn()
policy._reset_learn([0])
state = policy._state_dict_learn()
policy._load_state_dict_learn(state)
policy._init_collect()
assert type(policy._collect_model) == EpsGreedySampleWrapper
policy._reset_collect()
policy._reset_collect([0])
policy._init_eval()
assert type(policy._eval_model) == ArgmaxSampleWrapper
policy._reset_eval()
policy._reset_eval([0])
assert policy.default_model()[0] == 'drqn'
var = policy._monitor_vars_learn()
assert type(var) == list
assert sum([type(s) == str for s in var]) == len(var)
batch = get_batch(8)
out = policy._forward_collect(batch, eps=0.1)
assert len(set(out[0].keys()).intersection({'logit', 'prev_state', 'action'})) == 3
assert list(out[0]['logit'].shape) == [action_space]
timestep = namedtuple('timestep', ['reward', 'done'])
ts = timestep(
1.,
0.,
)
ts = policy._process_transition(batch[0], out[0], ts)
assert len(set(ts.keys()).intersection({'prev_state', 'action', 'reward', 'done', 'obs'})) == 5
ts = get_transition(64 * policy._sequence_len)
sample = policy._get_train_sample(ts)
n_traj = len(ts) // policy._sequence_len
assert len(sample) == n_traj + 1 if len(ts) % policy._sequence_len != 0 else n_traj
out = policy._forward_eval(batch)
assert len(set(out[0].keys()).intersection({'logit', 'action'})) == 2
assert list(out[0]['logit'].shape) == [action_space]
for i in range(len(sample)):
sample[i]['IS'] = sample[i]['IS'][cfg.burnin_step:]
out = policy._forward_learn(sample)
policy._value_rescale = False
out = policy._forward_learn(sample)
|