import math import os import random import torch import torch.utils.data import numpy as np from librosa.util import normalize from scipy.io.wavfile import read import scipy import librosa import wave from pydub import AudioSegment MAX_WAV_VALUE = 32768.0 def load_wav(full_path): try: sampling_rate, data = read(full_path) if max(data.shape) / sampling_rate < 0.5: return None, None except FileNotFoundError: print(f"File not found: {file_path}") return None, None except Exception as e: print(f"An unexpected error occurred: {e}") return None, None if len(data.shape) > 1: if data.shape[1] <= 2: data = data[...,0] else: data = data[0,...] return data / MAX_WAV_VALUE, sampling_rate def get_wave_duration(file_path): """ Gets the duration of a WAV file in seconds. :param file_path: Path to the WAV file. :return: Duration of the WAV file in seconds. """ try: with wave.open(file_path, 'rb') as wf: # Get the number of frames num_frames = wf.getnframes() # Get the frame rate frame_rate = wf.getframerate() # Calculate duration duration = num_frames / float(frame_rate) return duration, frame_rate, num_frames except wave.Error as e: print(f"Error reading {file_path}: {e}") return None, None, None except FileNotFoundError: print(f"File not found: {file_path}") return None, None, None except Exception as e: print(f"An unexpected error occurred: {e}") return None, None, None def read_audio_segment(file_path, start_ms, end_ms): """ Reads a segment from a WAV file and returns the raw data and its properties. :param file_path: Path to the WAV file. :param start_ms: Start time of the segment in milliseconds. :param end_ms: End time of the segment in milliseconds. :return: A tuple containing the raw audio data, frame rate, sample width, and number of channels. """ #start_time = time.time() try: # Load the audio file audio = AudioSegment.from_wav(file_path) # Extract the segment segment = audio[start_ms:end_ms] # Get raw audio data raw_data = segment.raw_data # Get audio properties frame_rate = segment.frame_rate sample_width = segment.sample_width channels = segment.channels # Create NumPy array from the raw audio data audio_array = np.frombuffer(raw_data, dtype=np.int16) # If stereo, reshape the array to have a second dimension if channels > 1: audio_array = audio_array.reshape((-1, channels)) audio_array = audio_array[...,0] ''' if frame_rate !=48000: audio_array = audio_array/MAX_WAV_VALUE audio_array = librosa.resample(audio_array, frame_rate, 48000) audio_array = audio_array * MAX_WAV_VALUE frame_rate = 48000 ''' #end_time = time.time() #time_taken = end_time - start_time #print(f"Successfully read segment from {start_ms}ms to {end_ms}ms in {time_taken:.4f} seconds") return audio_array / MAX_WAV_VALUE#, frame_rate #, sample_width, channels except Exception as e: print(f"An error occurred: {e}") return None#, None #, None, None def resample(audio, sr_in, sr_out, target_len=None): #audio = audio / MAX_WAV_VALUE #audio = normalize(audio) * 0.95 if target_len is not None: audio = scipy.signal.resample(audio, target_len) return audio resample_factor = sr_out / sr_in new_samples = int(len(audio) * resample_factor) audio = scipy.signal.resample(audio, new_samples) return audio def load_segment(full_path, target_sampling_rate=None, segment_size=None): if segment_size is not None: dur,sampling_rate,len_data = get_wave_duration(full_path) if sampling_rate is None: return None, None if sampling_rate < 44100: return None, None target_dur = segment_size / target_sampling_rate if dur < target_dur: data, sampling_rate = load_wav(full_path) #print(f'data_read: {data.shape}, sampling_rate: {sampling_rate}') if data is None: return None, None if target_sampling_rate is not None and sampling_rate != target_sampling_rate: data = resample(data, sampling_rate, target_sampling_rate) sampling_rate = target_sampling_rate data = torch.FloatTensor(data) data = data.unsqueeze(0) data = torch.nn.functional.pad(data, (0, segment_size - data.size(1)), 'constant') data = data.squeeze(0) return data.numpy(), sampling_rate else: dur,sampling_rate,len_data = get_wave_duration(full_path) if sampling_rate < 44100: return None, None target_dur = segment_size / target_sampling_rate target_len = int(target_dur * sampling_rate) start_idx = random.randint(0, (len_data - target_len)) start_ms = start_idx / sampling_rate * 1000 end_ms = start_ms + target_dur * 1000 data = read_audio_segment(full_path, start_ms, end_ms) #print(f'data_read: {data.shape}, sampling_rate: {sampling_rate}') if data is None: return None, None if target_sampling_rate is not None and sampling_rate != target_sampling_rate: data = resample(data, sampling_rate, target_sampling_rate) sampling_rate = target_sampling_rate if len(data) < segment_size: data = torch.FloatTensor(data) data = data.unsqueeze(0) data = torch.nn.functional.pad(data, (0, segment_size - data.size(1)), 'constant') data = data.squeeze(0) data = data.numpy() else: start_idx = random.randint(0, (len(data) - segment_size)) data = data[start_idx:start_idx+segment_size] #print(f'data_cut: {data.shape}') return data, sampling_rate else: dur,sampling_rate,len_data = get_wave_duration(full_path) if sampling_rate is None: return None, None if sampling_rate < 44100: return None, None data, sampling_rate = load_wav(full_path) if data is None: return None, None if target_sampling_rate is not None and sampling_rate != target_sampling_rate: data = resample(data, sampling_rate, target_sampling_rate) sampling_rate = target_sampling_rate return data, sampling_rate def dynamic_range_compression(x, C=1, clip_val=1e-5): return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) def dynamic_range_decompression(x, C=1): return np.exp(x) / C def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): return torch.log(torch.clamp(x, min=clip_val) * C) def dynamic_range_decompression_torch(x, C=1): return torch.exp(x) / C def spectral_normalize_torch(magnitudes): output = dynamic_range_compression_torch(magnitudes) return output def spectral_de_normalize_torch(magnitudes): output = dynamic_range_decompression_torch(magnitudes) return output mel_basis = {} hann_window = {} def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): ''' if torch.min(y) < -1.: print('min value is ', torch.min(y)) if torch.max(y) > 1.: print('max value is ', torch.max(y)) ''' global mel_basis, hann_window if fmax not in mel_basis: #mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) # sr, n_fft, n_mels=128, fmin=0.0, fmax mel = librosa.filters.mel(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') y = y.squeeze(1) spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) spec = spectral_normalize_torch(spec) return spec def get_dataset_filelist_org(a): with open(a.input_training_file, 'r', encoding='utf-8') as fi: training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') for x in fi.read().split('\n') if len(x) > 0] with open(a.input_validation_file, 'r', encoding='utf-8') as fi: validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') for x in fi.read().split('\n') if len(x) > 0] return training_files, validation_files def get_dataset_filelist(a): with open(a.input_training_file, 'r', encoding='utf-8') as fi: training_files = [x for x in fi.read().split('\n') if len(x) > 0] with open(a.input_validation_file, 'r', encoding='utf-8') as fi: validation_files = [x for x in fi.read().split('\n') if len(x) > 0] return training_files, validation_files class MelDataset(torch.utils.data.Dataset): def __init__(self, training_files, segment_size, n_fft, num_mels, hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1, device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None): self.audio_files = training_files random.seed(1234) if shuffle: random.shuffle(self.audio_files) self.segment_size = segment_size self.sampling_rate = sampling_rate self.split = split self.n_fft = n_fft self.num_mels = num_mels self.hop_size = hop_size self.win_size = win_size self.fmin = fmin self.fmax = fmax self.fmax_loss = fmax_loss self.cached_wav = None self.n_cache_reuse = n_cache_reuse self._cache_ref_count = 0 self.device = device self.fine_tuning = fine_tuning self.base_mels_path = base_mels_path self.supported_samples = [16000, 22050, 24000] #[4000, 8000, 16000, 22050, 24000, 32000] #self.supported_samples = [4000, 8000] #, 16000, 22050, 24000, 32000] def __getitem__(self, index): filename = self.audio_files[index] while 1: #audio, sampling_rate = load_wav(filename) audio, sampling_rate = load_segment(filename, self.sampling_rate, self.segment_size) if audio is not None: break else: filename = self.audio_files[random.randint(0,index)] #audio, sampling_rate = load_wav(filename) #audio, sampling_rate = load_segment(filename, self.sampling_rate, self.segment_size) #audio = audio / MAX_WAV_VALUE if not self.fine_tuning: audio = normalize(audio) * 0.95 sr_out = random.choice(self.supported_samples) audio_down = resample(audio, self.sampling_rate, sr_out) target_len = len(audio) #/ downsample_factor audio_up = resample(audio_down, None, None, target_len) audio = torch.FloatTensor(audio) audio = audio.unsqueeze(0) audio_up = torch.FloatTensor(audio_up) audio_up = audio_up.unsqueeze(0) mel = mel_spectrogram(audio_up, self.n_fft, self.num_mels, self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, center=False) mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels, self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss, center=False) return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) def __getitem__org(self, index): filename = self.audio_files[index] if self._cache_ref_count == 0: while 1: audio, sampling_rate = load_wav(filename) if audio is not None: break else: filename = self.audio_files[random.randint(0,index)] audio, sampling_rate = load_wav(filename) audio = audio / MAX_WAV_VALUE if not self.fine_tuning: audio = normalize(audio) * 0.95 #self.cached_wav = audio if sampling_rate != self.sampling_rate: resample_factor = self.sampling_rate / sampling_rate new_samples = int(len(audio) * resample_factor) audio = scipy.signal.resample(audio, new_samples)#.astype(np.int16) #raise ValueError("{} SR doesn't match target {} SR".format( # sampling_rate, self.sampling_rate)) downsample_factor = 16000 / self.sampling_rate new_samples = int(len(audio) * downsample_factor) audio_down = scipy.signal.resample(audio, new_samples) new_samples = len(audio) #/ downsample_factor audio_up = scipy.signal.resample(audio_down, new_samples) #print(f'audio: {audio.shape}, audio_up: {audio_up.shape}') #min_idx = min(len(audio), len(audio_up)) #audio = audio[:min_idx] #audio_up = audio_up[:min_idx] self.cached_wav = audio self.cached_wav_up = audio_up self._cache_ref_count = self.n_cache_reuse else: audio = self.cached_wav audio_up = self.cached_wav_up self._cache_ref_count -= 1 audio = torch.FloatTensor(audio) audio = audio.unsqueeze(0) audio_up = torch.FloatTensor(audio_up) audio_up = audio_up.unsqueeze(0) if True: if self.split: if audio.size(1) >= self.segment_size: max_audio_start = audio.size(1) - self.segment_size audio_start = random.randint(0, max_audio_start) audio = audio[:, audio_start:audio_start+self.segment_size] audio_up = audio_up[:, audio_start:audio_start+self.segment_size] else: audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant') audio_up = torch.nn.functional.pad(audio_up, (0, self.segment_size - audio_up.size(1)), 'constant') mel = mel_spectrogram(audio_up, self.n_fft, self.num_mels, self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax, center=False) mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels, self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss, center=False) return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) def __len__(self): return len(self.audio_files)