File size: 1,267 Bytes
88b5dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import logging
import random

from torch.utils.data import DataLoader

from ..hparams import HParams
from .dataset import Dataset
from .utils import mix_fg_bg, rglob_audio_files

logger = logging.getLogger(__name__)


def _create_datasets(hp: HParams, mode, val_size=10, seed=123):
    paths = rglob_audio_files(hp.fg_dir)
    logger.info(f"Found {len(paths)} audio files in {hp.fg_dir}")

    random.Random(seed).shuffle(paths)
    train_paths = paths[:-val_size]
    val_paths = paths[-val_size:]

    train_ds = Dataset(train_paths, hp, training=True, mode=mode)
    val_ds = Dataset(val_paths, hp, training=False, mode=mode)

    logger.info(f"Train set: {len(train_ds)} samples - Val set: {len(val_ds)} samples")

    return train_ds, val_ds


def create_dataloaders(hp: HParams, mode):
    train_ds, val_ds = _create_datasets(hp=hp, mode=mode)

    train_dl = DataLoader(
        train_ds,
        batch_size=hp.batch_size_per_gpu,
        shuffle=True,
        num_workers=hp.nj,
        drop_last=True,
        collate_fn=train_ds.collate_fn,
    )
    val_dl = DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=hp.nj,
        drop_last=False,
        collate_fn=val_ds.collate_fn,
    )
    return train_dl, val_dl