|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision.datasets import DatasetFolder |
|
import multiprocessing |
|
|
|
|
|
class FBanksCrossEntropyDataset(Dataset): |
|
def __init__(self, root): |
|
self.dataset_folder = DatasetFolder(root=root, loader=FBanksCrossEntropyDataset._npy_loader, extensions='.npy') |
|
self.len_ = len(self.dataset_folder.samples) |
|
|
|
bin_counts = np.bincount(self.dataset_folder.targets) |
|
self.num_classes = len(self.dataset_folder.classes) |
|
self.label_to_index_range = {} |
|
start = 0 |
|
for i in range(self.num_classes): |
|
self.label_to_index_range[i] = (start, start + bin_counts[i]) |
|
start = start + bin_counts[i] |
|
|
|
@staticmethod |
|
def _npy_loader(path): |
|
sample = np.load(path) |
|
assert sample.shape[0] == 64 |
|
assert sample.shape[1] == 64 |
|
assert sample.shape[2] == 1 |
|
|
|
sample = np.moveaxis(sample, 2, 0) |
|
sample = torch.from_numpy(sample).float() |
|
|
|
return sample |
|
|
|
def __getitem__(self, index): |
|
return self.dataset_folder[index] |
|
|
|
def __len__(self): |
|
return self.len_ |
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
use_cuda = False |
|
kwargs = {'num_workers': multiprocessing.cpu_count(), |
|
'pin_memory': True} if use_cuda else {} |
|
|
|
data_test = FBanksCrossEntropyDataset('./dataset-speaker-csf/fbanks-test') |
|
print(data_test.label_to_index_range) |
|
test_loader = DataLoader(data_test, batch_size=1, shuffle=True, **kwargs) |
|
print(next(iter(test_loader))[0].shape) |