File size: 3,318 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import pytest
import threading
import time
import torch
import torch.nn as nn
from functools import partial
from itertools import product
from ding.utils import EasyTimer
from ding.utils.data import AsyncDataLoader
batch_size_args = [3, 6]
num_workers_args = [0, 4]
chunk_size_args = [1, 3]
args = [item for item in product(*[batch_size_args, num_workers_args, chunk_size_args])]
unittest_args = [item for item in product(*[[3], [2], [1]])]
class Dataset(object):
def __init__(self):
self.data = torch.randn(256, 256)
def __len__(self):
return 100
def __getitem__(self, idx):
time.sleep(0.5)
return [self.data, idx]
class TestAsyncDataLoader:
def get_data_source(self):
dataset = Dataset()
def data_source_fn(batch_size):
return [partial(dataset.__getitem__, idx=i) for i in range(batch_size)]
return data_source_fn
def get_model(self):
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.main = [nn.Linear(256, 256) for _ in range(10)]
self.main = nn.Sequential(*self.main)
def forward(self, x):
idx = x[1]
x = self.main(x[0])
time.sleep(1)
return [x, idx]
return Model()
# @pytest.mark.unittest
@pytest.mark.parametrize('batch_size, num_workers, chunk_size', unittest_args)
def test_cpu(self, batch_size, num_workers, chunk_size):
self.entry(batch_size, num_workers, chunk_size, use_cuda=False)
@pytest.mark.cudatest
@pytest.mark.parametrize('batch_size, num_workers, chunk_size', args)
def test_gpu(self, batch_size, num_workers, chunk_size):
self.entry(batch_size, num_workers, chunk_size, use_cuda=True)
torch.cuda.empty_cache()
def entry(self, batch_size, num_workers, chunk_size, use_cuda):
model = self.get_model()
if use_cuda:
model.cuda()
timer = EasyTimer()
data_source = self.get_data_source()
device = 'cuda' if use_cuda else 'cpu'
dataloader = AsyncDataLoader(data_source, batch_size, device, num_workers=num_workers, chunk_size=chunk_size)
count = 0
total_data_time = 0.
while True:
with timer:
data = next(dataloader)
data_time = timer.value
if count > 2: # ignore start-3 time
total_data_time += data_time
with timer:
with torch.no_grad():
_, idx = model(data)
if use_cuda:
idx = idx.cpu()
sorted_idx = torch.sort(idx)[0]
assert sorted_idx.eq(torch.arange(batch_size)).sum() == batch_size, idx
model_time = timer.value
print('count {}, data_time: {}, model_time: {}'.format(count, data_time, model_time))
count += 1
if count == 10:
break
if num_workers < 1:
assert total_data_time <= 7 * batch_size * 0.5 + 7 * 0.01 - 7 * 1
else:
assert total_data_time <= 7 * 0.008
dataloader.__del__()
time.sleep(0.5)
assert len(threading.enumerate()) <= 2, threading.enumerate()
|