File size: 7,116 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import numpy as np
import pytest
import torch
import torch.nn.functional as F

from lzero.policy.utils import negative_cosine_similarity, to_torch_float_tensor, visualize_avg_softmax, \
    calculate_topk_accuracy, plot_topk_accuracy, compare_argmax, plot_argmax_distribution


# We use the pytest.mark.unittest decorator to mark this class for unit testing.
@pytest.mark.unittest
class TestVisualizationFunctions:

    def test_visualize_avg_softmax(self):
        """
        This test checks whether the visualize_avg_softmax function correctly
        computes the average softmax probabilities and visualizes them.
        """

        # We initialize the input parameters.
        batch_size = 256
        num_classes = 10
        logits = torch.randn(batch_size, num_classes)

        # We call the visualize_avg_softmax function.
        visualize_avg_softmax(logits)

        # This function does not return anything, it only creates a plot.
        # Therefore, we can only visually inspect the plot to check if it is correct.

    def test_calculate_topk_accuracy(self):
        """
        This test checks whether the calculate_topk_accuracy function correctly
        computes the top-k accuracy.
        """

        # We initialize the input parameters.
        batch_size = 256
        num_classes = 10
        logits = torch.randn(batch_size, num_classes)
        true_labels = torch.randint(0, num_classes, [batch_size])
        true_one_hot = F.one_hot(true_labels, num_classes)
        top_k = 5

        # We call the calculate_topk_accuracy function.
        match_percentage = calculate_topk_accuracy(logits, true_one_hot, top_k)

        # We check if the match percentage is a float and within the range [0, 100].
        assert isinstance(match_percentage, float)
        assert 0 <= match_percentage <= 100

    def test_plot_topk_accuracy(self):
        """
        This test checks whether the plot_topk_accuracy function correctly
        plots the top-k accuracy for different values of k.
        """

        # We initialize the input parameters.
        batch_size = 256
        num_classes = 10
        logits = torch.randn(batch_size, num_classes)
        true_labels = torch.randint(0, num_classes, [batch_size])
        true_one_hot = F.one_hot(true_labels, num_classes)
        top_k_values = range(1, 6)

        # We call the plot_topk_accuracy function.
        plot_topk_accuracy(logits, true_one_hot, top_k_values)

        # This function does not return anything, it only creates a plot.
        # Therefore, we can only visually inspect the plot to check if it is correct.

    def test_compare_argmax(self):
        """
        This test checks whether the compare_argmax function correctly
        plots the comparison of argmax values.
        """

        # We initialize the input parameters.
        batch_size = 256
        num_classes = 10
        logits = torch.randn(batch_size, num_classes)
        true_labels = torch.randint(0, num_classes, [batch_size])
        chance_one_hot = F.one_hot(true_labels, num_classes)

        # We call the compare_argmax function.
        compare_argmax(logits, chance_one_hot)

        # This function does not return anything, it only creates a plot.
        # Therefore, we can only visually inspect the plot to check if it is correct.

    def test_plot_argmax_distribution(self):
        """
        This test checks whether the plot_argmax_distribution function correctly
        plots the distribution of argmax values.
        """

        # We initialize the input parameters.
        batch_size = 256
        num_classes = 10
        true_labels = torch.randint(0, num_classes, [batch_size])
        true_chance_one_hot = F.one_hot(true_labels, num_classes)

        # We call the plot_argmax_distribution function.
        plot_argmax_distribution(true_chance_one_hot)

        # This function does not return anything, it only creates a plot.
        # Therefore, we can only visually inspect the plot to check if it is correct.


# We use the pytest.mark.unittest decorator to mark this class for unit testing.
@pytest.mark.unittest
class TestUtils():

    # This function tests the negative_cosine_similarity function.
    # This function computes the negative cosine similarity between two vectors.
    def test_negative_cosine_similarity(self):
        # We initialize the input parameters.
        batch_size = 256
        dim = 512
        x1 = torch.randn(batch_size, dim)
        x2 = torch.randn(batch_size, dim)

        # We call the negative_cosine_similarity function.
        output = negative_cosine_similarity(x1, x2)

        # We check if the output shape is as expected.
        assert output.shape == (batch_size, )

        # We check if all elements of the output are between -1 and 1.
        assert ((output >= -1) & (output <= 1)).all()

        # We test a special case where the two input vectors are in the same direction.
        # In this case, the cosine similarity should be -1.
        x1 = torch.randn(batch_size, dim)
        positive_factor = torch.randint(1, 100, [1])
        output_positive = negative_cosine_similarity(x1, positive_factor.float() * x1)
        assert output_positive.shape == (batch_size, )
        assert ((output_positive - (-1)) < 1e-6).all()

        # We test another special case where the two input vectors are in opposite directions.
        # In this case, the cosine similarity should be 1.
        negative_factor = -torch.randint(1, 100, [1])
        output_negative = negative_cosine_similarity(x1, negative_factor.float() * x1)
        assert output_negative.shape == (batch_size, )
        assert ((output_positive - 1) < 1e-6).all()

    def test_to_torch_float_tensor(self):
        device = 'cpu'
        mask_batch_np, target_value_prefix_np, target_value_np, target_policy_np, weights_np = np.random.randn(
            4, 5
        ), np.random.randn(4, 5), np.random.randn(4, 5), np.random.randn(4, 5), np.random.randn(4, 5)
        data_list_np = [
            mask_batch_np,
            target_value_prefix_np.astype('float32'),
            target_value_np.astype('float32'), target_policy_np, weights_np
        ]
        [mask_batch_func, target_value_prefix_func, target_value_func, target_policy_func,
         weights_func] = to_torch_float_tensor(data_list_np, device)
        mask_batch_2 = torch.from_numpy(mask_batch_np).to(device).float()
        target_value_prefix_2 = torch.from_numpy(target_value_prefix_np.astype('float32')).to(device).float()
        target_value_2 = torch.from_numpy(target_value_np.astype('float32')).to(device).float()
        target_policy_2 = torch.from_numpy(target_policy_np).to(device).float()
        weights_2 = torch.from_numpy(weights_np).to(device).float()

        assert (mask_batch_func == mask_batch_2).all() and (target_value_prefix_func == target_value_prefix_2).all(
        ) and (target_value_func == target_value_2).all() and (target_policy_func == target_policy_2
                                                               ).all() and (weights_func == weights_2).all()