zjowowen's picture
init space
079c32c
import pytest
import torch
from ding.framework import OnlineRLContext
from ding.data.buffer import DequeBuffer
from typing import Any
import numpy as np
import copy
from ding.framework.middleware.functional.enhancer import reward_estimator, her_data_enhancer
from unittest.mock import Mock, patch
from ding.framework.middleware.tests import MockHerRewardModel, CONFIG
DATA = [{'obs': torch.rand(2, 2), 'next_obs': torch.rand(2, 2)} for _ in range(20)]
class MockRewardModel(Mock):
def estimate(self, data: list) -> Any:
assert len(data) == len(DATA)
assert torch.equal(data[0]['obs'], DATA[0]['obs'])
@pytest.mark.unittest
def test_reward_estimator():
ctx = OnlineRLContext()
ctx.train_data = copy.deepcopy(DATA)
with patch("ding.reward_model.HerRewardModel", MockHerRewardModel):
reward_estimator(cfg=None, reward_model=MockRewardModel())(ctx)
@pytest.mark.unittest
def test_her_data_enhancer():
cfg = copy.deepcopy(CONFIG)
ctx = OnlineRLContext()
with patch("ding.reward_model.HerRewardModel", MockHerRewardModel):
mock_her_reward_model = MockHerRewardModel()
buffer = DequeBuffer(mock_her_reward_model.episode_size)
train_data = [
[
{
'action': torch.randint(low=0, high=5, size=(1, )),
'collect_train_iter': torch.tensor([0]),
'done': torch.tensor(False),
'next_obs': torch.randint(low=0, high=2, size=(10, ), dtype=torch.float32),
'obs': torch.randint(low=0, high=2, size=(10, ), dtype=torch.float32),
'reward': torch.randint(low=0, high=2, size=(1, ), dtype=torch.float32),
} for _ in range(np.random.choice([1, 4, 5], size=1)[0])
] for _ in range(mock_her_reward_model.episode_size)
]
for d in train_data:
buffer.push(d)
her_data_enhancer(cfg=cfg, buffer_=buffer, her_reward_model=MockHerRewardModel())(ctx)
assert len(ctx.train_data) == mock_her_reward_model.episode_size * mock_her_reward_model.episode_element_size
assert len(ctx.train_data[0]) == 6
buffer = DequeBuffer(cfg.policy.learn.batch_size)
for d in train_data:
buffer.push(d)
mock_her_reward_model.episode_size = None
her_data_enhancer(cfg=cfg, buffer_=buffer, her_reward_model=MockHerRewardModel())(ctx)
assert len(ctx.train_data) == cfg.policy.learn.batch_size * mock_her_reward_model.episode_element_size
assert len(ctx.train_data[0]) == 6