Spaces:
Running
Running
""" | |
A copy from https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/dataset.py | |
""" | |
import queue as Queue | |
import threading | |
import torch | |
from torch.utils.data import DataLoader | |
class BackgroundGenerator(threading.Thread): | |
def __init__(self, generator, local_rank, max_prefetch=6): | |
super(BackgroundGenerator, self).__init__() | |
self.queue = Queue.Queue(max_prefetch) | |
self.generator = generator | |
self.local_rank = local_rank | |
self.daemon = True | |
self.start() | |
def run(self): | |
torch.cuda.set_device(self.local_rank) | |
for item in self.generator: | |
self.queue.put(item) | |
self.queue.put(None) | |
def next(self): | |
next_item = self.queue.get() | |
if next_item is None: | |
raise StopIteration | |
return next_item | |
def __next__(self): | |
return self.next() | |
def __iter__(self): | |
return self | |
class DataLoaderX(DataLoader): | |
def __init__(self, local_rank, **kwargs): | |
super(DataLoaderX, self).__init__(**kwargs) | |
self.stream = torch.cuda.Stream(local_rank) | |
self.local_rank = local_rank | |
def __iter__(self): | |
self.iter = super(DataLoaderX, self).__iter__() | |
self.iter = BackgroundGenerator(self.iter, self.local_rank) | |
self.preload() | |
return self | |
def preload(self): | |
self.batch = next(self.iter, None) | |
if self.batch is None: | |
return None | |
with torch.cuda.stream(self.stream): | |
for k in range(len(self.batch)): | |
self.batch[k] = self.batch[k].to(device=self.local_rank, | |
non_blocking=True) | |
def __next__(self): | |
torch.cuda.current_stream().wait_stream(self.stream) | |
batch = self.batch | |
if batch is None: | |
raise StopIteration | |
self.preload() | |
return batch |