File size: 1,944 Bytes
b213d84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
from torch.utils.data.sampler import BatchSampler, Sampler
class GroupedBatchSampler(BatchSampler):
"""
Wraps another sampler to yield a mini-batch of indices.
It enforces that the batch only contain elements from the same group.
It also tries to provide mini-batches which follows an ordering which is
as close as possible to the ordering from the original sampler.
"""
def __init__(self, sampler, group_ids, batch_size):
"""
Args:
sampler (Sampler): Base sampler.
group_ids (list[int]): If the sampler produces indices in range [0, N),
`group_ids` must be a list of `N` ints which contains the group id of each sample.
The group ids must be a set of integers in the range [0, num_groups).
batch_size (int): Size of mini-batch.
"""
if not isinstance(sampler, Sampler):
raise ValueError(
"sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
)
self.sampler = sampler
self.group_ids = np.asarray(group_ids)
assert self.group_ids.ndim == 1
self.batch_size = batch_size
groups = np.unique(self.group_ids).tolist()
# buffer the indices of each group until batch size is reached
self.buffer_per_group = {k: [] for k in groups}
def __iter__(self):
for idx in self.sampler:
group_id = self.group_ids[idx]
group_buffer = self.buffer_per_group[group_id]
group_buffer.append(idx)
if len(group_buffer) == self.batch_size:
yield group_buffer[:] # yield a copy of the list
del group_buffer[:]
def __len__(self):
raise NotImplementedError("len() of GroupedBatchSampler is not well-defined.")
|