speaker_identify / data_proc /triplet_loss_dataset.py
DuyTa's picture
Upload folder using huggingface_hub
f831146 verified
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.datasets import DatasetFolder
class FBanksTripletDataset(Dataset):
def __init__(self, root):
self.dataset_folder = DatasetFolder(root=root, loader=FBanksTripletDataset._npy_loader, extensions='.npy')
self.len_ = len(self.dataset_folder.samples)
bin_counts = np.bincount(self.dataset_folder.targets)
self.num_classes = len(self.dataset_folder.classes)
self.label_to_index_range = {}
start = 0
for i in range(self.num_classes):
self.label_to_index_range[i] = (start, start + bin_counts[i])
start = start + bin_counts[i]
@staticmethod
def _npy_loader(path):
sample = np.load(path)
assert sample.shape[0] == 64
assert sample.shape[1] == 64
assert sample.shape[2] == 1
sample = np.moveaxis(sample, 2, 0)
sample = torch.from_numpy(sample).float()
return sample
def __getitem__(self, index):
anchor_x, anchor_y = self.dataset_folder[index]
# find a positive
start, end = self.label_to_index_range[anchor_y]
i = np.random.randint(low=start, high=end)
positive_x, positive_y = self.dataset_folder[i]
# find a negative
l_ = list(range(self.num_classes))
l_.pop(anchor_y)
ny_ = np.random.choice(l_)
start, end = self.label_to_index_range[ny_]
i = np.random.randint(low=start, high=end)
negative_x, negative_y = self.dataset_folder[i]
return (anchor_x, anchor_y), (positive_x, positive_y), (negative_x, negative_y)
def __len__(self):
return self.len_