|
import logging |
|
import random |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
from ..hparams import HParams |
|
from .dataset import Dataset |
|
from .utils import mix_fg_bg, rglob_audio_files |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def _create_datasets(hp: HParams, mode, val_size=10, seed=123): |
|
paths = rglob_audio_files(hp.fg_dir) |
|
logger.info(f"Found {len(paths)} audio files in {hp.fg_dir}") |
|
|
|
random.Random(seed).shuffle(paths) |
|
train_paths = paths[:-val_size] |
|
val_paths = paths[-val_size:] |
|
|
|
train_ds = Dataset(train_paths, hp, training=True, mode=mode) |
|
val_ds = Dataset(val_paths, hp, training=False, mode=mode) |
|
|
|
logger.info(f"Train set: {len(train_ds)} samples - Val set: {len(val_ds)} samples") |
|
|
|
return train_ds, val_ds |
|
|
|
|
|
def create_dataloaders(hp: HParams, mode): |
|
train_ds, val_ds = _create_datasets(hp=hp, mode=mode) |
|
|
|
train_dl = DataLoader( |
|
train_ds, |
|
batch_size=hp.batch_size_per_gpu, |
|
shuffle=True, |
|
num_workers=hp.nj, |
|
drop_last=True, |
|
collate_fn=train_ds.collate_fn, |
|
) |
|
val_dl = DataLoader( |
|
val_ds, |
|
batch_size=1, |
|
shuffle=False, |
|
num_workers=hp.nj, |
|
drop_last=False, |
|
collate_fn=val_ds.collate_fn, |
|
) |
|
return train_dl, val_dl |
|
|