|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import math |
|
from typing import Callable, Iterable, List, Optional, Sequence |
|
|
|
import torch |
|
|
|
from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Subset |
|
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
|
|
class MixedDataLoader: |
|
def __init__(self, dataloaders: List[DataLoader], mixing_prob: torch.FloatTensor): |
|
""" |
|
Args: |
|
dataloaders (List[DataLoader]): List of DataLoaders to be mixed. |
|
mixing_prob (torch.FloatTensor): Probability of each dataloader to be sampled from |
|
|
|
""" |
|
assert len(dataloaders) == mixing_prob.shape[0] |
|
self.dataloaders = dataloaders |
|
self.mixing_prob = mixing_prob |
|
|
|
self._iter_dls = None |
|
self._iter_mixing_prob = None |
|
self.random_generator = torch.Generator() |
|
|
|
def __len__(self): |
|
return sum([len(d) for d in self.dataloaders]) |
|
|
|
def __iter__(self): |
|
|
|
self.random_generator.manual_seed(42) |
|
self._iter_dls = [iter(loader) for loader in self.dataloaders] |
|
self._iter_mixing_prob = self.mixing_prob.clone() |
|
return self |
|
|
|
def __next__(self): |
|
""" |
|
Sample a dataloader to sample from based on mixing probabilities. If one of the dataloaders is exhausted, we continue sampling from the other loaders until all are exhausted. |
|
""" |
|
if self._iter_dls is None: |
|
raise TypeError(f"{type(self).__name__} object is not an iterator") |
|
|
|
while self._iter_mixing_prob.any(): |
|
dataset_idx = self._iter_mixing_prob.multinomial( |
|
1, generator=self.random_generator |
|
).item() |
|
try: |
|
item = next(self._iter_dls[dataset_idx]) |
|
return item |
|
except StopIteration: |
|
|
|
self._iter_mixing_prob[dataset_idx] = 0 |
|
except Exception as e: |
|
|
|
logging.error(e) |
|
raise e |
|
|
|
|
|
raise StopIteration |
|
|
|
|
|
class TorchTrainMixedDataset: |
|
def __init__( |
|
self, |
|
datasets: List[Dataset], |
|
batch_sizes: List[int], |
|
num_workers: int, |
|
shuffle: bool, |
|
pin_memory: bool, |
|
drop_last: bool, |
|
collate_fn: Optional[Callable] = None, |
|
worker_init_fn: Optional[Callable] = None, |
|
phases_per_epoch: int = 1, |
|
dataset_prob: Optional[List[float]] = None, |
|
) -> None: |
|
""" |
|
Args: |
|
datasets (List[Dataset]): List of Datasets to be mixed. |
|
batch_sizes (List[int]): Batch sizes for each dataset in the list. |
|
num_workers (int): Number of workers per dataloader. |
|
shuffle (bool): Whether or not to shuffle data. |
|
pin_memory (bool): If True, use pinned memory when loading tensors from disk. |
|
drop_last (bool): Whether or not to drop the last batch of data. |
|
collate_fn (Callable): Function to merge a list of samples into a mini-batch. |
|
worker_init_fn (Callable): Function to init each dataloader worker. |
|
phases_per_epoch (int): Number of phases per epoch. |
|
dataset_prob (List[float]): Probability of choosing the dataloader to sample from. Should sum to 1.0 |
|
""" |
|
|
|
self.datasets = datasets |
|
self.batch_sizes = batch_sizes |
|
self.num_workers = num_workers |
|
self.shuffle = shuffle |
|
self.pin_memory = pin_memory |
|
self.drop_last = drop_last |
|
self.collate_fn = collate_fn |
|
self.worker_init_fn = worker_init_fn |
|
assert len(self.datasets) > 0 |
|
for dataset in self.datasets: |
|
assert not isinstance(dataset, IterableDataset), "Not supported" |
|
|
|
self._set_dataset_epoch(dataset, 0) |
|
self.phases_per_epoch = phases_per_epoch |
|
self.chunks = [None] * len(datasets) |
|
if dataset_prob is None: |
|
|
|
dataset_lens = [ |
|
(math.floor(len(d) / bs) if drop_last else math.ceil(len(d) / bs)) |
|
for d, bs in zip(datasets, batch_sizes) |
|
] |
|
total_len = sum(dataset_lens) |
|
dataset_prob = torch.tensor([d_len / total_len for d_len in dataset_lens]) |
|
else: |
|
assert len(dataset_prob) == len(datasets) |
|
dataset_prob = torch.tensor(dataset_prob) |
|
|
|
logging.info(f"Dataset mixing probabilities: {dataset_prob.tolist()}") |
|
assert dataset_prob.sum().item() == 1.0, "Probabilities should sum to 1.0" |
|
self.dataset_prob = dataset_prob |
|
|
|
def _set_dataset_epoch(self, dataset, epoch: int) -> None: |
|
if hasattr(dataset, "epoch"): |
|
dataset.epoch = epoch |
|
if hasattr(dataset, "set_epoch"): |
|
dataset.set_epoch(epoch) |
|
|
|
def get_loader(self, epoch) -> Iterable: |
|
dataloaders = [] |
|
for d_idx, (dataset, batch_size) in enumerate( |
|
zip(self.datasets, self.batch_sizes) |
|
): |
|
if self.phases_per_epoch > 1: |
|
|
|
|
|
main_epoch = epoch // self.phases_per_epoch |
|
|
|
|
|
local_phase = epoch % self.phases_per_epoch |
|
|
|
|
|
if local_phase == 0 or self.chunks[d_idx] is None: |
|
|
|
|
|
self._set_dataset_epoch(dataset, main_epoch) |
|
|
|
|
|
g = torch.Generator() |
|
g.manual_seed(main_epoch) |
|
self.chunks[d_idx] = torch.chunk( |
|
torch.randperm(len(dataset), generator=g), |
|
self.phases_per_epoch, |
|
) |
|
|
|
dataset = Subset(dataset, self.chunks[d_idx][local_phase]) |
|
else: |
|
self._set_dataset_epoch(dataset, epoch) |
|
|
|
sampler = DistributedSampler(dataset, shuffle=self.shuffle) |
|
sampler.set_epoch(epoch) |
|
|
|
batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last) |
|
dataloaders.append( |
|
DataLoader( |
|
dataset, |
|
num_workers=self.num_workers, |
|
pin_memory=self.pin_memory, |
|
batch_sampler=batch_sampler, |
|
collate_fn=self.collate_fn, |
|
worker_init_fn=self.worker_init_fn, |
|
) |
|
) |
|
return MixedDataLoader(dataloaders, self.dataset_prob) |
|
|