File size: 2,716 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
from mlagents.torch_utils import torch

from mlagents.trainers.buffer import AgentBuffer, BufferKey
from mlagents.trainers.torch_entities.agent_action import AgentAction


def test_agent_action_group_from_buffer():
    buff = AgentBuffer()
    # Create some actions
    for _ in range(3):
        buff[BufferKey.GROUP_CONTINUOUS_ACTION].append(
            3 * [np.ones((5,), dtype=np.float32)]
        )
        buff[BufferKey.GROUP_DISCRETE_ACTION].append(
            3 * [np.ones((4,), dtype=np.float32)]
        )
    # Some agents have died
    for _ in range(2):
        buff[BufferKey.GROUP_CONTINUOUS_ACTION].append(
            1 * [np.ones((5,), dtype=np.float32)]
        )
        buff[BufferKey.GROUP_DISCRETE_ACTION].append(
            1 * [np.ones((4,), dtype=np.float32)]
        )

    # Get the group actions, which will be a List of Lists of AgentAction, where each element is the same
    # length as the AgentBuffer but contains only one agent's obs. Dead agents are padded by
    # NaNs.
    gact = AgentAction.group_from_buffer(buff)
    # Agent 0 is full
    agent_0_act = gact[0]
    assert agent_0_act.continuous_tensor.shape == (buff.num_experiences, 5)
    assert agent_0_act.discrete_tensor.shape == (buff.num_experiences, 4)

    agent_1_act = gact[1]
    assert agent_1_act.continuous_tensor.shape == (buff.num_experiences, 5)
    assert agent_1_act.discrete_tensor.shape == (buff.num_experiences, 4)
    assert (agent_1_act.continuous_tensor[0:3] > 0).all()
    assert (agent_1_act.continuous_tensor[3:] == 0).all()
    assert (agent_1_act.discrete_tensor[0:3] > 0).all()
    assert (agent_1_act.discrete_tensor[3:] == 0).all()


def test_slice():
    # Both continuous and discrete
    aa = AgentAction(
        torch.tensor([[1.0], [1.0], [1.0]]),
        [torch.tensor([2, 1, 0]), torch.tensor([1, 2, 0])],
    )
    saa = aa.slice(0, 2)
    assert saa.continuous_tensor.shape == (2, 1)
    assert saa.discrete_tensor.shape == (2, 2)


def test_to_flat():
    # Both continuous and discrete
    aa = AgentAction(
        torch.tensor([[1.0, 1.0, 1.0]]), [torch.tensor([2]), torch.tensor([1])]
    )
    flattened_actions = aa.to_flat([3, 3])
    assert torch.eq(
        flattened_actions, torch.tensor([[1, 1, 1, 0, 0, 1, 0, 1, 0]])
    ).all()

    # Just continuous
    aa = AgentAction(torch.tensor([[1.0, 1.0, 1.0]]), None)
    flattened_actions = aa.to_flat([])
    assert torch.eq(flattened_actions, torch.tensor([1, 1, 1])).all()

    # Just discrete
    aa = AgentAction(torch.tensor([]), [torch.tensor([2]), torch.tensor([1])])
    flattened_actions = aa.to_flat([3, 3])
    assert torch.eq(flattened_actions, torch.tensor([0, 0, 1, 0, 1, 0])).all()