gomoku / DI-engine /ding /torch_utils /tests /test_reshape_helper.py
zjowowen's picture
init space
079c32c
raw
history blame
1.23 kB
import pytest
import torch
from ding.torch_utils.reshape_helper import fold_batch, unfold_batch, unsqueeze_repeat
@pytest.mark.unittest
def test_fold_unfold_batch():
T, B, C, H, W = 10, 20, 3, 255, 255
data = torch.randn(T, B, C, H, W)
data, batch_dim = fold_batch(data, nonbatch_ndims=3)
assert data.shape == (T * B, C, H, W) and batch_dim == (T, B)
data = unfold_batch(data, batch_dim)
assert data.shape == (T, B, C, H, W)
T, B, N = 10, 20, 100
data = torch.randn(T, B, N)
data, batch_dim = fold_batch(data, nonbatch_ndims=1)
assert data.shape == (T * B, N) and batch_dim == (T, B)
data = unfold_batch(data, batch_dim)
assert data.shape == (T, B, N)
@pytest.mark.unittest
def test_unsqueeze_repeat():
T, B, C, H, W = 10, 20, 3, 255, 255
repeat_times = 4
data = torch.randn(T, B, C, H, W)
ensembled_data = unsqueeze_repeat(data, repeat_times)
assert ensembled_data.shape == (repeat_times, T, B, C, H, W)
ensembled_data = unsqueeze_repeat(data, repeat_times, -1)
assert ensembled_data.shape == (T, B, C, H, W, repeat_times)
ensembled_data = unsqueeze_repeat(data, repeat_times, 2)
assert ensembled_data.shape == (T, B, repeat_times, C, H, W)