Spaces:
Sleeping
Sleeping
import logging | |
import random | |
from pathlib import Path | |
import numpy as np | |
import torch | |
import torchaudio | |
import torchaudio.functional as AF | |
from torch.nn.utils.rnn import pad_sequence | |
from torch.utils.data import Dataset as DatasetBase | |
from ..hparams import HParams | |
from .distorter import Distorter | |
from .utils import rglob_audio_files | |
logger = logging.getLogger(__name__) | |
def _normalize(x): | |
return x / (np.abs(x).max() + 1e-7) | |
def _collate(batch, key, tensor=True, pad=True): | |
l = [d[key] for d in batch] | |
if l[0] is None: | |
return None | |
if tensor: | |
l = [torch.from_numpy(x) for x in l] | |
if pad: | |
assert tensor, "Can't pad non-tensor" | |
l = pad_sequence(l, batch_first=True) | |
return l | |
def praat_augment(wav, sr): | |
try: | |
import parselmouth | |
except ImportError: | |
raise ImportError("Please install parselmouth>=0.5.0 to use Praat augmentation") | |
# "praat-parselmouth @ git+https://github.com/YannickJadoul/Parselmouth@0bbcca69705ed73322f3712b19d71bb3694b2540", | |
# https://github.com/YannickJadoul/Parselmouth/issues/68 | |
# note that this function may hang if the praat version is 0.4.3 | |
assert wav.ndim == 1, f"wav.ndim must be 1 but got {wav.ndim}" | |
sound = parselmouth.Sound(wav, sr) | |
formant_shift_ratio = random.uniform(1.1, 1.5) | |
pitch_range_factor = random.uniform(0.5, 2.0) | |
sound = parselmouth.praat.call( | |
sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0 | |
) | |
wav = np.array(sound.values)[0].astype(np.float32) | |
return wav | |
class Dataset(DatasetBase): | |
def __init__( | |
self, | |
fg_paths: list[Path], | |
hp: HParams, | |
training=True, | |
max_retries=100, | |
silent_fg_prob=0.01, | |
mode=False, | |
): | |
super().__init__() | |
assert mode in ("enhancer", "denoiser"), f"Invalid mode: {mode}" | |
self.hp = hp | |
self.fg_paths = fg_paths | |
self.bg_paths = rglob_audio_files(hp.bg_dir) | |
if len(self.fg_paths) == 0: | |
raise ValueError(f"No foreground audio files found in {hp.fg_dir}") | |
if len(self.bg_paths) == 0: | |
raise ValueError(f"No background audio files found in {hp.bg_dir}") | |
logger.info( | |
f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files" | |
) | |
self.training = training | |
self.max_retries = max_retries | |
self.silent_fg_prob = silent_fg_prob | |
self.mode = mode | |
self.distorter = Distorter(hp, training=training, mode=mode) | |
def _load_wav(self, path, length=None, random_crop=True): | |
wav, sr = torchaudio.load(path) | |
wav = AF.resample( | |
waveform=wav, | |
orig_freq=sr, | |
new_freq=self.hp.wav_rate, | |
lowpass_filter_width=64, | |
rolloff=0.9475937167399596, | |
resampling_method="sinc_interp_kaiser", | |
beta=14.769656459379492, | |
) | |
wav = wav.float().numpy() | |
if wav.ndim == 2: | |
wav = np.mean(wav, axis=0) | |
if length is None and self.training: | |
length = int(self.hp.training_seconds * self.hp.wav_rate) | |
if length is not None: | |
if random_crop: | |
start = random.randint(0, max(0, len(wav) - length)) | |
wav = wav[start : start + length] | |
else: | |
wav = wav[:length] | |
if length is not None and len(wav) < length: | |
wav = np.pad(wav, (0, length - len(wav))) | |
wav = _normalize(wav) | |
return wav | |
def _getitem_unsafe(self, index: int): | |
fg_path = self.fg_paths[index] | |
if self.training and random.random() < self.silent_fg_prob: | |
fg_wav = np.zeros( | |
int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32 | |
) | |
else: | |
fg_wav = self._load_wav(fg_path) | |
if random.random() < self.hp.praat_augment_prob and self.training: | |
fg_wav = praat_augment(fg_wav, self.hp.wav_rate) | |
if self.hp.load_fg_only: | |
bg_wav = None | |
fg_dwav = None | |
bg_dwav = None | |
else: | |
fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype( | |
np.float32 | |
) | |
if self.training: | |
bg_path = random.choice(self.bg_paths) | |
else: | |
# Deterministic for validation | |
bg_path = self.bg_paths[index % len(self.bg_paths)] | |
bg_wav = self._load_wav( | |
bg_path, length=len(fg_wav), random_crop=self.training | |
) | |
bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype( | |
np.float32 | |
) | |
return dict( | |
fg_wav=fg_wav, | |
bg_wav=bg_wav, | |
fg_dwav=fg_dwav, | |
bg_dwav=bg_dwav, | |
) | |
def __getitem__(self, index: int): | |
for i in range(self.max_retries): | |
try: | |
return self._getitem_unsafe(index) | |
except Exception as e: | |
if i == self.max_retries - 1: | |
raise RuntimeError( | |
f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries" | |
) from e | |
logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping") | |
index = np.random.randint(0, len(self)) | |
def __len__(self): | |
return len(self.fg_paths) | |
def collate_fn(batch): | |
return dict( | |
fg_wavs=_collate(batch, "fg_wav"), | |
bg_wavs=_collate(batch, "bg_wav"), | |
fg_dwavs=_collate(batch, "fg_dwav"), | |
bg_dwavs=_collate(batch, "bg_dwav"), | |
) | |