AnnaMats's picture
Second Push
05c9ac2
import pytest
from mlagents.torch_utils import torch
from mlagents.trainers.torch_entities.decoders import ValueHeads
def test_valueheads():
stream_names = [f"reward_signal_{num}" for num in range(5)]
input_size = 5
batch_size = 4
# Test default 1 value per head
value_heads = ValueHeads(stream_names, input_size)
input_data = torch.ones((batch_size, input_size))
value_out = value_heads(input_data) # Note: mean value will be removed shortly
for stream_name in stream_names:
assert value_out[stream_name].shape == (batch_size,)
# Test that inputting the wrong size input will throw an error
with pytest.raises(Exception):
value_out = value_heads(torch.ones((batch_size, input_size + 2)))
# Test multiple values per head (e.g. discrete Q function)
output_size = 4
value_heads = ValueHeads(stream_names, input_size, output_size)
input_data = torch.ones((batch_size, input_size))
value_out = value_heads(input_data)
for stream_name in stream_names:
assert value_out[stream_name].shape == (batch_size, output_size)