|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset |
|
from torchvision.datasets import DatasetFolder |
|
|
|
|
|
class FBanksTripletDataset(Dataset): |
|
def __init__(self, root): |
|
self.dataset_folder = DatasetFolder(root=root, loader=FBanksTripletDataset._npy_loader, extensions='.npy') |
|
self.len_ = len(self.dataset_folder.samples) |
|
bin_counts = np.bincount(self.dataset_folder.targets) |
|
self.num_classes = len(self.dataset_folder.classes) |
|
self.label_to_index_range = {} |
|
start = 0 |
|
for i in range(self.num_classes): |
|
self.label_to_index_range[i] = (start, start + bin_counts[i]) |
|
start = start + bin_counts[i] |
|
|
|
@staticmethod |
|
def _npy_loader(path): |
|
sample = np.load(path) |
|
assert sample.shape[0] == 64 |
|
assert sample.shape[1] == 64 |
|
assert sample.shape[2] == 1 |
|
|
|
sample = np.moveaxis(sample, 2, 0) |
|
sample = torch.from_numpy(sample).float() |
|
|
|
return sample |
|
|
|
def __getitem__(self, index): |
|
anchor_x, anchor_y = self.dataset_folder[index] |
|
|
|
|
|
start, end = self.label_to_index_range[anchor_y] |
|
i = np.random.randint(low=start, high=end) |
|
positive_x, positive_y = self.dataset_folder[i] |
|
|
|
|
|
l_ = list(range(self.num_classes)) |
|
l_.pop(anchor_y) |
|
ny_ = np.random.choice(l_) |
|
start, end = self.label_to_index_range[ny_] |
|
i = np.random.randint(low=start, high=end) |
|
negative_x, negative_y = self.dataset_folder[i] |
|
|
|
return (anchor_x, anchor_y), (positive_x, positive_y), (negative_x, negative_y) |
|
|
|
def __len__(self): |
|
return self.len_ |