Spaces:
Sleeping
Sleeping
from email.policy import strict | |
import torch | |
import os | |
import pytorch_lightning as pl | |
import torch.nn.functional as F | |
from contextlib import contextmanager | |
import numpy as np | |
from qa_mdt.audioldm_train.modules.diffusionmodules.ema import * | |
from torch.optim.lr_scheduler import LambdaLR | |
from qa_mdt.audioldm_train.modules.diffusionmodules.model import Encoder, Decoder | |
from qa_mdt.audioldm_train.modules.diffusionmodules.distributions import ( | |
DiagonalGaussianDistribution, | |
) | |
import wandb | |
from qa_mdt.audioldm_train.utilities.model_util import instantiate_from_config | |
import soundfile as sf | |
from qa_mdt.audioldm_train.utilities.model_util import get_vocoder | |
from qa_mdt.audioldm_train.utilities.tools import synth_one_sample | |
import itertools | |
class AutoencoderKL(pl.LightningModule): | |
def __init__( | |
self, | |
ddconfig=None, | |
lossconfig=None, | |
batchsize=None, | |
embed_dim=None, | |
time_shuffle=1, | |
subband=1, | |
sampling_rate=16000, | |
ckpt_path=None, | |
reload_from_ckpt=None, | |
ignore_keys=[], | |
image_key="fbank", | |
colorize_nlabels=None, | |
monitor=None, | |
base_learning_rate=1e-5, | |
): | |
super().__init__() | |
self.automatic_optimization = False | |
assert ( | |
"mel_bins" in ddconfig.keys() | |
), "mel_bins is not specified in the Autoencoder config" | |
num_mel = ddconfig["mel_bins"] | |
self.image_key = image_key | |
self.sampling_rate = sampling_rate | |
self.encoder = Encoder(**ddconfig) | |
self.decoder = Decoder(**ddconfig) | |
self.loss = instantiate_from_config(lossconfig) | |
self.subband = int(subband) | |
if self.subband > 1: | |
print("Use subband decomposition %s" % self.subband) | |
assert ddconfig["double_z"] | |
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) | |
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) | |
if self.image_key == "fbank": | |
self.vocoder = get_vocoder(None, "cpu", num_mel) | |
self.embed_dim = embed_dim | |
if colorize_nlabels is not None: | |
assert type(colorize_nlabels) == int | |
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) | |
if monitor is not None: | |
self.monitor = monitor | |
if ckpt_path is not None: | |
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) | |
self.learning_rate = float(base_learning_rate) | |
print("Initial learning rate %s" % self.learning_rate) | |
self.time_shuffle = time_shuffle | |
self.reload_from_ckpt = reload_from_ckpt | |
self.reloaded = False | |
self.mean, self.std = None, None | |
self.feature_cache = None | |
self.flag_first_run = True | |
self.train_step = 0 | |
self.logger_save_dir = None | |
self.logger_exp_name = None | |
self.logger_exp_group_name = None | |
if not self.reloaded and self.reload_from_ckpt is not None: | |
# import pdb | |
# pdb.set_trace() | |
print("--> Reload weight of autoencoder from %s" % self.reload_from_ckpt) | |
checkpoint = torch.load(self.reload_from_ckpt) | |
load_todo_keys = {} | |
pretrained_state_dict = checkpoint["state_dict"] | |
current_state_dict = self.state_dict() | |
for key in current_state_dict: | |
if ( | |
key in pretrained_state_dict.keys() | |
and pretrained_state_dict[key].size() | |
== current_state_dict[key].size() | |
): | |
load_todo_keys[key] = pretrained_state_dict[key] | |
else: | |
print("Key %s mismatch during loading, seems fine" % key) | |
self.load_state_dict(load_todo_keys, strict=False) | |
self.reloaded = True | |
else: | |
print("Train from scratch") | |
def get_log_dir(self): | |
return os.path.join( | |
self.logger_save_dir, self.logger_exp_group_name, self.logger_exp_name | |
) | |
def set_log_dir(self, save_dir, exp_group_name, exp_name): | |
self.logger_save_dir = save_dir | |
self.logger_exp_name = exp_name | |
self.logger_exp_group_name = exp_group_name | |
def init_from_ckpt(self, path, ignore_keys=list()): | |
sd = torch.load(path, map_location="cpu")["state_dict"] | |
keys = list(sd.keys()) | |
for k in keys: | |
for ik in ignore_keys: | |
if k.startswith(ik): | |
print("Deleting key {} from state_dict.".format(k)) | |
del sd[k] | |
self.load_state_dict(sd, strict=False) | |
print(f"Restored from {path}") | |
def encode(self, x): | |
# x = self.time_shuffle_operation(x) | |
x = self.freq_split_subband(x) | |
h = self.encoder(x) | |
moments = self.quant_conv(h) | |
posterior = DiagonalGaussianDistribution(moments) | |
return posterior | |
def decode(self, z): | |
z = self.post_quant_conv(z) | |
dec = self.decoder(z) | |
# bs, ch, shuffled_timesteps, fbins = dec.size() | |
# dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins) | |
dec = self.freq_merge_subband(dec) | |
return dec | |
def decode_to_waveform(self, dec): | |
from qa_mdt.audioldm_train.utilities.model_util import vocoder_infer | |
if self.image_key == "fbank": | |
dec = dec.squeeze(1).permute(0, 2, 1) | |
wav_reconstruction = vocoder_infer(dec, self.vocoder) | |
elif self.image_key == "stft": | |
dec = dec.squeeze(1).permute(0, 2, 1) | |
wav_reconstruction = self.wave_decoder(dec) | |
return wav_reconstruction | |
def visualize_latent(self, input): | |
import matplotlib.pyplot as plt | |
# for i in range(10): | |
# zero_input = torch.zeros_like(input) - 11.59 | |
# zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59 | |
# posterior = self.encode(zero_input) | |
# latent = posterior.sample() | |
# avg_latent = torch.mean(latent, dim=1)[0] | |
# plt.imshow(avg_latent.cpu().detach().numpy().T) | |
# plt.savefig("%s.png" % i) | |
# plt.close() | |
np.save("input.npy", input.cpu().detach().numpy()) | |
# zero_input = torch.zeros_like(input) - 11.59 | |
time_input = input.clone() | |
time_input[:, :, :, :32] *= 0 | |
time_input[:, :, :, :32] -= 11.59 | |
np.save("time_input.npy", time_input.cpu().detach().numpy()) | |
posterior = self.encode(time_input) | |
latent = posterior.sample() | |
np.save("time_latent.npy", latent.cpu().detach().numpy()) | |
avg_latent = torch.mean(latent, dim=1) | |
for i in range(avg_latent.size(0)): | |
plt.imshow(avg_latent[i].cpu().detach().numpy().T) | |
plt.savefig("freq_%s.png" % i) | |
plt.close() | |
freq_input = input.clone() | |
freq_input[:, :, :512, :] *= 0 | |
freq_input[:, :, :512, :] -= 11.59 | |
np.save("freq_input.npy", freq_input.cpu().detach().numpy()) | |
posterior = self.encode(freq_input) | |
latent = posterior.sample() | |
np.save("freq_latent.npy", latent.cpu().detach().numpy()) | |
avg_latent = torch.mean(latent, dim=1) | |
for i in range(avg_latent.size(0)): | |
plt.imshow(avg_latent[i].cpu().detach().numpy().T) | |
plt.savefig("time_%s.png" % i) | |
plt.close() | |
def forward(self, input, sample_posterior=True): | |
posterior = self.encode(input) | |
if sample_posterior: | |
z = posterior.sample() | |
else: | |
z = posterior.mode() | |
if self.flag_first_run: | |
print("Latent size: ", z.size()) | |
self.flag_first_run = False | |
dec = self.decode(z) | |
return dec, posterior | |
def get_input(self, batch): | |
fname, text, label_indices, waveform, stft, fbank = ( | |
batch["fname"], | |
batch["text"], | |
batch["label_vector"], | |
batch["waveform"], | |
batch["stft"], | |
batch["log_mel_spec"], | |
) | |
# if(self.time_shuffle != 1): | |
# if(fbank.size(1) % self.time_shuffle != 0): | |
# pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle) | |
# fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len)) | |
ret = {} | |
ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = ( | |
fbank.unsqueeze(1), | |
stft.unsqueeze(1), | |
fname, | |
waveform.unsqueeze(1), | |
) | |
return ret | |
# def time_shuffle_operation(self, fbank): | |
# if(self.time_shuffle == 1): | |
# return fbank | |
# shuffled_fbank = [] | |
# for i in range(self.time_shuffle): | |
# shuffled_fbank.append(fbank[:,:, i::self.time_shuffle,:]) | |
# return torch.cat(shuffled_fbank, dim=1) | |
# def time_unshuffle_operation(self, shuffled_fbank, bs, timesteps, fbins): | |
# if(self.time_shuffle == 1): | |
# return shuffled_fbank | |
# buffer = torch.zeros((bs, 1, timesteps, fbins)).to(shuffled_fbank.device) | |
# for i in range(self.time_shuffle): | |
# buffer[:,0,i::self.time_shuffle,:] = shuffled_fbank[:,i,:,:] | |
# return buffer | |
def freq_split_subband(self, fbank): | |
if self.subband == 1 or self.image_key != "stft": | |
return fbank | |
bs, ch, tstep, fbins = fbank.size() | |
assert fbank.size(-1) % self.subband == 0 | |
assert ch == 1 | |
return ( | |
fbank.squeeze(1) | |
.reshape(bs, tstep, self.subband, fbins // self.subband) | |
.permute(0, 2, 1, 3) | |
) | |
def freq_merge_subband(self, subband_fbank): | |
if self.subband == 1 or self.image_key != "stft": | |
return subband_fbank | |
assert subband_fbank.size(1) == self.subband # Channel dimension | |
bs, sub_ch, tstep, fbins = subband_fbank.size() | |
return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1) | |
def training_step(self, batch, batch_idx): | |
g_opt, d_opt = self.optimizers() | |
inputs_dict = self.get_input(batch) | |
inputs = inputs_dict[self.image_key] | |
waveform = inputs_dict["waveform"] | |
if batch_idx % 5000 == 0 and self.local_rank == 0: | |
print("Log train image") | |
self.log_images(inputs, waveform=waveform) | |
reconstructions, posterior = self(inputs) | |
if self.image_key == "stft": | |
rec_waveform = self.decode_to_waveform(reconstructions) | |
else: | |
rec_waveform = None | |
# train the discriminator | |
# If working on waveform, inputs is STFT, reconstructions are the waveform | |
# If working on the melspec, inputs is melspec, reconstruction are also mel spec | |
discloss, log_dict_disc = self.loss( | |
inputs=inputs, | |
reconstructions=reconstructions, | |
posteriors=posterior, | |
waveform=waveform, | |
rec_waveform=rec_waveform, | |
optimizer_idx=1, | |
global_step=self.global_step, | |
last_layer=self.get_last_layer(), | |
split="train", | |
) | |
self.log( | |
"discloss", | |
discloss, | |
prog_bar=True, | |
logger=True, | |
on_step=True, | |
on_epoch=True, | |
) | |
self.log_dict( | |
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False | |
) | |
d_opt.zero_grad() | |
self.manual_backward(discloss) | |
d_opt.step() | |
self.log( | |
"train_step", | |
self.train_step, | |
prog_bar=False, | |
logger=False, | |
on_step=True, | |
on_epoch=False, | |
) | |
self.log( | |
"global_step", | |
float(self.global_step), | |
prog_bar=True, | |
logger=True, | |
on_step=True, | |
on_epoch=False, | |
) | |
aeloss, log_dict_ae = self.loss( | |
inputs=inputs, | |
reconstructions=reconstructions, | |
posteriors=posterior, | |
waveform=waveform, | |
rec_waveform=rec_waveform, | |
optimizer_idx=0, | |
global_step=self.global_step, | |
last_layer=self.get_last_layer(), | |
split="train", | |
) | |
self.log( | |
"aeloss", | |
aeloss, | |
prog_bar=True, | |
logger=True, | |
on_step=True, | |
on_epoch=False, | |
) | |
self.log( | |
"posterior_std", | |
torch.mean(posterior.var), | |
prog_bar=True, | |
logger=True, | |
on_step=True, | |
on_epoch=False, | |
) | |
self.log_dict( | |
log_dict_ae, prog_bar=True, logger=True, on_step=True, on_epoch=False | |
) | |
self.train_step += 1 | |
g_opt.zero_grad() | |
self.manual_backward(aeloss) | |
g_opt.step() | |
def validation_step(self, batch, batch_idx): | |
inputs_dict = self.get_input(batch) | |
inputs = inputs_dict[self.image_key] | |
waveform = inputs_dict["waveform"] | |
if batch_idx <= 3: | |
print("Log val image") | |
self.log_images(inputs, train=False, waveform=waveform) | |
reconstructions, posterior = self(inputs) | |
if self.image_key == "stft": | |
rec_waveform = self.decode_to_waveform(reconstructions) | |
else: | |
rec_waveform = None | |
aeloss, log_dict_ae = self.loss( | |
inputs=inputs, | |
reconstructions=reconstructions, | |
posteriors=posterior, | |
waveform=waveform, | |
rec_waveform=rec_waveform, | |
optimizer_idx=0, | |
global_step=self.global_step, | |
last_layer=self.get_last_layer(), | |
split="val", | |
) | |
discloss, log_dict_disc = self.loss( | |
inputs=inputs, | |
reconstructions=reconstructions, | |
posteriors=posterior, | |
waveform=waveform, | |
rec_waveform=rec_waveform, | |
optimizer_idx=1, | |
global_step=self.global_step, | |
last_layer=self.get_last_layer(), | |
split="val", | |
) | |
self.log_dict(log_dict_ae) | |
self.log_dict(log_dict_disc) | |
return self.log_dict | |
def test_step(self, batch, batch_idx): | |
inputs_dict = self.get_input(batch) | |
inputs = inputs_dict[self.image_key] | |
waveform = inputs_dict["waveform"] | |
fnames = inputs_dict["fname"] | |
reconstructions, posterior = self(inputs) | |
save_path = os.path.join( | |
self.get_log_dir(), "autoencoder_result_audiocaps", str(self.global_step) | |
) | |
if self.image_key == "stft": | |
wav_prediction = self.decode_to_waveform(reconstructions) | |
wav_original = waveform | |
self.save_wave( | |
wav_prediction, fnames, os.path.join(save_path, "stft_wav_prediction") | |
) | |
else: | |
wav_vocoder_gt, wav_prediction = synth_one_sample( | |
inputs.squeeze(1), | |
reconstructions.squeeze(1), | |
labels="validation", | |
vocoder=self.vocoder, | |
) | |
self.save_wave( | |
wav_vocoder_gt, fnames, os.path.join(save_path, "fbank_vocoder_gt_wave") | |
) | |
self.save_wave( | |
wav_prediction, fnames, os.path.join(save_path, "fbank_wav_prediction") | |
) | |
def save_wave(self, batch_wav, fname, save_dir): | |
os.makedirs(save_dir, exist_ok=True) | |
for wav, name in zip(batch_wav, fname): | |
name = os.path.basename(name) | |
sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate) | |
def configure_optimizers(self): | |
lr = self.learning_rate | |
params = ( | |
list(self.encoder.parameters()) | |
+ list(self.decoder.parameters()) | |
+ list(self.quant_conv.parameters()) | |
+ list(self.post_quant_conv.parameters()) | |
) | |
if self.image_key == "stft": | |
params += list(self.wave_decoder.parameters()) | |
opt_ae = torch.optim.Adam(params, lr=lr, betas=(0.5, 0.9)) | |
if self.image_key == "fbank": | |
disc_params = self.loss.discriminator.parameters() | |
elif self.image_key == "stft": | |
disc_params = itertools.chain( | |
self.loss.msd.parameters(), self.loss.mpd.parameters() | |
) | |
opt_disc = torch.optim.Adam(disc_params, lr=lr, betas=(0.5, 0.9)) | |
return [opt_ae, opt_disc], [] | |
def get_last_layer(self): | |
return self.decoder.conv_out.weight | |
def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs): | |
log = dict() | |
x = batch.to(self.device) | |
if not only_inputs: | |
xrec, posterior = self(x) | |
log["samples"] = self.decode(posterior.sample()) | |
log["reconstructions"] = xrec | |
log["inputs"] = x | |
wavs = self._log_img(log, train=train, index=0, waveform=waveform) | |
return wavs | |
def _log_img(self, log, train=True, index=0, waveform=None): | |
images_input = self.tensor2numpy(log["inputs"][index, 0]).T | |
images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T | |
images_samples = self.tensor2numpy(log["samples"][index, 0]).T | |
if train: | |
name = "train" | |
else: | |
name = "val" | |
if self.logger is not None: | |
self.logger.log_image( | |
"img_%s" % name, | |
[images_input, images_reconstruct, images_samples], | |
caption=["input", "reconstruct", "samples"], | |
) | |
inputs, reconstructions, samples = ( | |
log["inputs"], | |
log["reconstructions"], | |
log["samples"], | |
) | |
if self.image_key == "fbank": | |
wav_original, wav_prediction = synth_one_sample( | |
inputs[index], | |
reconstructions[index], | |
labels="validation", | |
vocoder=self.vocoder, | |
) | |
wav_original, wav_samples = synth_one_sample( | |
inputs[index], samples[index], labels="validation", vocoder=self.vocoder | |
) | |
wav_original, wav_samples, wav_prediction = ( | |
wav_original[0], | |
wav_samples[0], | |
wav_prediction[0], | |
) | |
elif self.image_key == "stft": | |
wav_prediction = ( | |
self.decode_to_waveform(reconstructions)[index, 0] | |
.cpu() | |
.detach() | |
.numpy() | |
) | |
wav_samples = ( | |
self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy() | |
) | |
wav_original = waveform[index, 0].cpu().detach().numpy() | |
if self.logger is not None: | |
self.logger.experiment.log( | |
{ | |
"original_%s" | |
% name: wandb.Audio( | |
wav_original, caption="original", sample_rate=self.sampling_rate | |
), | |
"reconstruct_%s" | |
% name: wandb.Audio( | |
wav_prediction, | |
caption="reconstruct", | |
sample_rate=self.sampling_rate, | |
), | |
"samples_%s" | |
% name: wandb.Audio( | |
wav_samples, caption="samples", sample_rate=self.sampling_rate | |
), | |
} | |
) | |
return wav_original, wav_prediction, wav_samples | |
def tensor2numpy(self, tensor): | |
return tensor.cpu().detach().numpy() | |
def to_rgb(self, x): | |
assert self.image_key == "segmentation" | |
if not hasattr(self, "colorize"): | |
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) | |
x = F.conv2d(x, weight=self.colorize) | |
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 | |
return x | |
class IdentityFirstStage(torch.nn.Module): | |
def __init__(self, *args, vq_interface=False, **kwargs): | |
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff | |
super().__init__() | |
def encode(self, x, *args, **kwargs): | |
return x | |
def decode(self, x, *args, **kwargs): | |
return x | |
def quantize(self, x, *args, **kwargs): | |
if self.vq_interface: | |
return x, None, [None, None, None] | |
return x | |
def forward(self, x, *args, **kwargs): | |
return x | |