from copy import deepcopy import pytest import torch from easydict import EasyDict from ding.model.wrapper.model_wrappers import BaseModelWrapper, MultinomialSampleWrapper from ding.policy import PPOSTDIMPolicy obs_shape = 4 action_shape = 2 cfg1 = EasyDict(PPOSTDIMPolicy.default_config()) cfg1.model.obs_shape = obs_shape cfg1.model.action_shape = action_shape cfg2 = deepcopy(cfg1) cfg2.action_space = "continuous" def get_transition_discrete(size=64): data = [] for i in range(size): sample = {} sample['obs'] = torch.rand(obs_shape) sample['next_obs'] = torch.rand(obs_shape) sample['action'] = torch.tensor([0], dtype=torch.long) sample['value'] = torch.rand(1) sample['logit'] = torch.rand(size=(action_shape, )) sample['done'] = False sample['reward'] = torch.rand(1) data.append(sample) return data @pytest.mark.parametrize('cfg', [cfg1]) @pytest.mark.unittest def test_stdim(cfg): policy = PPOSTDIMPolicy(cfg, enable_field=['collect', 'eval', 'learn']) assert type(policy._learn_model) == BaseModelWrapper assert type(policy._collect_model) == MultinomialSampleWrapper sample = get_transition_discrete(size=64) state = policy._state_dict_learn() policy._load_state_dict_learn(state) sample = get_transition_discrete(size=64) out = policy._forward_learn(sample)