gomoku / DI-engine /ding /utils /data /tests /test_dataloader.py
zjowowen's picture
init space
079c32c
raw
history blame
3.32 kB
import pytest
import threading
import time
import torch
import torch.nn as nn
from functools import partial
from itertools import product
from ding.utils import EasyTimer
from ding.utils.data import AsyncDataLoader
batch_size_args = [3, 6]
num_workers_args = [0, 4]
chunk_size_args = [1, 3]
args = [item for item in product(*[batch_size_args, num_workers_args, chunk_size_args])]
unittest_args = [item for item in product(*[[3], [2], [1]])]
class Dataset(object):
def __init__(self):
self.data = torch.randn(256, 256)
def __len__(self):
return 100
def __getitem__(self, idx):
time.sleep(0.5)
return [self.data, idx]
class TestAsyncDataLoader:
def get_data_source(self):
dataset = Dataset()
def data_source_fn(batch_size):
return [partial(dataset.__getitem__, idx=i) for i in range(batch_size)]
return data_source_fn
def get_model(self):
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.main = [nn.Linear(256, 256) for _ in range(10)]
self.main = nn.Sequential(*self.main)
def forward(self, x):
idx = x[1]
x = self.main(x[0])
time.sleep(1)
return [x, idx]
return Model()
# @pytest.mark.unittest
@pytest.mark.parametrize('batch_size, num_workers, chunk_size', unittest_args)
def test_cpu(self, batch_size, num_workers, chunk_size):
self.entry(batch_size, num_workers, chunk_size, use_cuda=False)
@pytest.mark.cudatest
@pytest.mark.parametrize('batch_size, num_workers, chunk_size', args)
def test_gpu(self, batch_size, num_workers, chunk_size):
self.entry(batch_size, num_workers, chunk_size, use_cuda=True)
torch.cuda.empty_cache()
def entry(self, batch_size, num_workers, chunk_size, use_cuda):
model = self.get_model()
if use_cuda:
model.cuda()
timer = EasyTimer()
data_source = self.get_data_source()
device = 'cuda' if use_cuda else 'cpu'
dataloader = AsyncDataLoader(data_source, batch_size, device, num_workers=num_workers, chunk_size=chunk_size)
count = 0
total_data_time = 0.
while True:
with timer:
data = next(dataloader)
data_time = timer.value
if count > 2: # ignore start-3 time
total_data_time += data_time
with timer:
with torch.no_grad():
_, idx = model(data)
if use_cuda:
idx = idx.cpu()
sorted_idx = torch.sort(idx)[0]
assert sorted_idx.eq(torch.arange(batch_size)).sum() == batch_size, idx
model_time = timer.value
print('count {}, data_time: {}, model_time: {}'.format(count, data_time, model_time))
count += 1
if count == 10:
break
if num_workers < 1:
assert total_data_time <= 7 * batch_size * 0.5 + 7 * 0.01 - 7 * 1
else:
assert total_data_time <= 7 * 0.008
dataloader.__del__()
time.sleep(0.5)
assert len(threading.enumerate()) <= 2, threading.enumerate()