File size: 2,241 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import pytest
from torch import nn
from lzero.model.stochastic_muzero_model import ChanceEncoder

# Initialize a ChanceEncoder instance for testing
@pytest.fixture
def encoder():
    return ChanceEncoder((3, 32, 32), 4)

def test_ChanceEncoder(encoder):
    # Create a dummy tensor for testing
    x_and_last_x = torch.randn(1, 6, 32, 32)

    # Forward pass through the encoder
    chance_encoding_t, chance_onehot_t = encoder(x_and_last_x)

    # Check the output shapes
    assert chance_encoding_t.shape == (1, 4)
    assert chance_onehot_t.shape == (1, 4)

    # Check that chance_onehot_t is indeed one-hot
    assert torch.all((chance_onehot_t == 0) | (chance_onehot_t == 1))
    assert torch.all(torch.sum(chance_onehot_t, dim=1) == 1)
    
def test_ChanceEncoder_gradients_chance_encoding(encoder):
    # Create a dummy tensor for testing
    x_and_last_x = torch.randn(1, 6, 32, 32)

    # Forward pass through the encoder
    chance_encoding_t, chance_onehot_t = encoder(x_and_last_x)

    # Create a dummy target tensor for a simple loss function
    target = torch.randn(1, 4)

    # Use mean squared error as a simple loss function
    loss = nn.MSELoss()(chance_encoding_t, target)

    # Backward pass
    loss.backward()

    # Check if gradients are computed
    for param in encoder.parameters():
        assert param.grad is not None

    # Check if gradients have the correct shape
    for param in encoder.parameters():
        assert param.grad.shape == param.shape

def test_ChanceEncoder_gradients_chance_onehot_t(encoder):
    # Create a dummy tensor for testing
    x_and_last_x = torch.randn(1, 6, 32, 32)

    # Forward pass through the encoder
    chance_encoding_t, chance_onehot_t = encoder(x_and_last_x)

    # Create a dummy target tensor for a simple loss function
    target = torch.randn(1, 4)

    # Use mean squared error as a simple loss function
    loss = nn.MSELoss()(chance_onehot_t, target)

    # Backward pass
    loss.backward()

    # Check if gradients are computed
    for param in encoder.parameters():
        assert param.grad is not None

    # Check if gradients have the correct shape
    for param in encoder.parameters():
        assert param.grad.shape == param.shape