File size: 5,592 Bytes
05c9ac2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
from typing import Dict, Any
from unittest.mock import MagicMock
import pytest
import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.torch_entities.components.bc.module import BCModule
from mlagents.trainers.torch_entities.networks import SimpleActor
from mlagents.trainers.settings import (
TrainerSettings,
BehavioralCloningSettings,
NetworkSettings,
)
def create_bc_module(mock_behavior_specs, bc_settings, use_rnn, tanhresample):
# model_path = env.external_brain_names[0]
trainer_config = TrainerSettings()
trainer_config.network_settings.memory = (
NetworkSettings.MemorySettings() if use_rnn else None
)
actor_kwargs: Dict[str, Any] = {
"conditional_sigma": False,
"tanh_squash": tanhresample,
}
policy = TorchPolicy(
0,
mock_behavior_specs,
trainer_config.network_settings,
SimpleActor,
actor_kwargs,
)
bc_module = BCModule(
policy,
settings=bc_settings,
policy_learning_rate=trainer_config.hyperparameters.learning_rate,
default_batch_size=trainer_config.hyperparameters.batch_size,
default_num_epoch=3,
)
return bc_module
def assert_stats_are_float(stats):
for _, item in stats.items():
assert isinstance(item, float)
# Test default values
def test_bcmodule_defaults():
# See if default values match
mock_specs = mb.create_mock_3dball_behavior_specs()
bc_settings = BehavioralCloningSettings(
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo"
)
bc_module = create_bc_module(mock_specs, bc_settings, False, False)
assert bc_module.num_epoch == 3
assert bc_module.batch_size == TrainerSettings().hyperparameters.batch_size
# Assign strange values and see if it overrides properly
bc_settings = BehavioralCloningSettings(
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo",
num_epoch=100,
batch_size=10000,
)
bc_module = create_bc_module(mock_specs, bc_settings, False, False)
assert bc_module.num_epoch == 100
assert bc_module.batch_size == 10000
# Test with continuous control env and vector actions
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_update(is_sac):
mock_specs = mb.create_mock_3dball_behavior_specs()
bc_settings = BehavioralCloningSettings(
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo"
)
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
stats = bc_module.update()
assert_stats_are_float(stats)
# Test with constant pretraining learning rate
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_constant_lr_update(is_sac):
mock_specs = mb.create_mock_3dball_behavior_specs()
bc_settings = BehavioralCloningSettings(
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo",
steps=0,
)
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
stats = bc_module.update()
assert_stats_are_float(stats)
old_learning_rate = bc_module.current_lr
_ = bc_module.update()
assert old_learning_rate == bc_module.current_lr
# Test with constant pretraining learning rate
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_linear_lr_update(is_sac):
mock_specs = mb.create_mock_3dball_behavior_specs()
bc_settings = BehavioralCloningSettings(
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo",
steps=100,
)
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
# Should decay by 10/100 * 0.0003 = 0.00003
bc_module.policy.get_current_step = MagicMock(return_value=10)
old_learning_rate = bc_module.current_lr
_ = bc_module.update()
assert old_learning_rate - 0.00003 == pytest.approx(bc_module.current_lr, abs=0.01)
# Test with RNN
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_rnn_update(is_sac):
mock_specs = mb.create_mock_3dball_behavior_specs()
bc_settings = BehavioralCloningSettings(
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo"
)
bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac)
stats = bc_module.update()
assert_stats_are_float(stats)
# Test with discrete control and visual observations
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_dc_visual_update(is_sac):
mock_specs = mb.create_mock_banana_behavior_specs()
bc_settings = BehavioralCloningSettings(
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "testdcvis.demo"
)
bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
stats = bc_module.update()
assert_stats_are_float(stats)
# Test with discrete control, visual observations and RNN
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_rnn_dc_update(is_sac):
mock_specs = mb.create_mock_banana_behavior_specs()
bc_settings = BehavioralCloningSettings(
demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "testdcvis.demo"
)
bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac)
stats = bc_module.update()
assert_stats_are_float(stats)
if __name__ == "__main__":
pytest.main()
|