File size: 4,293 Bytes
404d2af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import torch
from torch.utils.data import Sampler, ConcatDataset
class RandomConcatSampler(Sampler):
""" Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
Args:
shuffle (bool): shuffle the random sampled indices across all sub-datsets.
repeat (int): repeatedly use the sampled indices multiple times for training.
[arXiv:1902.05509, arXiv:1901.09335]
NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples)
NOTE: This sampler behaves differently with DistributedSampler.
It assume the dataset is splitted across ranks instead of replicated.
TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
"""
def __init__(self,
data_source: ConcatDataset,
n_samples_per_subset: int,
subset_replacement: bool=True,
shuffle: bool=True,
repeat: int=1,
seed: int=None):
if not isinstance(data_source, ConcatDataset):
raise TypeError("data_source should be torch.utils.data.ConcatDataset")
self.data_source = data_source
self.n_subset = len(self.data_source.datasets)
self.n_samples_per_subset = n_samples_per_subset
self.n_samples = self.n_subset * self.n_samples_per_subset * repeat
self.subset_replacement = subset_replacement
self.repeat = repeat
self.shuffle = shuffle
self.generator = torch.manual_seed(seed)
assert self.repeat >= 1
def __len__(self):
return self.n_samples
def __iter__(self):
indices = []
# sample from each sub-dataset
for d_idx in range(self.n_subset):
low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1]
high = self.data_source.cumulative_sizes[d_idx]
if self.subset_replacement:
rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ),
generator=self.generator, dtype=torch.int64)
else: # sample without replacement
len_subset = len(self.data_source.datasets[d_idx])
rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
if len_subset >= self.n_samples_per_subset:
rand_tensor = rand_tensor[:self.n_samples_per_subset]
else: # padding with replacement
rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ),
generator=self.generator, dtype=torch.int64)
rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
indices.append(rand_tensor)
indices = torch.cat(indices)
if self.shuffle: # shuffle the sampled dataset (from multiple subsets)
rand_tensor = torch.randperm(len(indices), generator=self.generator)
indices = indices[rand_tensor]
# repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling)
if self.repeat > 1:
repeat_indices = [indices.clone() for _ in range(self.repeat - 1)]
if self.shuffle:
_choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
repeat_indices = map(_choice, repeat_indices)
indices = torch.cat([indices, *repeat_indices], 0)
assert indices.shape[0] == self.n_samples
return iter(indices.tolist())
|