mazpie's picture
Initial commit
2d9a728
raw
history blame
2.7 kB
import numpy as np
import torch
from torch.utils.data.distributed import DistributedSampler
# stolen from https://github.com/facebookresearch/vissl/blob/94def58538d3c7037f5e093196494331eea1a2a2/vissl/data/data_helper.py#L93
class StatefulDistributedSampler(DistributedSampler):
"""
More fine-grained state DataSampler that uses training iteration and epoch
both for shuffling data. PyTorch DistributedSampler only uses epoch
for the shuffling and starts sampling data from the start. In case of training
on very large data, we train for one epoch only and when we resume training,
we want to resume the data sampler from the training iteration.
"""
def __init__(self, dataset, batch_size=None, seed: int = 0):
"""
Initializes the instance of StatefulDistributedSampler. Random seed is set
for the epoch set and data is shuffled. For starting the sampling, use
the start_iter (set to 0 or set by checkpointing resuming) to
sample data from the remaining images.
Args:
dataset (Dataset): Pytorch dataset that sampler will shuffle
batch_size (int): batch size we want the sampler to sample
seed (int): Seed for the torch generator.
"""
super().__init__(dataset, shuffle=False, seed=seed)
self.start_iter = 0
self.batch_size = batch_size
self.total_size = len(dataset) - (len(dataset) % self.num_replicas)
self.num_samples = self.total_size // self.num_replicas
print(f"rank: {self.rank}: Sampler created...")
def __iter__(self):
# partition data into num_replicas and optionally shuffle within a rank
g = torch.Generator()
g.manual_seed(self.epoch + self.seed)
shuffling = torch.randperm(self.num_samples, generator=g).tolist()
indices = np.array(
list(
range(
(self.rank * self.num_samples), (self.rank + 1) * self.num_samples
)
)
)[shuffling].tolist()
# make sure we have correct number of samples per replica
assert len(indices) == self.num_samples
assert self.batch_size > 0, "batch_size not set for the sampler"
# resume the sampler
start_index = self.start_iter * self.batch_size
indices = indices[start_index:]
return iter(indices)
def set_start_iter(self, start_iter):
"""
Set the iteration number from which the sampling should start. This is
used to find the marker in the data permutation order from where the
sampler should start sampling.
"""
self.start_iter = start_iter