File size: 2,241 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
import pytest
from torch import nn
from lzero.model.stochastic_muzero_model import ChanceEncoder
# Initialize a ChanceEncoder instance for testing
@pytest.fixture
def encoder():
return ChanceEncoder((3, 32, 32), 4)
def test_ChanceEncoder(encoder):
# Create a dummy tensor for testing
x_and_last_x = torch.randn(1, 6, 32, 32)
# Forward pass through the encoder
chance_encoding_t, chance_onehot_t = encoder(x_and_last_x)
# Check the output shapes
assert chance_encoding_t.shape == (1, 4)
assert chance_onehot_t.shape == (1, 4)
# Check that chance_onehot_t is indeed one-hot
assert torch.all((chance_onehot_t == 0) | (chance_onehot_t == 1))
assert torch.all(torch.sum(chance_onehot_t, dim=1) == 1)
def test_ChanceEncoder_gradients_chance_encoding(encoder):
# Create a dummy tensor for testing
x_and_last_x = torch.randn(1, 6, 32, 32)
# Forward pass through the encoder
chance_encoding_t, chance_onehot_t = encoder(x_and_last_x)
# Create a dummy target tensor for a simple loss function
target = torch.randn(1, 4)
# Use mean squared error as a simple loss function
loss = nn.MSELoss()(chance_encoding_t, target)
# Backward pass
loss.backward()
# Check if gradients are computed
for param in encoder.parameters():
assert param.grad is not None
# Check if gradients have the correct shape
for param in encoder.parameters():
assert param.grad.shape == param.shape
def test_ChanceEncoder_gradients_chance_onehot_t(encoder):
# Create a dummy tensor for testing
x_and_last_x = torch.randn(1, 6, 32, 32)
# Forward pass through the encoder
chance_encoding_t, chance_onehot_t = encoder(x_and_last_x)
# Create a dummy target tensor for a simple loss function
target = torch.randn(1, 4)
# Use mean squared error as a simple loss function
loss = nn.MSELoss()(chance_onehot_t, target)
# Backward pass
loss.backward()
# Check if gradients are computed
for param in encoder.parameters():
assert param.grad is not None
# Check if gradients have the correct shape
for param in encoder.parameters():
assert param.grad.shape == param.shape
|