zjowowen's picture
init space
079c32c
raw
history blame
1.39 kB
from copy import deepcopy
import pytest
import torch
from easydict import EasyDict
from ding.model.wrapper.model_wrappers import BaseModelWrapper, MultinomialSampleWrapper
from ding.policy import PPOSTDIMPolicy
obs_shape = 4
action_shape = 2
cfg1 = EasyDict(PPOSTDIMPolicy.default_config())
cfg1.model.obs_shape = obs_shape
cfg1.model.action_shape = action_shape
cfg2 = deepcopy(cfg1)
cfg2.action_space = "continuous"
def get_transition_discrete(size=64):
data = []
for i in range(size):
sample = {}
sample['obs'] = torch.rand(obs_shape)
sample['next_obs'] = torch.rand(obs_shape)
sample['action'] = torch.tensor([0], dtype=torch.long)
sample['value'] = torch.rand(1)
sample['logit'] = torch.rand(size=(action_shape, ))
sample['done'] = False
sample['reward'] = torch.rand(1)
data.append(sample)
return data
@pytest.mark.parametrize('cfg', [cfg1])
@pytest.mark.unittest
def test_stdim(cfg):
policy = PPOSTDIMPolicy(cfg, enable_field=['collect', 'eval', 'learn'])
assert type(policy._learn_model) == BaseModelWrapper
assert type(policy._collect_model) == MultinomialSampleWrapper
sample = get_transition_discrete(size=64)
state = policy._state_dict_learn()
policy._load_state_dict_learn(state)
sample = get_transition_discrete(size=64)
out = policy._forward_learn(sample)