petrel-oss-python-sdk2 / tests /dataloader_test.py
Weiyun1025's picture
Upload folder using huggingface_hub
2abfccb verified
raw
history blame
2.71 kB
import torch
from torch.utils.data import get_worker_info
from torch.utils.data import DataLoader
import random
import time
from functools import partial
from itertools import chain
from petrel_client.utils.data import DataLoader as MyDataLoader
MyDataLoader = partial(MyDataLoader, prefetch_factor=4, persistent_workers=True)
def assert_equal(lhs, rhs):
if isinstance(lhs, dict):
assert lhs.keys() == rhs.keys()
for k in lhs.keys():
assert_equal(lhs[k], rhs[k])
elif isinstance(lhs, list):
assert len(lhs) == len(rhs)
for i in range(len(lhs)):
assert_equal(lhs[i], rhs[i])
elif isinstance(lhs, torch.Tensor):
assert torch.equal(lhs, rhs)
else:
assert False
def wait(dt):
time.sleep(dt)
class Dataset(list):
def __init__(self, *args, **kwargs):
super(Dataset, self).__init__(*args, **kwargs)
self._seed_inited = False
def __getitem__(self, *args, **kwargs):
worker_info = get_worker_info()
if not self._seed_inited:
if worker_info is None:
random.seed(0)
else:
random.seed(worker_info.id)
self._seed_inited = True
rand_int = random.randint(1, 4)
time_to_sleep = rand_int * 0.05
if worker_info is not None and worker_info.id == 0:
time_to_sleep *= 2
wait(time_to_sleep)
val = super(Dataset, self).__getitem__(*args, **kwargs)
return {'val': val}
def test(dataloader, result):
print('\ntest')
random.seed(0)
data_time = 0
tstart = t1 = time.time()
for i, data in enumerate(chain(dataloader, dataloader), 1):
t2 = time.time()
d = t2 - t1
print('{0:>5}' .format(int((t2 - t1)*1000)), end='')
if i % 10:
print('\t', end='')
else:
print('')
result.append(data)
data_time += d
rand_int = random.randrange(1, 4)
wait(0.05 * rand_int)
t1 = time.time()
tend = time.time()
print('\ntotal time: %.3f' % (tend - tstart))
print('total data time: %.3f' % data_time)
print(type(dataloader))
def worker_init_fn(worker_id):
print('start worker:', worker_id)
wait(3)
dataloader_args = {
'dataset': Dataset(range(1024)),
'drop_last': False,
'shuffle': False,
'batch_size': 32,
'num_workers': 8,
'worker_init_fn': worker_init_fn,
}
torch.manual_seed(0)
l2 = MyDataLoader(**dataloader_args)
r2 = []
test(l2, r2)
torch.manual_seed(0)
l1 = DataLoader(**dataloader_args)
r1 = []
test(l1, r1)
print('len l1:', len(l1))
print('len l2:', len(l2))
assert_equal(r1, r2)
print(torch)