speaker_identify / data_proc /cross_entropy_dataset.py
DuyTa's picture
Upload folder using huggingface_hub
f831146 verified
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import DatasetFolder
import multiprocessing
class FBanksCrossEntropyDataset(Dataset):
def __init__(self, root):
self.dataset_folder = DatasetFolder(root=root, loader=FBanksCrossEntropyDataset._npy_loader, extensions='.npy')
self.len_ = len(self.dataset_folder.samples)
bin_counts = np.bincount(self.dataset_folder.targets)
self.num_classes = len(self.dataset_folder.classes)
self.label_to_index_range = {}
start = 0
for i in range(self.num_classes):
self.label_to_index_range[i] = (start, start + bin_counts[i])
start = start + bin_counts[i]
@staticmethod
def _npy_loader(path):
sample = np.load(path)
assert sample.shape[0] == 64
assert sample.shape[1] == 64
assert sample.shape[2] == 1
sample = np.moveaxis(sample, 2, 0) # pytorch expects input in the format in_channels x width x height
sample = torch.from_numpy(sample).float()
return sample
def __getitem__(self, index):
return self.dataset_folder[index]
def __len__(self):
return self.len_
if __name__ == '__main__':
use_cuda = False
kwargs = {'num_workers': multiprocessing.cpu_count(),
'pin_memory': True} if use_cuda else {}
data_test = FBanksCrossEntropyDataset('./dataset-speaker-csf/fbanks-test')
print(data_test.label_to_index_range)
test_loader = DataLoader(data_test, batch_size=1, shuffle=True, **kwargs)
print(next(iter(test_loader))[0].shape)