AnnaMats's picture
Second Push
05c9ac2
import os
import unittest
from unittest import mock
import pytest
import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.policy.checkpoint_manager import ModelCheckpoint
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.tests.test_buffer import construct_fake_buffer
from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.policy import Policy
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes
from mlagents_envs.base_env import ActionSpec
import os.path
# Add concrete implementations of abstract methods
class FakeTrainer(RLTrainer):
def set_is_policy_updating(self, is_updating):
self.update_policy = is_updating
def get_policy(self, name_behavior_id):
return mock.Mock()
def _is_ready_update(self):
return True
def _update_policy(self):
return self.update_policy
def add_policy(self, mock_behavior_id, mock_policy):
def checkpoint_path(brain_name, step):
onnx_file_path = os.path.join(
self.model_saver.model_path, f"{brain_name}-{step}.onnx"
)
other_file_paths = [
os.path.join(self.model_saver.model_path, f"{brain_name}-{step}.pt")
]
return onnx_file_path, other_file_paths
self.policies[mock_behavior_id] = mock_policy
mock_model_saver = mock.Mock()
mock_model_saver.model_path = self.artifact_path
mock_model_saver.save_checkpoint.side_effect = checkpoint_path
self.model_saver = mock_model_saver
def create_optimizer(self) -> TorchOptimizer:
return mock.Mock()
def create_policy(
self,
parsed_behavior_id: BehaviorIdentifiers,
behavior_spec: BehaviorSpec,
) -> Policy:
return mock.Mock()
def _process_trajectory(self, trajectory):
super()._process_trajectory(trajectory)
def create_rl_trainer():
trainer = FakeTrainer(
"test_trainer",
TrainerSettings(max_steps=100, checkpoint_interval=10, summary_freq=20),
True,
False,
"mock_model_path",
0,
)
trainer.set_is_policy_updating(True)
return trainer
def test_rl_trainer():
trainer = create_rl_trainer()
agent_id = "0"
trainer.collected_rewards["extrinsic"] = {agent_id: 3}
# Test end episode
trainer.end_episode()
for rewards in trainer.collected_rewards.values():
for agent_id in rewards:
assert rewards[agent_id] == 0
def test_clear_update_buffer():
trainer = create_rl_trainer()
trainer.update_buffer = construct_fake_buffer(0)
trainer._clear_update_buffer()
for _, arr in trainer.update_buffer.items():
assert len(arr) == 0
@mock.patch("mlagents.trainers.trainer.trainer.Trainer.save_model")
def test_advance(mocked_save_model):
trainer = create_rl_trainer()
mock_policy = mock.Mock()
trainer.add_policy("TestBrain", mock_policy)
trajectory_queue = AgentManagerQueue("testbrain")
policy_queue = AgentManagerQueue("testbrain")
trainer.subscribe_trajectory_queue(trajectory_queue)
trainer.publish_policy_queue(policy_queue)
time_horizon = 10
trajectory = mb.make_fake_trajectory(
length=time_horizon,
observation_specs=create_observation_specs_with_shapes([(1,)]),
max_step_complete=True,
action_spec=ActionSpec.create_discrete((2,)),
)
trajectory_queue.put(trajectory)
trainer.advance()
policy_queue.get_nowait()
# Check that get_step is correct
assert trainer.get_step == time_horizon
# Check that we can turn off the trainer and that the buffer is cleared
for _ in range(0, 5):
trajectory_queue.put(trajectory)
trainer.advance()
# Check that there is stuff in the policy queue
policy_queue.get_nowait()
# Check that if the policy doesn't update, we don't push it to the queue
trainer.set_is_policy_updating(False)
for _ in range(0, 10):
trajectory_queue.put(trajectory)
trainer.advance()
# Check that there nothing in the policy queue
with pytest.raises(AgentManagerQueue.Empty):
policy_queue.get_nowait()
# Check that no model has been saved
assert not trainer.should_still_train
assert mocked_save_model.call_count == 0
@mock.patch("mlagents.trainers.trainer.trainer.StatsReporter.write_stats")
@mock.patch(
"mlagents.trainers.trainer.rl_trainer.ModelCheckpointManager.add_checkpoint"
)
def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary):
trainer = create_rl_trainer()
mock_policy = mock.Mock()
trainer.add_policy("TestBrain", mock_policy)
trajectory_queue = AgentManagerQueue("testbrain")
policy_queue = AgentManagerQueue("testbrain")
trainer.subscribe_trajectory_queue(trajectory_queue)
trainer.publish_policy_queue(policy_queue)
time_horizon = 10
summary_freq = trainer.trainer_settings.summary_freq
checkpoint_interval = trainer.trainer_settings.checkpoint_interval
trajectory = mb.make_fake_trajectory(
length=time_horizon,
observation_specs=create_observation_specs_with_shapes([(1,)]),
max_step_complete=True,
action_spec=ActionSpec.create_discrete((2,)),
)
# Check that we can turn off the trainer and that the buffer is cleared
num_trajectories = 5
for _ in range(0, num_trajectories):
trajectory_queue.put(trajectory)
trainer.advance()
# Check that there is stuff in the policy queue
policy_queue.get_nowait()
# Check that we have called write_summary the appropriate number of times
calls = [
mock.call(step)
for step in range(summary_freq, num_trajectories * time_horizon, summary_freq)
]
mock_write_summary.assert_has_calls(calls, any_order=True)
checkpoint_range = range(
checkpoint_interval, num_trajectories * time_horizon, checkpoint_interval
)
calls = [mock.call(trainer.brain_name, step) for step in checkpoint_range]
trainer.model_saver.save_checkpoint.assert_has_calls(calls, any_order=True)
export_ext = "onnx"
add_checkpoint_calls = [
mock.call(
trainer.brain_name,
ModelCheckpoint(
step,
f"{trainer.model_saver.model_path}{os.path.sep}{trainer.brain_name}-{step}.{export_ext}",
None,
mock.ANY,
[
f"{trainer.model_saver.model_path}{os.path.sep}{trainer.brain_name}-{step}.pt"
],
),
trainer.trainer_settings.keep_checkpoints,
)
for step in checkpoint_range
]
mock_add_checkpoint.assert_has_calls(add_checkpoint_calls)
def test_update_buffer_append():
trainer = create_rl_trainer()
mock_policy = mock.Mock()
trainer.add_policy("TestBrain", mock_policy)
trajectory_queue = AgentManagerQueue("testbrain")
policy_queue = AgentManagerQueue("testbrain")
trainer.subscribe_trajectory_queue(trajectory_queue)
trainer.publish_policy_queue(policy_queue)
time_horizon = 10
trajectory = mb.make_fake_trajectory(
length=time_horizon,
observation_specs=create_observation_specs_with_shapes([(1,)]),
max_step_complete=True,
action_spec=ActionSpec.create_discrete((2,)),
)
agentbuffer_trajectory = trajectory.to_agentbuffer()
assert trainer.update_buffer.num_experiences == 0
# Check that if we append, our update buffer gets longer.
# max_steps = 100
for i in range(10):
trainer._process_trajectory(trajectory)
trainer._append_to_update_buffer(agentbuffer_trajectory)
assert trainer.update_buffer.num_experiences == (i + 1) * time_horizon
# Check that if we append after stopping training, nothing happens.
# We process enough trajectories to hit max steps
trainer.set_is_policy_updating(False)
trainer._process_trajectory(trajectory)
trainer._append_to_update_buffer(agentbuffer_trajectory)
assert trainer.update_buffer.num_experiences == (i + 1) * time_horizon
class RLTrainerWarningTest(unittest.TestCase):
def test_warning_group_reward(self):
with self.assertLogs("mlagents.trainers", level="WARN") as cm:
rl_trainer = create_rl_trainer()
# This one should warn
trajectory = mb.make_fake_trajectory(
length=10,
observation_specs=create_observation_specs_with_shapes([(1,)]),
max_step_complete=True,
action_spec=ActionSpec.create_discrete((2,)),
group_reward=1.0,
)
buff = trajectory.to_agentbuffer()
rl_trainer._warn_if_group_reward(buff)
assert len(cm.output) > 0
len_of_first_warning = len(cm.output)
rl_trainer = create_rl_trainer()
# This one shouldn't
trajectory = mb.make_fake_trajectory(
length=10,
observation_specs=create_observation_specs_with_shapes([(1,)]),
max_step_complete=True,
action_spec=ActionSpec.create_discrete((2,)),
)
buff = trajectory.to_agentbuffer()
rl_trainer._warn_if_group_reward(buff)
# Make sure warnings don't get bigger
assert len(cm.output) == len_of_first_warning