qwerrwe / src /axolotl /monkeypatch /data /batch_dataset_fetcher.py
winglian's picture
support for true batches with multipack (#1230)
00568c1 unverified
raw
history blame
1.84 kB
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
# pylint: disable=protected-access
import torch
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
from torch.utils.data._utils.worker import _worker_loop
class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index):
if isinstance(possibly_batched_index[0], list):
data = [None for i in possibly_batched_index]
for i, possibly_batched_index_ in enumerate(possibly_batched_index):
if self.auto_collation:
if (
hasattr(self.dataset, "__getitems__")
and self.dataset.__getitems__
):
data[i] = self.dataset.__getitems__(possibly_batched_index_)
else:
data[i] = [self.dataset[idx] for idx in possibly_batched_index_]
else:
data[i] = self.dataset[possibly_batched_index_]
else:
if self.auto_collation:
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
data = self.dataset.__getitems__(possibly_batched_index)
else:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
def patch_fetchers():
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
def patched_worker_loop(*args, **kwargs):
patch_fetchers()
return _worker_loop(*args, **kwargs)
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
patch_fetchers()