File size: 4,108 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import pytest
from mlagents.torch_utils import torch

from mlagents.trainers.buffer import BufferKey, RewardSignalUtil
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.torch_entities.networks import SimpleActor
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.tests.dummy_config import (  # noqa: F401
    sac_dummy_config,
    curiosity_dummy_config,
)


@pytest.fixture
def dummy_config():
    return sac_dummy_config()


VECTOR_ACTION_SPACE = 2
VECTOR_OBS_SPACE = 8
DISCRETE_ACTION_SPACE = [3, 3, 3, 2]
BUFFER_INIT_SAMPLES = 64
NUM_AGENTS = 12


def create_sac_optimizer_mock(dummy_config, use_rnn, use_discrete, use_visual):
    mock_brain = mb.setup_test_behavior_specs(
        use_discrete,
        use_visual,
        vector_action_space=DISCRETE_ACTION_SPACE
        if use_discrete
        else VECTOR_ACTION_SPACE,
        vector_obs_space=VECTOR_OBS_SPACE if not use_visual else 0,
    )
    trainer_settings = dummy_config
    trainer_settings.network_settings.memory = (
        NetworkSettings.MemorySettings(sequence_length=16, memory_size=12)
        if use_rnn
        else None
    )
    actor_kwargs = {
        "conditional_sigma": False,
        "tanh_squash": False,
    }
    policy = TorchPolicy(
        0, mock_brain, trainer_settings.network_settings, SimpleActor, actor_kwargs
    )
    optimizer = TorchSACOptimizer(policy, trainer_settings)
    return optimizer


@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
def test_sac_optimizer_update(dummy_config, rnn, visual, discrete):
    torch.manual_seed(0)
    # Test evaluate
    optimizer = create_sac_optimizer_mock(
        dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual
    )
    # Test update
    update_buffer = mb.simulate_rollout(
        BUFFER_INIT_SAMPLES, optimizer.policy.behavior_spec, memory_size=12
    )
    # Mock out reward signal eval
    update_buffer[RewardSignalUtil.rewards_key("extrinsic")] = update_buffer[
        BufferKey.ENVIRONMENT_REWARDS
    ]
    # Mock out value memories
    update_buffer[BufferKey.CRITIC_MEMORY] = update_buffer[BufferKey.MEMORY]
    return_stats = optimizer.update(
        update_buffer,
        num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length,
    )
    # Make sure we have the right stats
    required_stats = [
        "Losses/Policy Loss",
        "Losses/Value Loss",
        "Losses/Q1 Loss",
        "Losses/Q2 Loss",
        "Policy/Continuous Entropy Coeff",
        "Policy/Discrete Entropy Coeff",
        "Policy/Learning Rate",
    ]
    for stat in required_stats:
        assert stat in return_stats.keys()


@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
def test_sac_update_reward_signals(
    dummy_config, curiosity_dummy_config, discrete  # noqa: F811
):
    # Add a Curiosity module
    dummy_config.reward_signals = curiosity_dummy_config
    optimizer = create_sac_optimizer_mock(
        dummy_config, use_rnn=False, use_discrete=discrete, use_visual=False
    )

    # Test update, while removing PPO-specific buffer elements.
    update_buffer = mb.simulate_rollout(
        BUFFER_INIT_SAMPLES, optimizer.policy.behavior_spec
    )

    # Mock out reward signal eval
    update_buffer[RewardSignalUtil.rewards_key("extrinsic")] = update_buffer[
        BufferKey.ENVIRONMENT_REWARDS
    ]
    update_buffer[RewardSignalUtil.rewards_key("curiosity")] = update_buffer[
        BufferKey.ENVIRONMENT_REWARDS
    ]
    return_stats = optimizer.update_reward_signals(update_buffer)
    required_stats = ["Losses/Curiosity Forward Loss", "Losses/Curiosity Inverse Loss"]
    for stat in required_stats:
        assert stat in return_stats.keys()


if __name__ == "__main__":
    pytest.main()