import pytest import torch from ding.framework import OnlineRLContext from ding.data.buffer import DequeBuffer from typing import Any import numpy as np import copy from ding.framework.middleware.functional.enhancer import reward_estimator, her_data_enhancer from unittest.mock import Mock, patch from ding.framework.middleware.tests import MockHerRewardModel, CONFIG DATA = [{'obs': torch.rand(2, 2), 'next_obs': torch.rand(2, 2)} for _ in range(20)] class MockRewardModel(Mock): def estimate(self, data: list) -> Any: assert len(data) == len(DATA) assert torch.equal(data[0]['obs'], DATA[0]['obs']) @pytest.mark.unittest def test_reward_estimator(): ctx = OnlineRLContext() ctx.train_data = copy.deepcopy(DATA) with patch("ding.reward_model.HerRewardModel", MockHerRewardModel): reward_estimator(cfg=None, reward_model=MockRewardModel())(ctx) @pytest.mark.unittest def test_her_data_enhancer(): cfg = copy.deepcopy(CONFIG) ctx = OnlineRLContext() with patch("ding.reward_model.HerRewardModel", MockHerRewardModel): mock_her_reward_model = MockHerRewardModel() buffer = DequeBuffer(mock_her_reward_model.episode_size) train_data = [ [ { 'action': torch.randint(low=0, high=5, size=(1, )), 'collect_train_iter': torch.tensor([0]), 'done': torch.tensor(False), 'next_obs': torch.randint(low=0, high=2, size=(10, ), dtype=torch.float32), 'obs': torch.randint(low=0, high=2, size=(10, ), dtype=torch.float32), 'reward': torch.randint(low=0, high=2, size=(1, ), dtype=torch.float32), } for _ in range(np.random.choice([1, 4, 5], size=1)[0]) ] for _ in range(mock_her_reward_model.episode_size) ] for d in train_data: buffer.push(d) her_data_enhancer(cfg=cfg, buffer_=buffer, her_reward_model=MockHerRewardModel())(ctx) assert len(ctx.train_data) == mock_her_reward_model.episode_size * mock_her_reward_model.episode_element_size assert len(ctx.train_data[0]) == 6 buffer = DequeBuffer(cfg.policy.learn.batch_size) for d in train_data: buffer.push(d) mock_her_reward_model.episode_size = None her_data_enhancer(cfg=cfg, buffer_=buffer, her_reward_model=MockHerRewardModel())(ctx) assert len(ctx.train_data) == cfg.policy.learn.batch_size * mock_her_reward_model.episode_element_size assert len(ctx.train_data[0]) == 6