import pytest import random import copy import torch import treetensor.torch as ttorch from unittest.mock import Mock, patch from ding.data.buffer import DequeBuffer from ding.framework import OnlineRLContext, task from ding.framework.middleware import trainer, multistep_trainer, OffPolicyLearner, HERLearner from ding.framework.middleware.tests import MockHerRewardModel, CONFIG class MockPolicy(Mock): _device = 'cpu' # MockPolicy class for train mode def forward(self, train_data, **kwargs): res = { 'total_loss': 0.1, } return res class MultiStepMockPolicy(Mock): _device = 'cpu' # MockPolicy class for multi-step train mode def forward(self, train_data, **kwargs): res = [ { 'total_loss': 0.1, }, { 'total_loss': 1.0, }, ] return res def get_mock_train_input(): data = {'obs': torch.rand(2, 2), 'next_obs': torch.rand(2, 2), 'reward': random.random(), 'info': {}} return ttorch.as_tensor(data) @pytest.mark.unittest def test_trainer(): cfg = copy.deepcopy(CONFIG) ctx = OnlineRLContext() ctx.train_data = None with patch("ding.policy.Policy", MockPolicy): policy = MockPolicy() for _ in range(10): trainer(cfg, policy)(ctx) assert ctx.train_iter == 0 ctx.train_data = get_mock_train_input() with patch("ding.policy.Policy", MockPolicy): policy = MockPolicy() for _ in range(30): trainer(cfg, policy)(ctx) assert ctx.train_iter == 30 assert ctx.train_output["total_loss"] == 0.1 @pytest.mark.unittest def test_multistep_trainer(): cfg = copy.deepcopy(CONFIG) ctx = OnlineRLContext() ctx.train_data = None with patch("ding.policy.Policy", MockPolicy): policy = MockPolicy() for _ in range(10): trainer(cfg, policy)(ctx) assert ctx.train_iter == 0 ctx.train_data = get_mock_train_input() with patch("ding.policy.Policy", MultiStepMockPolicy): policy = MultiStepMockPolicy() for _ in range(30): multistep_trainer(policy, 10)(ctx) assert ctx.train_iter == 60 assert ctx.train_output[0]["total_loss"] == 0.1 assert ctx.train_output[1]["total_loss"] == 1.0 @pytest.mark.unittest def test_offpolicy_learner(): cfg = copy.deepcopy(CONFIG) ctx = OnlineRLContext() buffer = DequeBuffer(size=10) for _ in range(10): buffer.push(get_mock_train_input()) with patch("ding.policy.Policy", MockPolicy): with task.start(): policy = MockPolicy() learner = OffPolicyLearner(cfg, policy, buffer) learner(ctx) assert len(ctx.train_output) == 4 @pytest.mark.unittest def test_her_learner(): cfg = copy.deepcopy(CONFIG) ctx = OnlineRLContext() buffer = DequeBuffer(size=10) for _ in range(10): buffer.push([get_mock_train_input(), get_mock_train_input()]) with patch("ding.policy.Policy", MockPolicy), patch("ding.reward_model.HerRewardModel", MockHerRewardModel): with task.start(): policy = MockPolicy() her_reward_model = MockHerRewardModel() learner = HERLearner(cfg, policy, buffer, her_reward_model) learner(ctx) assert len(ctx.train_output) == 4