image2reverb / app.py
Matthijs Hollemans
fix gradio inputs + outputs
d1b714f
# Hacked together using the code from https://github.com/nikhilsinghmus/image2reverb
import os, types
import numpy as np
import gradio as gr
import soundfile as sf
import scipy
import librosa.display
from PIL import Image
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from pytorch_lightning import Trainer
from image2reverb.model import Image2Reverb
from image2reverb.stft import STFT
predicted_ir = None
predicted_spectrogram = None
predicted_depthmap = None
def test_step(self, batch, batch_idx):
spec, label, paths = batch
examples = [os.path.splitext(os.path.basename(s))[0] for _, s in zip(*paths)]
f, img = self.enc.forward(label)
shape = (
f.shape[0],
(self._latent_dimension - f.shape[1]) if f.shape[1] < self._latent_dimension else f.shape[1],
f.shape[2],
f.shape[3]
)
z = torch.cat((f, torch.randn(shape, device=model.device)), 1)
fake_spec = self.g(z)
stft = STFT()
y_f = [stft.inverse(s.squeeze()) for s in fake_spec]
# TODO: bit hacky
global predicted_ir, predicted_spectrogram, predicted_depthmap
predicted_ir = y_f[0]
s = fake_spec.squeeze().cpu().numpy()
predicted_spectrogram = np.exp((((s + 1) * 0.5) * 19.5) - 17.5) - 1e-8
img = (img + 1) * 0.5
predicted_depthmap = img.cpu().squeeze().permute(1, 2, 0)[:,:,-1].squeeze().numpy()
return {"test_audio": y_f, "test_examples": examples}
def test_epoch_end(self, outputs):
if not self.test_callback:
return
examples = []
audio = []
for output in outputs:
for i in range(len(output["test_examples"])):
audio.append(output["test_audio"][i])
examples.append(output["test_examples"][i])
self.test_callback(examples, audio)
checkpoint_path = "./checkpoints/image2reverb_f22.ckpt"
encoder_path = None
depthmodel_path = "./checkpoints/mono_odom_640x192"
constant_depth = None
latent_dimension = 512
model = Image2Reverb(encoder_path, depthmodel_path)
m = torch.load(checkpoint_path, map_location=model.device)
model.load_state_dict(m["state_dict"])
model.test_step = types.MethodType(test_step, model)
model.test_epoch_end = types.MethodType(test_epoch_end, model)
image_transforms = transforms.Compose([
transforms.Resize([224, 224], transforms.functional.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
class Image2ReverbDemoDataset(Dataset):
def __init__(self, image):
self.image = Image.fromarray(image)
self.stft = STFT()
def __getitem__(self, index):
img_tensor = image_transforms(self.image.convert("RGB"))
return torch.zeros(1, int(5.94 * 22050)), img_tensor, ("", "")
def __len__(self):
return 1
def name(self):
return "Image2ReverbDemo"
def convolve(audio, reverb):
# convolve audio with reverb
wet_audio = np.concatenate((audio, np.zeros(reverb.shape)))
wet_audio = scipy.signal.oaconvolve(wet_audio, reverb, "full")[:len(wet_audio)]
# normalize audio to roughly -1 dB peak and remove DC offset
wet_audio /= np.max(np.abs(wet_audio))
wet_audio -= np.mean(wet_audio)
wet_audio *= 0.9
return wet_audio
def predict(image, audio):
# image = numpy (height, width, channels)
# audio = tuple (sample_rate, frames) or (sample_rate, (frames, channels))
test_set = Image2ReverbDemoDataset(image)
test_loader = torch.utils.data.DataLoader(test_set, num_workers=0, batch_size=1)
trainer = Trainer(limit_test_batches=1)
trainer.test(model, test_loader, verbose=True)
# depthmap output
depthmap_fig = plt.figure()
plt.imshow(predicted_depthmap)
plt.close()
# spectrogram output
spectrogram_fig = plt.figure()
librosa.display.specshow(predicted_spectrogram, sr=22050, x_axis="time", y_axis="hz")
plt.close()
# plot the IR as a waveform
waveform_fig = plt.figure()
librosa.display.waveshow(predicted_ir, sr=22050, alpha=0.5)
plt.close()
# output audio as 16-bit signed integer
ir = (22050, (predicted_ir * 32767).astype(np.int16))
sample_rate, original_audio = audio
# incoming audio is 16-bit signed integer, convert to float and normalize
original_audio = original_audio.astype(np.float32) / 32768.0
original_audio /= np.max(np.abs(original_audio))
# resample reverb to sample_rate first, also normalize
reverb = predicted_ir.copy()
reverb = scipy.signal.resample_poly(reverb, up=sample_rate, down=22050)
reverb /= np.max(np.abs(reverb))
# stereo?
if len(original_audio.shape) > 1:
wet_left = convolve(original_audio[:, 0], reverb)
wet_right = convolve(original_audio[:, 1], reverb)
wet_audio = np.concatenate([wet_left[:, None], wet_right[:, None]], axis=1)
else:
wet_audio = convolve(original_audio, reverb)
# 50% dry-wet mix
mixed_audio = wet_audio * 0.5
mixed_audio[:len(original_audio), ...] += original_audio * 0.9 * 0.5
# convert back to 16-bit signed integer
wet_audio = (wet_audio * 32767).astype(np.int16)
mixed_audio = (mixed_audio * 32767).astype(np.int16)
convolved_audio_100 = (sample_rate, wet_audio)
convolved_audio_50 = (sample_rate, mixed_audio)
return depthmap_fig, spectrogram_fig, waveform_fig, ir, convolved_audio_100, convolved_audio_50
title = "Image2Reverb: Cross-Modal Reverb Impulse Response Synthesis"
description = """
<b>Image2Reverb</b> predicts the acoustic reverberation of a given environment from a 2D image. <a href="https://arxiv.org/abs/2103.14201">Read the paper</a>
How to use: Choose an image of a room or other environment and an audio file.
The model will predict what the reverb of the room sounds like and applies this to the audio file.
First, the image is resized to 224ร—224. The monodepth model is used to predict a depthmap, which is added as an
additional channel to the image input. A ResNet-based encoder then converts the image into features, and
finally a GAN predicts the spectrogram of the reverb's impulse response.
<center><img src="file/model.jpg" width="870" height="297" alt="model architecture"></center>
The predicted impulse response is mono 22050 kHz. It is upsampled to the sampling rate of the audio
file and applied to both channels if the audio is stereo.
Generating the impulse response involves a certain amount of randomness, making it sound a little
different every time you try it.
"""
article = """
<div style='margin:20px auto;'>
<p>Based on original work by Nikhil Singh, Jeff Mentch, Jerry Ng, Matthew Beveridge, Iddo Drori.
<a href="https://web.media.mit.edu/~nsingh1/image2reverb/">Project Page</a> |
<a href="https://arxiv.org/abs/2103.14201">Paper</a> |
<a href="https://github.com/nikhilsinghmus/image2reverb">GitHub</a></p>
<pre>
@InProceedings{Singh_2021_ICCV,
author = {Singh, Nikhil and Mentch, Jeff and Ng, Jerry and Beveridge, Matthew and Drori, Iddo},
title = {Image2Reverb: Cross-Modal Reverb Impulse Response Synthesis},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2021},
pages = {286-295}
}
</pre>
<p>๐ŸŒ  Example images from <a href="https://web.media.mit.edu/~nsingh1/image2reverb/">the original project page</a>.</p>
<p>๐ŸŽถ Example sound from <a href="https://freesound.org/people/ashesanddreams/sounds/610414/">Ashes and Dreams @ freesound.org</a> (CC BY 4.0 license). This is a mono 48 kHz recording that has no reverb on it.</p>
</div>
"""
audio_example = "examples/ashesanddreams.wav"
examples = [
["examples/input.4e2f71f6.png", audio_example],
["examples/input.321eef38.png", audio_example],
["examples/input.2238dc21.png", audio_example],
["examples/input.4d280b40.png", audio_example],
["examples/input.0c3f5013.png", audio_example],
["examples/input.98773b90.png", audio_example],
["examples/input.ac61500f.png", audio_example],
["examples/input.5416407f.png", audio_example],
]
gr.Interface(
fn=predict,
inputs=[
gr.Image(label="Upload Image"),
gr.Audio(label="Upload Audio", source="upload", type="numpy"),
],
outputs=[
gr.Plot(label="Depthmap"),
gr.Plot(label="Impulse Response Spectrogram"),
gr.Plot(label="Impulse Response Waveform"),
gr.Audio(label="Impulse Response", type="numpy"),
gr.Audio(label="Output Audio (100% Wet)", type="numpy"),
gr.Audio(label="Output Audio (50% Dry, 50% Wet)", type="numpy"),
],
title=title,
description=description,
article=article,
examples=examples,
).launch()