|
import pytest |
|
from mlagents.torch_utils import torch |
|
|
|
from mlagents.trainers.torch_entities.decoders import ValueHeads |
|
|
|
|
|
def test_valueheads(): |
|
stream_names = [f"reward_signal_{num}" for num in range(5)] |
|
input_size = 5 |
|
batch_size = 4 |
|
|
|
|
|
value_heads = ValueHeads(stream_names, input_size) |
|
input_data = torch.ones((batch_size, input_size)) |
|
value_out = value_heads(input_data) |
|
|
|
for stream_name in stream_names: |
|
assert value_out[stream_name].shape == (batch_size,) |
|
|
|
|
|
with pytest.raises(Exception): |
|
value_out = value_heads(torch.ones((batch_size, input_size + 2))) |
|
|
|
|
|
output_size = 4 |
|
value_heads = ValueHeads(stream_names, input_size, output_size) |
|
input_data = torch.ones((batch_size, input_size)) |
|
value_out = value_heads(input_data) |
|
|
|
for stream_name in stream_names: |
|
assert value_out[stream_name].shape == (batch_size, output_size) |
|
|