|
"""monkey patches for the dataset fetcher to handle batches of packed indexes""" |
|
|
|
|
|
import torch |
|
from torch.utils.data._utils.fetch import _BaseDatasetFetcher |
|
from torch.utils.data._utils.worker import _worker_loop |
|
|
|
|
|
class _MapDatasetFetcher(_BaseDatasetFetcher): |
|
def fetch(self, possibly_batched_index): |
|
if isinstance(possibly_batched_index[0], list): |
|
data = [None for i in possibly_batched_index] |
|
for i, possibly_batched_index_ in enumerate(possibly_batched_index): |
|
if self.auto_collation: |
|
if ( |
|
hasattr(self.dataset, "__getitems__") |
|
and self.dataset.__getitems__ |
|
): |
|
data[i] = self.dataset.__getitems__(possibly_batched_index_) |
|
else: |
|
data[i] = [self.dataset[idx] for idx in possibly_batched_index_] |
|
else: |
|
data[i] = self.dataset[possibly_batched_index_] |
|
else: |
|
if self.auto_collation: |
|
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__: |
|
data = self.dataset.__getitems__(possibly_batched_index) |
|
else: |
|
data = [self.dataset[idx] for idx in possibly_batched_index] |
|
else: |
|
data = self.dataset[possibly_batched_index] |
|
return self.collate_fn(data) |
|
|
|
|
|
def patch_fetchers(): |
|
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher |
|
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher |
|
|
|
|
|
def patched_worker_loop(*args, **kwargs): |
|
patch_fetchers() |
|
return _worker_loop(*args, **kwargs) |
|
|
|
|
|
torch.utils.data._utils.worker._worker_loop = patched_worker_loop |
|
patch_fetchers() |
|
|