musika_api / utils.py
nakas's picture
Update utils.py
d10ce51
raw
history blame
25.8 kB
import os
import time
import datetime
from glob import glob
from tqdm import tqdm
import librosa
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import random_seed
import gradio as gr
from scipy.io.wavfile import write as write_wav
class Utils_functions:
def __init__(self, args):
self.args = args
melmat = tf.signal.linear_to_mel_weight_matrix(
num_mel_bins=args.mel_bins,
num_spectrogram_bins=(4 * args.hop * 2) // 2 + 1,
sample_rate=args.sr,
lower_edge_hertz=0.0,
upper_edge_hertz=args.sr // 2,
)
mel_f = tf.convert_to_tensor(librosa.mel_frequencies(n_mels=args.mel_bins + 2, fmin=0.0, fmax=args.sr // 2))
enorm = tf.cast(
tf.expand_dims(
tf.constant(2.0 / (mel_f[2 : args.mel_bins + 2] - mel_f[: args.mel_bins])),
0,
),
tf.float32,
)
melmat = tf.multiply(melmat, enorm)
melmat = tf.divide(melmat, tf.reduce_sum(melmat, axis=0))
self.melmat = tf.where(tf.math.is_nan(melmat), tf.zeros_like(melmat), melmat)
with np.errstate(divide="ignore", invalid="ignore"):
self.melmatinv = tf.constant(np.nan_to_num(np.divide(melmat.numpy().T, np.sum(melmat.numpy(), axis=1))).T)
def conc_tog_specphase(self, S, P):
S = tf.cast(S, tf.float32)
P = tf.cast(P, tf.float32)
S = self.denormalize(S, clip=False)
S = tf.math.sqrt(self.db2power(S) + 1e-7)
P = P * np.pi
Sls = tf.split(S, S.shape[0], 0)
S = tf.squeeze(tf.concat(Sls, 1), 0)
Pls = tf.split(P, P.shape[0], 0)
P = tf.squeeze(tf.concat(Pls, 1), 0)
SP = tf.cast(S, tf.complex64) * tf.math.exp(1j * tf.cast(P, tf.complex64))
wv = tf.signal.inverse_stft(
SP,
4 * self.args.hop,
self.args.hop,
fft_length=4 * self.args.hop,
window_fn=tf.signal.inverse_stft_window_fn(self.args.hop),
)
return np.squeeze(wv)
def _tf_log10(self, x):
numerator = tf.math.log(x)
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
return numerator / denominator
def normalize(self, S, clip=False):
S = (S - self.args.mu_rescale) / self.args.sigma_rescale
if clip:
S = tf.clip_by_value(S, -1.0, 1.0)
return S
def normalize_rel(self, S):
S = S - tf.math.reduce_min(S + 1e-7)
S = (S / (tf.math.reduce_max(S + 1e-7) + 1e-7)) + 1e-7
return S
def denormalize(self, S, clip=False):
if clip:
S = tf.clip_by_value(S, -1.0, 1.0)
return (S * self.args.sigma_rescale) + self.args.mu_rescale
def amp2db(self, x):
return 20 * self._tf_log10(tf.clip_by_value(tf.abs(x), 1e-5, 1e100))
def db2amp(self, x):
return tf.pow(tf.ones(tf.shape(x)) * 10.0, x * 0.05)
def power2db(self, power, ref_value=1.0, amin=1e-10, top_db=None, norm=False):
log_spec = 10.0 * self._tf_log10(tf.maximum(amin, power))
log_spec -= 10.0 * self._tf_log10(tf.maximum(amin, ref_value))
if top_db is not None:
log_spec = tf.maximum(log_spec, tf.reduce_max(log_spec) - top_db)
return log_spec
def power2db_batch(self, power, ref_value=1.0, amin=1e-10, top_db=None, norm=False):
log_spec = 10.0 * self._tf_log10(tf.maximum(amin, power))
log_spec -= 10.0 * self._tf_log10(tf.maximum(amin, ref_value))
if top_db is not None:
log_spec = tf.maximum(log_spec, tf.reduce_max(log_spec, [-2, -1], keepdims=True) - top_db)
return log_spec
def db2power(self, S_db, ref=1.0):
return ref * tf.math.pow(10.0, 0.1 * S_db)
def wv2mel(self, wv, topdb=80.0):
X = tf.signal.stft(
wv,
frame_length=4 * self.args.hop,
frame_step=self.args.hop,
fft_length=4 * self.args.hop,
window_fn=tf.signal.hann_window,
pad_end=False,
)
S = self.normalize(self.power2db(tf.abs(X) ** 2, top_db=topdb) - self.args.ref_level_db)
SM = tf.tensordot(S, self.melmat, 1)
return SM
def mel2spec(self, SM):
return tf.tensordot(SM, tf.transpose(self.melmatinv), 1)
def spec2mel(self, S):
return tf.tensordot(S, self.melmat, 1)
def wv2spec(self, wv, hop_size=256, fac=4):
X = tf.signal.stft(
wv,
frame_length=fac * hop_size,
frame_step=hop_size,
fft_length=fac * hop_size,
window_fn=tf.signal.hann_window,
pad_end=False,
)
return self.normalize(self.power2db(tf.abs(X) ** 2, top_db=None))
def wv2spec_hop(self, wv, topdb=80.0, hopsize=256):
X = tf.signal.stft(
wv,
frame_length=4 * hopsize,
frame_step=hopsize,
fft_length=4 * hopsize,
window_fn=tf.signal.hann_window,
pad_end=False,
)
S = self.normalize(self.power2db(tf.abs(X) ** 2, top_db=topdb))
return tf.tensordot(S, self.melmat, 1)
def rand_channel_swap(self, x):
s_l, s_r = tf.split(x, 2, -1)
if tf.random.uniform((), dtype=tf.float32) > 0.5:
sl = s_l
sr = s_r
else:
sl = s_r
sr = s_l
return tf.concat([sl, sr], -1)
def distribute(self, x, model, bs=32, dual_out=False):
outls = []
if isinstance(x, list):
bdim = x[0].shape[0]
for i in range(((bdim - 2) // bs) + 1):
outls.append(model([el[i * bs : i * bs + bs] for el in x], training=False))
else:
bdim = x.shape[0]
for i in range(((bdim - 2) // bs) + 1):
outls.append(model(x[i * bs : i * bs + bs], training=False))
if dual_out:
return np.concatenate([outls[k][0] for k in range(len(outls))], 0), np.concatenate(
[outls[k][1] for k in range(len(outls))], 0
)
else:
return np.concatenate(outls, 0)
def distribute_enc(self, x, model, bs=32):
outls = []
if isinstance(x, list):
bdim = x[0].shape[0]
for i in range(((bdim - 2) // bs) + 1):
res = model([el[i * bs : i * bs + bs] for el in x], training=False)
resls = tf.split(res, self.args.shape // self.args.window, 0)
res = tf.concat(resls, -2)
outls.append(res)
else:
bdim = x.shape[0]
for i in range(((bdim - 2) // bs) + 1):
res = model(x[i * bs : i * bs + bs], training=False)
resls = tf.split(res, self.args.shape // self.args.window, 0)
res = tf.concat(resls, -2)
outls.append(res)
return tf.concat(outls, 0)
def distribute_dec(self, x, model, bs=32):
outls = []
bdim = x.shape[0]
for i in range(((bdim - 2) // bs) + 1):
inp = x[i * bs : i * bs + bs]
inpls = tf.split(inp, 2, -2)
inp = tf.concat(inpls, 0)
res = model(inp, training=False)
outls.append(res)
return np.concatenate([outls[k][0] for k in range(len(outls))], 0), np.concatenate(
[outls[k][1] for k in range(len(outls))], 0
)
def distribute_dec2(self, x, model, bs=32):
outls = []
bdim = x.shape[0]
for i in range(((bdim - 2) // bs) + 1):
inp1 = x[i * bs : i * bs + bs]
inpls = tf.split(inp1, 2, -2)
inp1 = tf.concat(inpls, 0)
outls.append(model(inp1, training=False))
return tf.concat(outls, 0)
def center_coordinate(
self, x
): # allows to have sequences with even number length with anchor in the middle of the sequence
return tf.reduce_mean(tf.stack([x, tf.roll(x, -1, -2)], 0), 0)[:, :-1, :]
# hardcoded! need to figure out how to generalize it without breaking jit compiling
def crop_coordinate(
self, x
): # randomly crops a conditioning sequence such that the anchor vector is at center of generator receptive field (maximum context is provided to the generator)
fac = tf.random.uniform((), 0, self.args.coordlen // (self.args.latlen // 2), dtype=tf.int32)
if fac == 0:
return tf.reshape(
x[
:,
(self.args.latlen // 4)
+ 0 * (self.args.latlen // 2) : (self.args.latlen // 4)
+ 0 * (self.args.latlen // 2)
+ self.args.latlen,
:,
],
[-1, self.args.latlen, x.shape[-1]],
)
elif fac == 1:
return tf.reshape(
x[
:,
(self.args.latlen // 4)
+ 1 * (self.args.latlen // 2) : (self.args.latlen // 4)
+ 1 * (self.args.latlen // 2)
+ self.args.latlen,
:,
],
[-1, self.args.latlen, x.shape[-1]],
)
else:
return tf.reshape(
x[
:,
(self.args.latlen // 4)
+ 2 * (self.args.latlen // 2) : (self.args.latlen // 4)
+ 2 * (self.args.latlen // 2)
+ self.args.latlen,
:,
],
[-1, self.args.latlen, x.shape[-1]],
)
def update_switch(self, switch, ca, cab, learning_rate_switch=0.0001, stable_point=0.9):
cert = tf.math.minimum(tf.math.maximum(tf.reduce_mean(ca) - tf.reduce_mean(cab), 0.0), 2.0) / 2.0
if cert > stable_point:
switch_new = switch - learning_rate_switch
else:
switch_new = switch + learning_rate_switch
return tf.math.maximum(tf.math.minimum(switch_new, 0.0), -1.0)
def get_noise_interp(self):
noiseg = tf.random.normal([1, 64], dtype=tf.float32)
noisel = tf.concat([tf.random.normal([1, self.args.coorddepth], dtype=tf.float32), noiseg], -1)
noisec = tf.concat([tf.random.normal([1, self.args.coorddepth], dtype=tf.float32), noiseg], -1)
noiser = tf.concat([tf.random.normal([1, self.args.coorddepth], dtype=tf.float32), noiseg], -1)
rl = tf.linspace(noisel, noisec, self.args.coordlen + 1, axis=-2)[:, :-1, :]
rr = tf.linspace(noisec, noiser, self.args.coordlen + 1, axis=-2)
noisetot = tf.concat([rl, rr], -2)
noisetot = self.center_coordinate(noisetot)
return self.crop_coordinate(noisetot)
def generate_example_stereo(self, models_ls):
(
critic,
gen,
enc,
dec,
enc2,
dec2,
gen_ema,
[opt_dec, opt_disc],
switch,
) = models_ls
abb = gen_ema(self.get_noise_interp(), training=False)
abbls = tf.split(abb, abb.shape[-2] // 8, -2)
abb = tf.concat(abbls, 0)
chls = []
for channel in range(2):
ab = self.distribute_dec2(
abb[
:,
:,
:,
channel * self.args.latdepth : channel * self.args.latdepth + self.args.latdepth,
],
dec2,
)
abls = tf.split(ab, ab.shape[-2] // self.args.shape, -2)
ab = tf.concat(abls, 0)
ab_m, ab_p = self.distribute_dec(ab, dec)
wv = self.conc_tog_specphase(ab_m, ab_p)
chls.append(wv)
return np.stack(chls, -1)
# Save in training loop
def save_test_image_full(self, path, models_ls=None):
abwv = self.generate_example_stereo(models_ls)
abwv2 = self.generate_example_stereo(models_ls)
abwv3 = self.generate_example_stereo(models_ls)
abwv4 = self.generate_example_stereo(models_ls)
# IPython.display.display(
# IPython.display.Audio(np.squeeze(np.transpose(abwv)), rate=self.args.sr)
# )
# IPython.display.display(
# IPython.display.Audio(np.squeeze(np.transpose(abwv2)), rate=self.args.sr)
# )
# IPython.display.display(
# IPython.display.Audio(np.squeeze(np.transpose(abwv3)), rate=self.args.sr)
# )
# IPython.display.display(
# IPython.display.Audio(np.squeeze(np.transpose(abwv4)), rate=self.args.sr)
# )
write_wav(f"{path}/out1.wav", self.args.sr, np.squeeze(abwv))
write_wav(f"{path}/out2.wav", self.args.sr, np.squeeze(abwv2))
write_wav(f"{path}/out3.wav", self.args.sr, np.squeeze(abwv3))
write_wav(f"{path}/out4.wav", self.args.sr, np.squeeze(abwv4))
fig, axs = plt.subplots(nrows=4, ncols=1, figsize=(20, 20))
axs[0].imshow(
np.flip(
np.array(
tf.transpose(
self.wv2spec_hop((abwv[:, 0] + abwv[:, 1]) / 2.0, 80.0, self.args.hop * 2),
[1, 0],
)
),
-2,
),
cmap=None,
)
axs[0].axis("off")
axs[0].set_title("Generated1")
axs[1].imshow(
np.flip(
np.array(
tf.transpose(
self.wv2spec_hop((abwv2[:, 0] + abwv2[:, 1]) / 2.0, 80.0, self.args.hop * 2),
[1, 0],
)
),
-2,
),
cmap=None,
)
axs[1].axis("off")
axs[1].set_title("Generated2")
axs[2].imshow(
np.flip(
np.array(
tf.transpose(
self.wv2spec_hop((abwv3[:, 0] + abwv3[:, 1]) / 2.0, 80.0, self.args.hop * 2),
[1, 0],
)
),
-2,
),
cmap=None,
)
axs[2].axis("off")
axs[2].set_title("Generated3")
axs[3].imshow(
np.flip(
np.array(
tf.transpose(
self.wv2spec_hop((abwv4[:, 0] + abwv4[:, 1]) / 2.0, 80.0, self.args.hop * 2),
[1, 0],
)
),
-2,
),
cmap=None,
)
axs[3].axis("off")
axs[3].set_title("Generated4")
# plt.show()
plt.savefig(f"{path}/output.png")
plt.close()
def save_end(
self,
epoch,
gloss,
closs,
mloss,
models_ls=None,
n_save=3,
save_path="checkpoints",
):
(critic, gen, enc, dec, enc2, dec2, gen_ema, [opt_dec, opt_disc], switch) = models_ls
if epoch % n_save == 0:
print("Saving...")
path = f"{save_path}/MUSIKA_iterations-{((epoch+1)*self.args.totsamples)//(self.args.bs*1000)}k_losses-{str(gloss)[:9]}-{str(closs)[:9]}-{str(mloss)[:9]}"
os.mkdir(path)
critic.save_weights(path + "/critic.h5")
gen.save_weights(path + "/gen.h5")
gen_ema.save_weights(path + "/gen_ema.h5")
# enc.save_weights(path + "/enc.h5")
# dec.save_weights(path + "/dec.h5")
# enc2.save_weights(path + "/enc2.h5")
# dec2.save_weights(path + "/dec2.h5")
np.save(path + "/opt_dec.npy", opt_dec.get_weights())
np.save(path + "/opt_disc.npy", opt_disc.get_weights())
np.save(path + "/switch.npy", switch.numpy())
self.save_test_image_full(path, models_ls=models_ls)
def truncated_normal(self, shape, bound=2.0, dtype=tf.float32):
seed1, seed2 = random_seed.get_seed(tf.random.uniform((), tf.int32.min, tf.int32.max, dtype=tf.int32))
return tf.random.stateless_parameterized_truncated_normal(shape, [seed1, seed2], 0.0, 1.0, -bound, bound)
def distribute_gen(self, x, model, bs=32):
outls = []
bdim = x.shape[0]
if bdim == 1:
bdim = 2
for i in range(((bdim - 2) // bs) + 1):
outls.append(model(x[i * bs : i * bs + bs], training=False))
return tf.concat(outls, 0)
def generate_waveform(self, inp, gen_ema, dec, dec2, batch_size=64):
ab = self.distribute_gen(inp, gen_ema, bs=batch_size)
abls = tf.split(ab, ab.shape[0], 0)
ab = tf.concat(abls, -2)
abls = tf.split(ab, ab.shape[-2] // 8, -2)
abi = tf.concat(abls, 0)
chls = []
for channel in range(2):
ab = self.distribute_dec2(
abi[:, :, :, channel * self.args.latdepth : channel * self.args.latdepth + self.args.latdepth],
dec2,
bs=batch_size,
)
abls = tf.split(ab, ab.shape[-2] // self.args.shape, -2)
ab = tf.concat(abls, 0)
ab_m, ab_p = self.distribute_dec(ab, dec, bs=batch_size)
abwv = self.conc_tog_specphase(ab_m, ab_p)
chls.append(abwv)
return np.clip(np.squeeze(np.stack(chls, -1)), -1.0, 1.0)
def decode_waveform(self, lat, dec, dec2, batch_size=64):
lat = lat[:, :, : (lat.shape[-2] // 8) * 8, :]
abls = tf.split(lat, lat.shape[-2] // 8, -2)
abi = tf.concat(abls, 0)
chls = []
for channel in range(2):
ab = self.distribute_dec2(
abi[:, :, :, channel * self.args.latdepth : channel * self.args.latdepth + self.args.latdepth],
dec2,
bs=batch_size,
)
abls = tf.split(ab, ab.shape[-2] // self.args.shape, -2)
ab = tf.concat(abls, 0)
ab_m, ab_p = self.distribute_dec(ab, dec, bs=batch_size)
abwv = self.conc_tog_specphase(ab_m, ab_p)
chls.append(abwv)
return np.clip(np.squeeze(np.stack(chls, -1)), -1.0, 1.0)
def get_noise_interp_multi(self, fac=1, var=2.0):
noiseg = self.truncated_normal([1, self.args.coorddepth], var, dtype=tf.float32)
coordratio = self.args.coordlen // self.args.latlen
noisels = [
tf.concat([self.truncated_normal([1, 64], var, dtype=tf.float32), noiseg], -1)
for i in range(3 + ((fac - 1) // coordratio))
]
rls = tf.concat(
[
tf.linspace(noisels[k], noisels[k + 1], self.args.coordlen + 1, axis=-2)[:, :-1, :]
for k in range(len(noisels) - 1)
],
-2,
)
rls = self.center_coordinate(rls)
rls = rls[:, self.args.latlen // 4 :, :]
rls = rls[:, : (rls.shape[-2] // self.args.latlen) * self.args.latlen, :]
rls = tf.split(rls, rls.shape[-2] // self.args.latlen, -2)
return tf.concat(rls[:fac], 0)
def get_noise_interp_loop(self, fac=1, var=2.0):
noiseg = self.truncated_normal([1, self.args.coorddepth], var, dtype=tf.float32)
coordratio = self.args.coordlen // self.args.latlen
noisels_pre = [tf.concat([self.truncated_normal([1, 64], var, dtype=tf.float32), noiseg], -1) for i in range(2)]
noisels = []
for k in range(fac + 2):
noisels.append(noisels_pre[0])
noisels.append(noisels_pre[1])
rls = tf.concat(
[
tf.linspace(noisels[k], noisels[k + 1], self.args.latlen // 2 + 1, axis=-2)[:, :-1, :]
for k in range(len(noisels) - 1)
],
-2,
)
rls = self.center_coordinate(rls)
rls = rls[:, self.args.latlen // 2 :, :]
rls = rls[:, : (rls.shape[-2] // self.args.latlen) * self.args.latlen, :]
rls = tf.split(rls, rls.shape[-2] // self.args.latlen, -2)
return tf.concat(rls[:fac], 0)
def generate(self, models_ls):
critic, gen, enc, dec, enc2, dec2, gen_ema, [opt_dec, opt_disc], switch = models_ls
os.makedirs(self.args.save_path, exist_ok=True)
fac = (self.args.seconds // 23) + 1
print(f"Generating {self.args.num_samples} samples...")
for i in tqdm(range(self.args.num_samples)):
wv = self.generate_waveform(
self.get_noise_interp_multi(fac, self.args.truncation), gen_ema, dec, dec2, batch_size=64
)
dt = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
write_wav(
f"{self.args.save_path}/{i}_{dt}.wav", self.args.sr, np.squeeze(wv)[: self.args.seconds * self.args.sr]
)
def decode_path(self, models_ls):
critic, gen, enc, dec, enc2, dec2, gen_ema, [opt_dec, opt_disc], switch = models_ls
os.makedirs(self.args.save_path, exist_ok=True)
pathls = glob(self.args.files_path + "/*.npy")
print(f"Decoding {len(pathls)} samples...")
for p in tqdm(pathls):
tp, ext = os.path.splitext(p)
bname = os.path.basename(tp)
lat = np.load(p, allow_pickle=True)
lat = tf.expand_dims(lat, 0)
lat = tf.expand_dims(lat, 0)
wv = self.decode_waveform(lat, dec, dec2, batch_size=64)
# dt = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
write_wav(f"{self.args.save_path}/{bname}.wav", self.args.sr, np.squeeze(wv))
def stfunc(self, genre, z, var, models_ls_1, models_ls_2, models_ls_3):
critic, gen, enc, dec, enc2, dec2, gen_ema_1, [opt_dec, opt_disc], switch = models_ls_1
critic, gen, enc, dec, enc2, dec2, gen_ema_2, [opt_dec, opt_disc], switch = models_ls_2
critic, gen, enc, dec, enc2, dec2, gen_ema_3, [opt_dec, opt_disc], switch = models_ls_3
if genre == 0:
gen_ema = gen_ema_1
elif genre == 1:
gen_ema = gen_ema_2
else:
gen_ema = gen_ema_3
var = float(var)
if z == 0:
fac = 1
elif z == 1:
fac = 5
else:
fac = 10
bef = time.time()
noiseinp = self.get_noise_interp_multi(fac, var)
abwvc = self.generate_waveform(noiseinp, gen_ema, dec, dec2, batch_size=64)
# print(
# f"Time for complete generation pipeline: {time.time()-bef} s {int(np.round((fac*23.)/(time.time()-bef)))}x faster than Real Time!"
# )
spec = np.flip(
np.array(
tf.transpose(
self.wv2spec_hop(
(abwvc[: 23 * self.args.sr, 0] + abwvc[: 23 * self.args.sr, 1]) / 2.0, 80.0, self.args.hop * 2
),
[1, 0],
)
),
-2,
)
output = "/tmp/outputfile.wav"
write_wav (output,self.args.sr,np.int16(abwvc * 32767.0))
return (
np.clip(spec, -1.0, 1.0),
output,
)
def render_gradio(self, models_ls_1, models_ls_2, models_ls_3, train=True):
article_text = "Original work by Marco Pasini ([Twitter](https://twitter.com/marco_ppasini)) at the Institute of Computational Perception, JKU Linz. Supervised by Jan Schlüter."
def gradio_func(genre, x, y):
return self.stfunc(genre, x, y, models_ls_1, models_ls_2, models_ls_3)
if self.args.small:
durations = ["11s", "59s", "1m 58s"]
durations_default = "59s"
else:
durations = ["23s", "1m 58s", "3m 57s"]
durations_default = "1m 58s"
iface = gr.Interface(
fn=gradio_func,
inputs=[
gr.Radio(
choices=["Techno/Experimental", "Death Metal (finetuned)", "Misc"],
type="index",
value="Techno/Experimental",
label="Music Genre to Generate",
),
gr.Radio(
choices=durations,
type="index",
value=durations_default,
label="Generated Music Length",
),
gr.Slider(
minimum=0.1,
maximum=3.9,
step=0.1,
value=1.8,
label="How much do you want the music style to be varied? (Stddev truncation for random vectors)",
),
],
outputs=[
gr.Image(label="Log-MelSpectrogram of Generated Audio (first 23 s)"),
gr.Audio(type="numpy", label="Generated Audio"),
],
title="musika!",
description="Blazingly Fast 44.1 kHz Stereo Waveform Music Generation of Arbitrary Length. Be patient and enjoy the weirdness!",
article=article_text,
)
print("--------------------------------")
print("--------------------------------")
print("--------------------------------")
print("--------------------------------")
print("--------------------------------")
print("CLICK ON LINK BELOW TO OPEN GRADIO INTERFACE")
if train:
iface.launch(prevent_thread_lock=True)
else:
iface.launch(enable_queue=True)
# iface.launch(share=True, enable_queue=True)
print("--------------------------------")
print("--------------------------------")
print("--------------------------------")
print("--------------------------------")
print("--------------------------------")