File size: 4,554 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from itertools import product
import pytest
import torch
from ding.torch_utils import is_differentiable
from lzero.model.alphazero_model import PredictionNetwork
action_space_size = [2, 3]
batch_size = [100, 200]
num_res_blocks = [3]
num_channels = [3]
value_head_channels = [8]
policy_head_channels = [8]
fc_value_layers = [[
16,
]]
fc_policy_layers = [[
16,
]]
output_support_size = [2]
observation_shape = [1, 3, 3]
prediction_network_args = list(
product(
action_space_size,
batch_size,
num_res_blocks,
num_channels,
value_head_channels,
policy_head_channels,
fc_value_layers,
fc_policy_layers,
output_support_size,
)
)
@pytest.mark.unittest
class TestAlphaZeroModel:
def output_check(self, model, outputs):
if isinstance(outputs, torch.Tensor):
loss = outputs.sum()
elif isinstance(outputs, list):
loss = sum([t.sum() for t in outputs])
elif isinstance(outputs, dict):
loss = sum([v.sum() for v in outputs.values()])
is_differentiable(loss, model)
@pytest.mark.parametrize(
'action_space_size, batch_size, num_res_blocks, num_channels, value_head_channels, policy_head_channels, fc_value_layers, fc_policy_layers, output_support_size',
prediction_network_args
)
def test_prediction_network(
self, action_space_size, batch_size, num_res_blocks, num_channels, value_head_channels,
policy_head_channels,
fc_value_layers, fc_policy_layers, output_support_size
):
obs = torch.rand(batch_size, num_channels, 3, 3)
flatten_output_size_for_value_head = value_head_channels * observation_shape[1] * observation_shape[2]
flatten_output_size_for_policy_head = policy_head_channels * observation_shape[1] * observation_shape[2]
# print('='*20)
# print(batch_size, num_res_blocks, num_channels, action_space_size, fc_value_layers, fc_policy_layers, output_support_size)
# print('='*20)
prediction_network = PredictionNetwork(
action_space_size=action_space_size,
continuous_action_space=False,
num_res_blocks=num_res_blocks,
num_channels=num_channels,
value_head_channels=value_head_channels,
policy_head_channels=policy_head_channels,
fc_value_layers=fc_value_layers,
fc_policy_layers=fc_policy_layers,
output_support_size=output_support_size,
flatten_output_size_for_value_head=flatten_output_size_for_value_head,
flatten_output_size_for_policy_head=flatten_output_size_for_policy_head,
last_linear_layer_init_zero=True,
)
policy, value = prediction_network(obs)
assert policy.shape == torch.Size([batch_size, action_space_size])
assert value.shape == torch.Size([batch_size, output_support_size])
if __name__ == "__main__":
action_space_size = 2
batch_size = 100
num_res_blocks = 3
num_channels = 3
reward_head_channels = 2
value_head_channels = 8
policy_head_channels = 8
fc_value_layers = [16]
fc_policy_layers = [16]
output_support_size = 2
observation_shape = [1, 3, 3]
obs = torch.rand(batch_size, num_channels, 3, 3)
flatten_output_size_for_value_head = value_head_channels * observation_shape[1] * observation_shape[2]
flatten_output_size_for_policy_head = policy_head_channels * observation_shape[1] * observation_shape[2]
print('=' * 20)
print(
batch_size, num_res_blocks, num_channels, action_space_size, reward_head_channels, fc_value_layers,
fc_policy_layers, output_support_size
)
print('=' * 20)
prediction_network = PredictionNetwork(
action_space_size=action_space_size,
num_res_blocks=num_res_blocks,
num_channels=num_channels,
value_head_channels=value_head_channels,
policy_head_channels=policy_head_channels,
fc_value_layers=fc_value_layers,
fc_policy_layers=fc_policy_layers,
output_support_size=output_support_size,
flatten_output_size_for_value_head=flatten_output_size_for_value_head,
flatten_output_size_for_policy_head=flatten_output_size_for_policy_head,
last_linear_layer_init_zero=True,
)
policy, value = prediction_network(obs)
assert policy.shape == torch.Size([batch_size, action_space_size])
assert value.shape == torch.Size([batch_size, output_support_size])
|