import streamlit as st import torchaudio import torch import matplotlib.pyplot as plt import soundfile as sf from audiosr import build_model, super_resolution, save_wave import tempfile import numpy as np # Set MPS device if available (for Mac M-Series GPUs) device = "cuda" if torch.cuda.is_available() else "cpu" # Title and Description st.title("AudioSR: Versatile Audio Super-Resolution") st.write(""" Upload your low-resolution audio files, and AudioSR will enhance them to high fidelity! Supports all types of audio (music, speech, sound effects, etc.) with arbitrary sampling rates. Only the first 10 seconds of the audio will be processed. """) # Upload audio file uploaded_file = st.file_uploader("Upload an audio file (WAV format)", type=["wav"]) # Model Parameters st.sidebar.title("Model Parameters") model_name = st.sidebar.selectbox("Select Model", ["basic", "speech"], index=0) ddim_steps = st.sidebar.slider("DDIM Steps", min_value=10, max_value=100, value=50) guidance_scale = st.sidebar.slider("Guidance Scale", min_value=1.0, max_value=10.0, value=3.5) random_seed = st.sidebar.number_input("Random Seed", min_value=0, value=42, step=1) latent_t_per_second = 12.8 # Helper function to plot spectrogram def plot_spectrogram(waveform, sample_rate, title): plt.figure(figsize=(10, 4)) spectrogram = torchaudio.transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=2048, hop_length=512, n_mels=128 )(torch.tensor(waveform)) log_spectrogram = torchaudio.transforms.AmplitudeToDB()(spectrogram) plt.imshow(log_spectrogram.numpy(), aspect="auto", origin="lower", cmap="viridis") plt.colorbar(format="%+2.0f dB") plt.title(title) plt.xlabel("Time") plt.ylabel("Frequency") plt.tight_layout() st.pyplot(plt) # Process Button if uploaded_file and st.button("Enhance Audio"): st.write("Processing audio...") # Create temp directory for saving files with tempfile.TemporaryDirectory() as temp_dir: input_path = os.path.join(temp_dir, "input.wav") truncated_path = os.path.join(temp_dir, "truncated.wav") output_path = os.path.join(temp_dir, "output.wav") # Save uploaded file locally with open(input_path, "wb") as f: f.write(uploaded_file.read()) # Load audio and truncate the first 10 seconds waveform, sample_rate = torchaudio.load(input_path) max_samples = sample_rate * 10 # First 10 seconds if waveform.size(1) > max_samples: waveform = waveform[:, :max_samples] st.write("Truncated audio to the first 10 seconds.") sf.write(truncated_path, waveform[0].numpy(), sample_rate) # Plot truncated spectrogram st.write("Truncated Input Audio Spectrogram (First 10 seconds):") plot_spectrogram(waveform[0].numpy(), sample_rate, title="Truncated Input Audio Spectrogram") # Build and load the model audiosr = build_model(model_name=model_name, device=device) # Perform super-resolution waveform_sr = super_resolution( audiosr, truncated_path, seed=random_seed, guidance_scale=guidance_scale, ddim_steps=ddim_steps, latent_t_per_second=latent_t_per_second, ) # Save enhanced audio save_wave(waveform_sr, inputpath=truncated_path, savepath=tmp_dir, name="output", samplerate=48000) # Plot output spectrogram st.write("Enhanced Audio Spectrogram:") plot_spectrogram(waveform_sr.numpy(), 48000, title="Enhanced Audio Spectrogram") # Display audio players and download link st.audio(truncated_path, format="audio/wav") st.write("Truncated Original Audio (First 10 seconds):") st.audio(output_path, format="audio/wav") st.write("Enhanced Audio:") st.download_button("Download Enhanced Audio", data=open(output_path, "rb").read(), file_name="enhanced_audio.wav") # Footer st.write("Built with [Streamlit](https://streamlit.io) and [AudioSR](https://audioldm.github.io/audiosr)")