OpenMusic / qa_mdt /audioldm_train /dataset_plugin.py
jadechoghari's picture
add model
9b9e0ee verified
import os
import torch
import numpy as np
import torchaudio
import matplotlib.pyplot as plt
CACHE = {
"get_vits_phoneme_ids": {
"PAD_LENGTH": 310,
"_pad": "_",
"_punctuation": ';:,.!?¡¿—…"«»“” ',
"_letters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
"_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ",
"_special": "♪☎☒☝⚠",
}
}
CACHE["get_vits_phoneme_ids"]["symbols"] = (
[CACHE["get_vits_phoneme_ids"]["_pad"]]
+ list(CACHE["get_vits_phoneme_ids"]["_punctuation"])
+ list(CACHE["get_vits_phoneme_ids"]["_letters"])
+ list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"])
+ list(CACHE["get_vits_phoneme_ids"]["_special"])
)
CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = {
s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"])
}
def get_vits_phoneme_ids(config, dl_output, metadata):
pad_token_id = 0
pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
_symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]
assert (
"phonemes" in metadata.keys()
), "You must provide vits phonemes on using addon get_vits_phoneme_ids"
clean_text = metadata["phonemes"]
sequence = []
for symbol in clean_text:
symbol_id = _symbol_to_id[symbol]
sequence += [symbol_id]
inserted_zero_sequence = [0] * (len(sequence) * 2)
inserted_zero_sequence[1::2] = sequence
inserted_zero_sequence = inserted_zero_sequence + [0]
def _pad_phonemes(phonemes_list):
return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list))
return {"phoneme_idx": torch.LongTensor(_pad_phonemes(inserted_zero_sequence))}
def get_vits_phoneme_ids_no_padding(config, dl_output, metadata):
pad_token_id = 0
pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
_symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]
assert (
"phonemes" in metadata.keys()
), "You must provide vits phonemes on using addon get_vits_phoneme_ids"
clean_text = metadata["phonemes"] + "⚠"
sequence = []
for symbol in clean_text:
if symbol not in _symbol_to_id.keys():
print("%s is not in the vocabulary. %s" % (symbol, clean_text))
symbol = "_"
symbol_id = _symbol_to_id[symbol]
sequence += [symbol_id]
def _pad_phonemes(phonemes_list):
return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list))
sequence = sequence[:pad_length]
return {"phoneme_idx": torch.LongTensor(_pad_phonemes(sequence))}
def calculate_relative_bandwidth(config, dl_output, metadata):
assert "stft" in dl_output.keys()
# The last dimension of the stft feature is the frequency dimension
freq_dimensions = dl_output["stft"].size(-1)
freq_energy_dist = torch.sum(dl_output["stft"], dim=0)
freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0)
total_energy = freq_energy_dist[-1]
percentile_5th = total_energy * 0.05
percentile_95th = total_energy * 0.95
lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist))
higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist))
lower_idx = int((lower_idx / freq_dimensions) * 1000)
higher_idx = int((higher_idx / freq_dimensions) * 1000)
return {"freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx])}
def calculate_mel_spec_relative_bandwidth_as_extra_channel(config, dl_output, metadata):
assert "stft" in dl_output.keys()
linear_mel_spec = torch.exp(torch.clip(dl_output["log_mel_spec"], max=10))
# The last dimension of the stft feature is the frequency dimension
freq_dimensions = linear_mel_spec.size(-1)
freq_energy_dist = torch.sum(linear_mel_spec, dim=0)
freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0)
total_energy = freq_energy_dist[-1]
percentile_5th = total_energy * 0.05
percentile_95th = total_energy * 0.95
lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist))
higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist))
latent_t_size = config["model"]["params"]["latent_t_size"]
latent_f_size = config["model"]["params"]["latent_f_size"]
lower_idx = int(latent_f_size * float((lower_idx / freq_dimensions)))
higher_idx = int(latent_f_size * float((higher_idx / freq_dimensions)))
bandwidth_condition = torch.zeros((latent_t_size, latent_f_size))
bandwidth_condition[:, lower_idx:higher_idx] += 1.0
return {
"mel_spec_bandwidth_cond_extra_channel": bandwidth_condition,
"freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx]),
}
def waveform_rs_48k(config, dl_output, metadata):
waveform = dl_output["waveform"] # [1, samples]
sampling_rate = dl_output["sampling_rate"]
if sampling_rate != 48000:
waveform_48k = torchaudio.functional.resample(
waveform, orig_freq=sampling_rate, new_freq=48000
)
else:
waveform_48k = waveform
return {"waveform_48k": waveform_48k}
def extract_vits_phoneme_and_flant5_text(config, dl_output, metadata):
assert (
"phoneme" not in metadata.keys()
), "The metadata of speech you use seems belong to fastspeech. Please check dataset_root.json"
if "phonemes" in metadata.keys():
new_item = get_vits_phoneme_ids_no_padding(config, dl_output, metadata)
new_item["text"] = "" # We assume TTS data does not have text description
else:
fake_metadata = {"phonemes": ""} # Add empty phoneme sequence
new_item = get_vits_phoneme_ids_no_padding(config, dl_output, fake_metadata)
return new_item
def extract_fs2_phoneme_and_flant5_text(config, dl_output, metadata):
if "phoneme" in metadata.keys():
new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata)
new_item["text"] = ""
else:
fake_metadata = {"phoneme": []}
new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, fake_metadata)
return new_item
def extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata):
PAD_LENGTH = 135
phonemes_lookup_dict = {
"K": 0,
"IH2": 1,
"NG": 2,
"OW2": 3,
"AH2": 4,
"F": 5,
"AE0": 6,
"IY0": 7,
"SH": 8,
"G": 9,
"W": 10,
"UW1": 11,
"AO2": 12,
"AW2": 13,
"UW0": 14,
"EY2": 15,
"UW2": 16,
"AE2": 17,
"IH0": 18,
"P": 19,
"D": 20,
"ER1": 21,
"AA1": 22,
"EH0": 23,
"UH1": 24,
"N": 25,
"V": 26,
"AY1": 27,
"EY1": 28,
"UH2": 29,
"EH1": 30,
"L": 31,
"AA2": 32,
"R": 33,
"OY1": 34,
"Y": 35,
"ER2": 36,
"S": 37,
"AE1": 38,
"AH1": 39,
"JH": 40,
"ER0": 41,
"EH2": 42,
"IY2": 43,
"OY2": 44,
"AW1": 45,
"IH1": 46,
"IY1": 47,
"OW0": 48,
"AO0": 49,
"AY0": 50,
"EY0": 51,
"AY2": 52,
"UH0": 53,
"M": 54,
"TH": 55,
"T": 56,
"OY0": 57,
"AW0": 58,
"DH": 59,
"Z": 60,
"spn": 61,
"AH0": 62,
"sp": 63,
"AO1": 64,
"OW1": 65,
"ZH": 66,
"B": 67,
"AA0": 68,
"CH": 69,
"HH": 70,
}
pad_token_id = len(phonemes_lookup_dict.keys())
assert (
"phoneme" in metadata.keys()
), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset"
phonemes = [
phonemes_lookup_dict[x]
for x in metadata["phoneme"]
if (x in phonemes_lookup_dict.keys())
]
if (len(phonemes) / PAD_LENGTH) > 5:
print(
"Warning: Phonemes length is too long and is truncated too much! %s"
% metadata
)
phonemes = phonemes[:PAD_LENGTH]
def _pad_phonemes(phonemes_list):
return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list))
return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))}
def extract_phoneme_g2p_en_feature(config, dl_output, metadata):
PAD_LENGTH = 250
phonemes_lookup_dict = {
" ": 0,
"AA": 1,
"AE": 2,
"AH": 3,
"AO": 4,
"AW": 5,
"AY": 6,
"B": 7,
"CH": 8,
"D": 9,
"DH": 10,
"EH": 11,
"ER": 12,
"EY": 13,
"F": 14,
"G": 15,
"HH": 16,
"IH": 17,
"IY": 18,
"JH": 19,
"K": 20,
"L": 21,
"M": 22,
"N": 23,
"NG": 24,
"OW": 25,
"OY": 26,
"P": 27,
"R": 28,
"S": 29,
"SH": 30,
"T": 31,
"TH": 32,
"UH": 33,
"UW": 34,
"V": 35,
"W": 36,
"Y": 37,
"Z": 38,
"ZH": 39,
}
pad_token_id = len(phonemes_lookup_dict.keys())
assert (
"phoneme" in metadata.keys()
), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset"
phonemes = [
phonemes_lookup_dict[x]
for x in metadata["phoneme"]
if (x in phonemes_lookup_dict.keys())
]
if (len(phonemes) / PAD_LENGTH) > 5:
print(
"Warning: Phonemes length is too long and is truncated too much! %s"
% metadata
)
phonemes = phonemes[:PAD_LENGTH]
def _pad_phonemes(phonemes_list):
return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list))
return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))}
def extract_kaldi_fbank_feature(config, dl_output, metadata):
norm_mean = -4.2677393
norm_std = 4.5689974
waveform = dl_output["waveform"] # [1, samples]
sampling_rate = dl_output["sampling_rate"]
log_mel_spec_hifigan = dl_output["log_mel_spec"]
if sampling_rate != 16000:
waveform_16k = torchaudio.functional.resample(
waveform, orig_freq=sampling_rate, new_freq=16000
)
else:
waveform_16k = waveform
waveform_16k = waveform_16k - waveform_16k.mean()
fbank = torchaudio.compliance.kaldi.fbank(
waveform_16k,
htk_compat=True,
sample_frequency=16000,
use_energy=False,
window_type="hanning",
num_mel_bins=128,
dither=0.0,
frame_shift=10,
)
TARGET_LEN = log_mel_spec_hifigan.size(0)
# cut and pad
n_frames = fbank.shape[0]
p = TARGET_LEN - n_frames
if p > 0:
m = torch.nn.ZeroPad2d((0, 0, 0, p))
fbank = m(fbank)
elif p < 0:
fbank = fbank[:TARGET_LEN, :]
fbank = (fbank - norm_mean) / (norm_std * 2)
return {"ta_kaldi_fbank": fbank} # [1024, 128]
def extract_kaldi_fbank_feature_32k(config, dl_output, metadata):
norm_mean = -4.2677393
norm_std = 4.5689974
waveform = dl_output["waveform"] # [1, samples]
sampling_rate = dl_output["sampling_rate"]
log_mel_spec_hifigan = dl_output["log_mel_spec"]
if sampling_rate != 32000:
waveform_32k = torchaudio.functional.resample(
waveform, orig_freq=sampling_rate, new_freq=32000
)
else:
waveform_32k = waveform
waveform_32k = waveform_32k - waveform_32k.mean()
fbank = torchaudio.compliance.kaldi.fbank(
waveform_32k,
htk_compat=True,
sample_frequency=32000,
use_energy=False,
window_type="hanning",
num_mel_bins=128,
dither=0.0,
frame_shift=10,
)
TARGET_LEN = log_mel_spec_hifigan.size(0)
# cut and pad
n_frames = fbank.shape[0]
p = TARGET_LEN - n_frames
if p > 0:
m = torch.nn.ZeroPad2d((0, 0, 0, p))
fbank = m(fbank)
elif p < 0:
fbank = fbank[:TARGET_LEN, :]
fbank = (fbank - norm_mean) / (norm_std * 2)
return {"ta_kaldi_fbank": fbank} # [1024, 128]
# Use the beat and downbeat information as music conditions
def extract_drum_beat(config, dl_output, metadata):
def visualization(conditional_signal, mel_spectrogram, filename):
import soundfile as sf
sf.write(
os.path.basename(dl_output["fname"]),
np.array(dl_output["waveform"])[0],
dl_output["sampling_rate"],
)
plt.figure(figsize=(10, 10))
plt.subplot(211)
plt.imshow(np.array(conditional_signal).T, aspect="auto")
plt.title("Conditional Signal")
plt.subplot(212)
plt.imshow(np.array(mel_spectrogram).T, aspect="auto")
plt.title("Mel Spectrogram")
plt.savefig(filename)
plt.close()
assert "sample_rate" in metadata and "beat" in metadata and "downbeat" in metadata
sampling_rate = metadata["sample_rate"]
duration = dl_output["duration"]
# The dataloader segment length before performing torch resampling
original_segment_length_before_resample = int(sampling_rate * duration)
random_start_sample = int(dl_output["random_start_sample_in_original_audio_file"])
# The sample idx for beat and downbeat, relatively to the segmented audio
beat = [
x - random_start_sample
for x in metadata["beat"]
if (
x - random_start_sample >= 0
and x - random_start_sample <= original_segment_length_before_resample
)
]
downbeat = [
x - random_start_sample
for x in metadata["downbeat"]
if (
x - random_start_sample >= 0
and x - random_start_sample <= original_segment_length_before_resample
)
]
latent_shape = (
config["model"]["params"]["latent_t_size"],
config["model"]["params"]["latent_f_size"],
)
conditional_signal = torch.zeros(latent_shape)
# beat: -0.5
# downbeat: +1.0
# 0: none; -0.5: beat; 1.0: downbeat; 0.5: downbeat+beat
for each in beat:
beat_index = int(
(each / original_segment_length_before_resample) * latent_shape[0]
)
beat_index = min(beat_index, conditional_signal.size(0) - 1)
conditional_signal[beat_index, :] -= 0.5
for each in downbeat:
beat_index = int(
(each / original_segment_length_before_resample) * latent_shape[0]
)
beat_index = min(beat_index, conditional_signal.size(0) - 1)
conditional_signal[beat_index, :] += 1.0
# visualization(conditional_signal, dl_output["log_mel_spec"], filename = os.path.basename(dl_output["fname"])+".png")
return {"cond_beat_downbeat": conditional_signal}