noisyKickGAN / app.py
ghostofdivinity's picture
Update app.py
a94f3c2
raw
history blame
1.84 kB
import gradio as gr
import torch
import torchaudio
from torch import nn
# Load the saved generator model
class Generator(nn.Module):
def __init__(self, latent_dim):
super(Generator, self).__init__()
self.generator = nn.Sequential(
nn.Linear(latent_dim, 1024),
nn.ReLU(),
nn.Linear(1024, 4096),
nn.ReLU(),
nn.Linear(4096, 8192),
nn.Tanh()
)
def forward(self, x):
return self.generator(x)
latent_dim = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(latent_dim).to(device)
generator_model_path = 'generator_model.pkl'
generator.load_state_dict(torch.load(generator_model_path, map_location=device))
def generate_kick_drums():
# Define the number of samples you want to generate
num_generated_samples = 3
output_files = []
# Generate new kick drum samples
generator.eval()
with torch.no_grad():
for i in range(num_generated_samples):
noise = torch.randn(1, latent_dim).to(device)
generated_sample = generator(noise).squeeze().cpu()
# Save the generated sample
output_filename = f"generated_kick_{i+1}.wav"
torchaudio.save(output_filename, generated_sample.unsqueeze(0), 16000)
output_files.append(output_filename)
return tuple(output_files)
# Define Gradio interface
def gradio_interface():
generate_button = gr.Interface(fn=generate_kick_drums,
inputs=None,
outputs=[gr.Audio(type='filepath', label=f"generated_kick_{i}") for i in range(3)],
live=True)
generate_button.launch(debug=True)
# Run the Gradio interface
gradio_interface()