File size: 5,592 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
from typing import Dict, Any
from unittest.mock import MagicMock
import pytest
import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.torch_entities.components.bc.module import BCModule
from mlagents.trainers.torch_entities.networks import SimpleActor
from mlagents.trainers.settings import (
    TrainerSettings,
    BehavioralCloningSettings,
    NetworkSettings,
)


def create_bc_module(mock_behavior_specs, bc_settings, use_rnn, tanhresample):
    # model_path = env.external_brain_names[0]
    trainer_config = TrainerSettings()
    trainer_config.network_settings.memory = (
        NetworkSettings.MemorySettings() if use_rnn else None
    )
    actor_kwargs: Dict[str, Any] = {
        "conditional_sigma": False,
        "tanh_squash": tanhresample,
    }
    policy = TorchPolicy(
        0,
        mock_behavior_specs,
        trainer_config.network_settings,
        SimpleActor,
        actor_kwargs,
    )
    bc_module = BCModule(
        policy,
        settings=bc_settings,
        policy_learning_rate=trainer_config.hyperparameters.learning_rate,
        default_batch_size=trainer_config.hyperparameters.batch_size,
        default_num_epoch=3,
    )
    return bc_module


def assert_stats_are_float(stats):
    for _, item in stats.items():
        assert isinstance(item, float)


# Test default values
def test_bcmodule_defaults():
    # See if default values match
    mock_specs = mb.create_mock_3dball_behavior_specs()
    bc_settings = BehavioralCloningSettings(
        demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo"
    )
    bc_module = create_bc_module(mock_specs, bc_settings, False, False)
    assert bc_module.num_epoch == 3
    assert bc_module.batch_size == TrainerSettings().hyperparameters.batch_size
    # Assign strange values and see if it overrides properly
    bc_settings = BehavioralCloningSettings(
        demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo",
        num_epoch=100,
        batch_size=10000,
    )
    bc_module = create_bc_module(mock_specs, bc_settings, False, False)
    assert bc_module.num_epoch == 100
    assert bc_module.batch_size == 10000


# Test with continuous control env and vector actions
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_update(is_sac):
    mock_specs = mb.create_mock_3dball_behavior_specs()
    bc_settings = BehavioralCloningSettings(
        demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo"
    )
    bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
    stats = bc_module.update()
    assert_stats_are_float(stats)


# Test with constant pretraining learning rate
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_constant_lr_update(is_sac):
    mock_specs = mb.create_mock_3dball_behavior_specs()
    bc_settings = BehavioralCloningSettings(
        demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo",
        steps=0,
    )
    bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
    stats = bc_module.update()
    assert_stats_are_float(stats)
    old_learning_rate = bc_module.current_lr

    _ = bc_module.update()
    assert old_learning_rate == bc_module.current_lr


# Test with constant pretraining learning rate
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_linear_lr_update(is_sac):
    mock_specs = mb.create_mock_3dball_behavior_specs()
    bc_settings = BehavioralCloningSettings(
        demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo",
        steps=100,
    )
    bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
    # Should decay by 10/100 * 0.0003 = 0.00003
    bc_module.policy.get_current_step = MagicMock(return_value=10)
    old_learning_rate = bc_module.current_lr
    _ = bc_module.update()
    assert old_learning_rate - 0.00003 == pytest.approx(bc_module.current_lr, abs=0.01)


# Test with RNN
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_rnn_update(is_sac):
    mock_specs = mb.create_mock_3dball_behavior_specs()
    bc_settings = BehavioralCloningSettings(
        demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "test.demo"
    )
    bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac)
    stats = bc_module.update()
    assert_stats_are_float(stats)


# Test with discrete control and visual observations
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_dc_visual_update(is_sac):
    mock_specs = mb.create_mock_banana_behavior_specs()
    bc_settings = BehavioralCloningSettings(
        demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "testdcvis.demo"
    )
    bc_module = create_bc_module(mock_specs, bc_settings, False, is_sac)
    stats = bc_module.update()
    assert_stats_are_float(stats)


# Test with discrete control, visual observations and RNN


@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_rnn_dc_update(is_sac):
    mock_specs = mb.create_mock_banana_behavior_specs()
    bc_settings = BehavioralCloningSettings(
        demo_path=os.path.dirname(os.path.abspath(__file__)) + "/" + "testdcvis.demo"
    )
    bc_module = create_bc_module(mock_specs, bc_settings, True, is_sac)
    stats = bc_module.update()
    assert_stats_are_float(stats)


if __name__ == "__main__":
    pytest.main()