|
import pytest |
|
import threading |
|
import time |
|
import torch |
|
import torch.nn as nn |
|
from functools import partial |
|
from itertools import product |
|
|
|
from ding.utils import EasyTimer |
|
from ding.utils.data import AsyncDataLoader |
|
|
|
batch_size_args = [3, 6] |
|
num_workers_args = [0, 4] |
|
chunk_size_args = [1, 3] |
|
args = [item for item in product(*[batch_size_args, num_workers_args, chunk_size_args])] |
|
unittest_args = [item for item in product(*[[3], [2], [1]])] |
|
|
|
|
|
class Dataset(object): |
|
|
|
def __init__(self): |
|
self.data = torch.randn(256, 256) |
|
|
|
def __len__(self): |
|
return 100 |
|
|
|
def __getitem__(self, idx): |
|
time.sleep(0.5) |
|
return [self.data, idx] |
|
|
|
|
|
class TestAsyncDataLoader: |
|
|
|
def get_data_source(self): |
|
dataset = Dataset() |
|
|
|
def data_source_fn(batch_size): |
|
return [partial(dataset.__getitem__, idx=i) for i in range(batch_size)] |
|
|
|
return data_source_fn |
|
|
|
def get_model(self): |
|
|
|
class Model(nn.Module): |
|
|
|
def __init__(self): |
|
super(Model, self).__init__() |
|
self.main = [nn.Linear(256, 256) for _ in range(10)] |
|
self.main = nn.Sequential(*self.main) |
|
|
|
def forward(self, x): |
|
idx = x[1] |
|
x = self.main(x[0]) |
|
time.sleep(1) |
|
return [x, idx] |
|
|
|
return Model() |
|
|
|
|
|
@pytest.mark.parametrize('batch_size, num_workers, chunk_size', unittest_args) |
|
def test_cpu(self, batch_size, num_workers, chunk_size): |
|
self.entry(batch_size, num_workers, chunk_size, use_cuda=False) |
|
|
|
@pytest.mark.cudatest |
|
@pytest.mark.parametrize('batch_size, num_workers, chunk_size', args) |
|
def test_gpu(self, batch_size, num_workers, chunk_size): |
|
self.entry(batch_size, num_workers, chunk_size, use_cuda=True) |
|
torch.cuda.empty_cache() |
|
|
|
def entry(self, batch_size, num_workers, chunk_size, use_cuda): |
|
model = self.get_model() |
|
if use_cuda: |
|
model.cuda() |
|
timer = EasyTimer() |
|
data_source = self.get_data_source() |
|
device = 'cuda' if use_cuda else 'cpu' |
|
dataloader = AsyncDataLoader(data_source, batch_size, device, num_workers=num_workers, chunk_size=chunk_size) |
|
count = 0 |
|
total_data_time = 0. |
|
while True: |
|
with timer: |
|
data = next(dataloader) |
|
data_time = timer.value |
|
if count > 2: |
|
total_data_time += data_time |
|
with timer: |
|
with torch.no_grad(): |
|
_, idx = model(data) |
|
if use_cuda: |
|
idx = idx.cpu() |
|
sorted_idx = torch.sort(idx)[0] |
|
assert sorted_idx.eq(torch.arange(batch_size)).sum() == batch_size, idx |
|
model_time = timer.value |
|
print('count {}, data_time: {}, model_time: {}'.format(count, data_time, model_time)) |
|
count += 1 |
|
if count == 10: |
|
break |
|
if num_workers < 1: |
|
assert total_data_time <= 7 * batch_size * 0.5 + 7 * 0.01 - 7 * 1 |
|
else: |
|
assert total_data_time <= 7 * 0.008 |
|
dataloader.__del__() |
|
time.sleep(0.5) |
|
assert len(threading.enumerate()) <= 2, threading.enumerate() |
|
|