zjowowen's picture
init space
079c32c
raw
history blame
4.31 kB
import pytest
import torch
from easydict import EasyDict
from ding.policy.r2d3 import R2D3Policy
from ding.utils.data import offline_data_save_type
from tensorboardX import SummaryWriter
from ding.model.wrapper.model_wrappers import ArgmaxSampleWrapper, HiddenStateWrapper, EpsGreedySampleWrapper
import os
from typing import List
from collections import namedtuple
obs_space = 5
action_space = 4
cfg = dict(
cuda=True,
on_policy=False,
priority=True,
priority_IS_weight=True,
model=dict(
obs_shape=obs_space,
action_shape=action_space,
encoder_hidden_size_list=[128, 128, 512],
),
discount_factor=0.99,
burnin_step=2,
nstep=5,
learn_unroll_len=20,
burning_step=5,
learn=dict(
value_rescale=True,
update_per_collect=8,
batch_size=64,
learning_rate=0.0005,
target_update_theta=0.001,
lambda1=1.0, # n-step return
lambda2=1.0, # supervised loss
lambda3=1e-5, # L2 it's very important to set Adam optimizer optim_type='adamw'.
lambda_one_step_td=1, # 1-step return
margin_function=0.8, # margin function in JE, here we implement this as a constant
per_train_iter_k=0,
ignore_done=False,
),
collect=dict(
n_sample=32,
traj_len_inf=True,
env_num=8,
pho=1 / 4,
),
eval=dict(env_num=8, ),
other=dict(
eps=dict(
type='exp',
start=0.95,
end=0.1,
decay=100000,
),
replay_buffer=dict(
replay_buffer_size=int(1e4),
alpha=0.6,
beta=0.4,
),
),
)
cfg = EasyDict(cfg)
def get_batch(size=8):
data = {}
for i in range(size):
obs = torch.zeros(obs_space)
data[i] = obs
return data
def get_transition(size=20):
data = []
import numpy as np
for i in range(size):
sample = {}
sample['obs'] = torch.zeros(obs_space)
sample['action'] = torch.tensor(np.array([int(i % action_space)]))
sample['done'] = False
sample['prev_state'] = [torch.randn(1, 1, 512) for __ in range(2)]
sample['reward'] = torch.Tensor([1.])
sample['IS'] = 1.
sample['is_expert'] = bool(i % 2)
data.append(sample)
return data
@pytest.mark.parametrize('cfg', [cfg])
@pytest.mark.unittest
def test_r2d3(cfg):
policy = R2D3Policy(cfg, enable_field=['collect', 'eval'])
policy._init_learn()
assert type(policy._learn_model) == ArgmaxSampleWrapper
assert type(policy._target_model) == HiddenStateWrapper
policy._reset_learn()
policy._reset_learn([0])
state = policy._state_dict_learn()
policy._load_state_dict_learn(state)
policy._init_collect()
assert type(policy._collect_model) == EpsGreedySampleWrapper
policy._reset_collect()
policy._reset_collect([0])
policy._init_eval()
assert type(policy._eval_model) == ArgmaxSampleWrapper
policy._reset_eval()
policy._reset_eval([0])
assert policy.default_model()[0] == 'drqn'
var = policy._monitor_vars_learn()
assert type(var) == list
assert sum([type(s) == str for s in var]) == len(var)
batch = get_batch(8)
out = policy._forward_collect(batch, eps=0.1)
assert len(set(out[0].keys()).intersection({'logit', 'prev_state', 'action'})) == 3
assert list(out[0]['logit'].shape) == [action_space]
timestep = namedtuple('timestep', ['reward', 'done'])
ts = timestep(
1.,
0.,
)
ts = policy._process_transition(batch[0], out[0], ts)
assert len(set(ts.keys()).intersection({'prev_state', 'action', 'reward', 'done', 'obs'})) == 5
ts = get_transition(64 * policy._sequence_len)
sample = policy._get_train_sample(ts)
n_traj = len(ts) // policy._sequence_len
assert len(sample) == n_traj + 1 if len(ts) % policy._sequence_len != 0 else n_traj
out = policy._forward_eval(batch)
assert len(set(out[0].keys()).intersection({'logit', 'action'})) == 2
assert list(out[0]['logit'].shape) == [action_space]
for i in range(len(sample)):
sample[i]['IS'] = sample[i]['IS'][cfg.burnin_step:]
out = policy._forward_learn(sample)
policy._value_rescale = False
out = policy._forward_learn(sample)