File size: 7,116 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import numpy as np
import pytest
import torch
import torch.nn.functional as F
from lzero.policy.utils import negative_cosine_similarity, to_torch_float_tensor, visualize_avg_softmax, \
calculate_topk_accuracy, plot_topk_accuracy, compare_argmax, plot_argmax_distribution
# We use the pytest.mark.unittest decorator to mark this class for unit testing.
@pytest.mark.unittest
class TestVisualizationFunctions:
def test_visualize_avg_softmax(self):
"""
This test checks whether the visualize_avg_softmax function correctly
computes the average softmax probabilities and visualizes them.
"""
# We initialize the input parameters.
batch_size = 256
num_classes = 10
logits = torch.randn(batch_size, num_classes)
# We call the visualize_avg_softmax function.
visualize_avg_softmax(logits)
# This function does not return anything, it only creates a plot.
# Therefore, we can only visually inspect the plot to check if it is correct.
def test_calculate_topk_accuracy(self):
"""
This test checks whether the calculate_topk_accuracy function correctly
computes the top-k accuracy.
"""
# We initialize the input parameters.
batch_size = 256
num_classes = 10
logits = torch.randn(batch_size, num_classes)
true_labels = torch.randint(0, num_classes, [batch_size])
true_one_hot = F.one_hot(true_labels, num_classes)
top_k = 5
# We call the calculate_topk_accuracy function.
match_percentage = calculate_topk_accuracy(logits, true_one_hot, top_k)
# We check if the match percentage is a float and within the range [0, 100].
assert isinstance(match_percentage, float)
assert 0 <= match_percentage <= 100
def test_plot_topk_accuracy(self):
"""
This test checks whether the plot_topk_accuracy function correctly
plots the top-k accuracy for different values of k.
"""
# We initialize the input parameters.
batch_size = 256
num_classes = 10
logits = torch.randn(batch_size, num_classes)
true_labels = torch.randint(0, num_classes, [batch_size])
true_one_hot = F.one_hot(true_labels, num_classes)
top_k_values = range(1, 6)
# We call the plot_topk_accuracy function.
plot_topk_accuracy(logits, true_one_hot, top_k_values)
# This function does not return anything, it only creates a plot.
# Therefore, we can only visually inspect the plot to check if it is correct.
def test_compare_argmax(self):
"""
This test checks whether the compare_argmax function correctly
plots the comparison of argmax values.
"""
# We initialize the input parameters.
batch_size = 256
num_classes = 10
logits = torch.randn(batch_size, num_classes)
true_labels = torch.randint(0, num_classes, [batch_size])
chance_one_hot = F.one_hot(true_labels, num_classes)
# We call the compare_argmax function.
compare_argmax(logits, chance_one_hot)
# This function does not return anything, it only creates a plot.
# Therefore, we can only visually inspect the plot to check if it is correct.
def test_plot_argmax_distribution(self):
"""
This test checks whether the plot_argmax_distribution function correctly
plots the distribution of argmax values.
"""
# We initialize the input parameters.
batch_size = 256
num_classes = 10
true_labels = torch.randint(0, num_classes, [batch_size])
true_chance_one_hot = F.one_hot(true_labels, num_classes)
# We call the plot_argmax_distribution function.
plot_argmax_distribution(true_chance_one_hot)
# This function does not return anything, it only creates a plot.
# Therefore, we can only visually inspect the plot to check if it is correct.
# We use the pytest.mark.unittest decorator to mark this class for unit testing.
@pytest.mark.unittest
class TestUtils():
# This function tests the negative_cosine_similarity function.
# This function computes the negative cosine similarity between two vectors.
def test_negative_cosine_similarity(self):
# We initialize the input parameters.
batch_size = 256
dim = 512
x1 = torch.randn(batch_size, dim)
x2 = torch.randn(batch_size, dim)
# We call the negative_cosine_similarity function.
output = negative_cosine_similarity(x1, x2)
# We check if the output shape is as expected.
assert output.shape == (batch_size, )
# We check if all elements of the output are between -1 and 1.
assert ((output >= -1) & (output <= 1)).all()
# We test a special case where the two input vectors are in the same direction.
# In this case, the cosine similarity should be -1.
x1 = torch.randn(batch_size, dim)
positive_factor = torch.randint(1, 100, [1])
output_positive = negative_cosine_similarity(x1, positive_factor.float() * x1)
assert output_positive.shape == (batch_size, )
assert ((output_positive - (-1)) < 1e-6).all()
# We test another special case where the two input vectors are in opposite directions.
# In this case, the cosine similarity should be 1.
negative_factor = -torch.randint(1, 100, [1])
output_negative = negative_cosine_similarity(x1, negative_factor.float() * x1)
assert output_negative.shape == (batch_size, )
assert ((output_positive - 1) < 1e-6).all()
def test_to_torch_float_tensor(self):
device = 'cpu'
mask_batch_np, target_value_prefix_np, target_value_np, target_policy_np, weights_np = np.random.randn(
4, 5
), np.random.randn(4, 5), np.random.randn(4, 5), np.random.randn(4, 5), np.random.randn(4, 5)
data_list_np = [
mask_batch_np,
target_value_prefix_np.astype('float32'),
target_value_np.astype('float32'), target_policy_np, weights_np
]
[mask_batch_func, target_value_prefix_func, target_value_func, target_policy_func,
weights_func] = to_torch_float_tensor(data_list_np, device)
mask_batch_2 = torch.from_numpy(mask_batch_np).to(device).float()
target_value_prefix_2 = torch.from_numpy(target_value_prefix_np.astype('float32')).to(device).float()
target_value_2 = torch.from_numpy(target_value_np.astype('float32')).to(device).float()
target_policy_2 = torch.from_numpy(target_policy_np).to(device).float()
weights_2 = torch.from_numpy(weights_np).to(device).float()
assert (mask_batch_func == mask_batch_2).all() and (target_value_prefix_func == target_value_prefix_2).all(
) and (target_value_func == target_value_2).all() and (target_policy_func == target_policy_2
).all() and (weights_func == weights_2).all()
|