zjowowen's picture
init space
079c32c
raw
history blame
3.38 kB
import pytest
import random
import copy
import torch
import treetensor.torch as ttorch
from unittest.mock import Mock, patch
from ding.data.buffer import DequeBuffer
from ding.framework import OnlineRLContext, task
from ding.framework.middleware import trainer, multistep_trainer, OffPolicyLearner, HERLearner
from ding.framework.middleware.tests import MockHerRewardModel, CONFIG
class MockPolicy(Mock):
_device = 'cpu'
# MockPolicy class for train mode
def forward(self, train_data, **kwargs):
res = {
'total_loss': 0.1,
}
return res
class MultiStepMockPolicy(Mock):
_device = 'cpu'
# MockPolicy class for multi-step train mode
def forward(self, train_data, **kwargs):
res = [
{
'total_loss': 0.1,
},
{
'total_loss': 1.0,
},
]
return res
def get_mock_train_input():
data = {'obs': torch.rand(2, 2), 'next_obs': torch.rand(2, 2), 'reward': random.random(), 'info': {}}
return ttorch.as_tensor(data)
@pytest.mark.unittest
def test_trainer():
cfg = copy.deepcopy(CONFIG)
ctx = OnlineRLContext()
ctx.train_data = None
with patch("ding.policy.Policy", MockPolicy):
policy = MockPolicy()
for _ in range(10):
trainer(cfg, policy)(ctx)
assert ctx.train_iter == 0
ctx.train_data = get_mock_train_input()
with patch("ding.policy.Policy", MockPolicy):
policy = MockPolicy()
for _ in range(30):
trainer(cfg, policy)(ctx)
assert ctx.train_iter == 30
assert ctx.train_output["total_loss"] == 0.1
@pytest.mark.unittest
def test_multistep_trainer():
cfg = copy.deepcopy(CONFIG)
ctx = OnlineRLContext()
ctx.train_data = None
with patch("ding.policy.Policy", MockPolicy):
policy = MockPolicy()
for _ in range(10):
trainer(cfg, policy)(ctx)
assert ctx.train_iter == 0
ctx.train_data = get_mock_train_input()
with patch("ding.policy.Policy", MultiStepMockPolicy):
policy = MultiStepMockPolicy()
for _ in range(30):
multistep_trainer(policy, 10)(ctx)
assert ctx.train_iter == 60
assert ctx.train_output[0]["total_loss"] == 0.1
assert ctx.train_output[1]["total_loss"] == 1.0
@pytest.mark.unittest
def test_offpolicy_learner():
cfg = copy.deepcopy(CONFIG)
ctx = OnlineRLContext()
buffer = DequeBuffer(size=10)
for _ in range(10):
buffer.push(get_mock_train_input())
with patch("ding.policy.Policy", MockPolicy):
with task.start():
policy = MockPolicy()
learner = OffPolicyLearner(cfg, policy, buffer)
learner(ctx)
assert len(ctx.train_output) == 4
@pytest.mark.unittest
def test_her_learner():
cfg = copy.deepcopy(CONFIG)
ctx = OnlineRLContext()
buffer = DequeBuffer(size=10)
for _ in range(10):
buffer.push([get_mock_train_input(), get_mock_train_input()])
with patch("ding.policy.Policy", MockPolicy), patch("ding.reward_model.HerRewardModel", MockHerRewardModel):
with task.start():
policy = MockPolicy()
her_reward_model = MockHerRewardModel()
learner = HERLearner(cfg, policy, buffer, her_reward_model)
learner(ctx)
assert len(ctx.train_output) == 4