File size: 1,109 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import pytest
from mlagents.torch_utils import torch

from mlagents.trainers.torch_entities.decoders import ValueHeads


def test_valueheads():
    stream_names = [f"reward_signal_{num}" for num in range(5)]
    input_size = 5
    batch_size = 4

    # Test default 1 value per head
    value_heads = ValueHeads(stream_names, input_size)
    input_data = torch.ones((batch_size, input_size))
    value_out = value_heads(input_data)  # Note: mean value will be removed shortly

    for stream_name in stream_names:
        assert value_out[stream_name].shape == (batch_size,)

    # Test that inputting the wrong size input will throw an error
    with pytest.raises(Exception):
        value_out = value_heads(torch.ones((batch_size, input_size + 2)))

    # Test multiple values per head (e.g. discrete Q function)
    output_size = 4
    value_heads = ValueHeads(stream_names, input_size, output_size)
    input_data = torch.ones((batch_size, input_size))
    value_out = value_heads(input_data)

    for stream_name in stream_names:
        assert value_out[stream_name].shape == (batch_size, output_size)