noisyKickGAN / app.py
ghostofdivinity's picture
Update app.py
bbd0842
raw
history blame
2.46 kB
import gradio as gr
import os
import torch
import torchaudio
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
import IPython.display as ipd
# Generate new kick drum samples
generator.eval()
with torch.no_grad():
for i in range(num_generated_samples):
noise = torch.randn(1, latent_dim).to(device)
generated_sample = generator(noise).squeeze().cpu()
# Save the generated sample
output_filename = f"generated_kick_{i+1}.wav"
torchaudio.save(output_filename, generated_sample.unsqueeze(0), 16000)
# Play the generated sample
print(f"Generated Sample {i+1}:")
display(ipd.Audio(output_filename))
# Load the saved generator model
class Generator(nn.Module):
def __init__(self, latent_dim):
super(Generator, self).__init__()
self.generator = nn.Sequential(
nn.Linear(latent_dim, 1024),
nn.ReLU(),
nn.Linear(1024, 4096),
nn.ReLU(),
nn.Linear(4096, 8192),
nn.Tanh()
)
def forward(self, x):
return self.generator(x)
latent_dim = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(latent_dim).to(device)
generator_model_path = "generator_model.pkl"
generator.load_state_dict(torch.load(generator_model_path))
def generate_kick_drums():
# Define the number of samples you want to generate
num_generated_samples = 3
output_files = []
# Generate new kick drum samples
generator.eval()
with torch.no_grad():
for i in range(num_generated_samples):
noise = torch.randn(1, latent_dim).to(device)
generated_sample = generator(noise).squeeze().cpu()
# Save the generated sample
output_filename = f"generated_kick_{i+1}.wav"
torchaudio.save(output_filename, generated_sample.unsqueeze(0), 16000)
output_files.append(output_filename)
return tuple(output_files)
def gradio_interface():
generate_button = gr.Interface(fn=generate_kick_drums,
inputs=None,
outputs=[gr.Audio(type='filepath', label=f"generated_kick_{i+1}") for i in range(3)],
live=True)
generate_button.launch(debug=True)
# Run the Gradio interface
gradio_interface()