File size: 1,703 Bytes
f831146 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.datasets import DatasetFolder
class FBanksTripletDataset(Dataset):
def __init__(self, root):
self.dataset_folder = DatasetFolder(root=root, loader=FBanksTripletDataset._npy_loader, extensions='.npy')
self.len_ = len(self.dataset_folder.samples)
bin_counts = np.bincount(self.dataset_folder.targets)
self.num_classes = len(self.dataset_folder.classes)
self.label_to_index_range = {}
start = 0
for i in range(self.num_classes):
self.label_to_index_range[i] = (start, start + bin_counts[i])
start = start + bin_counts[i]
@staticmethod
def _npy_loader(path):
sample = np.load(path)
assert sample.shape[0] == 64
assert sample.shape[1] == 64
assert sample.shape[2] == 1
sample = np.moveaxis(sample, 2, 0)
sample = torch.from_numpy(sample).float()
return sample
def __getitem__(self, index):
anchor_x, anchor_y = self.dataset_folder[index]
# find a positive
start, end = self.label_to_index_range[anchor_y]
i = np.random.randint(low=start, high=end)
positive_x, positive_y = self.dataset_folder[i]
# find a negative
l_ = list(range(self.num_classes))
l_.pop(anchor_y)
ny_ = np.random.choice(l_)
start, end = self.label_to_index_range[ny_]
i = np.random.randint(low=start, high=end)
negative_x, negative_y = self.dataset_folder[i]
return (anchor_x, anchor_y), (positive_x, positive_y), (negative_x, negative_y)
def __len__(self):
return self.len_ |