haoheliu's picture
Update app.py
c8989e8 verified
import streamlit as st
import torchaudio
import torch
import librosa
import librosa.display
import matplotlib.pyplot as plt
from semanticodec import SemantiCodec
import numpy as np
import tempfile
import os
# Set default parameters
DEFAULT_TOKEN_RATE = 100
DEFAULT_SEMANTIC_VOCAB_SIZE = 16384
DEFAULT_SAMPLE_RATE = 16000
MAX_DURATION_SECONDS = 30 # Maximum allowed duration
device = "cuda" if torch.cuda.is_available() else "cpu"
# Title and Description
st.title("SemantiCodec: Ultra-Low Bitrate Neural Audio Codec")
st.write("""
Upload your audio file (up to 30 seconds), adjust the codec parameters, and compare the original and reconstructed audio.
SemantiCodec achieves high-quality audio reconstruction with ultra-low bitrates!
""")
# Sidebar: Parameters
st.sidebar.title("Codec Parameters")
token_rate = st.sidebar.selectbox("Token Rate (tokens/sec)", [25, 50, 100], index=2)
semantic_vocab_size = st.sidebar.selectbox(
"Semantic Vocabulary Size",
[4096, 8192, 16384, 32768],
index=2,
)
ddim_steps = st.sidebar.slider("DDIM Sampling Steps", 10, 100, 50, step=5)
guidance_scale = st.sidebar.slider("CFG Guidance Scale", 0.5, 5.0, 2.0, step=0.1)
# Upload Audio File
uploaded_file = st.file_uploader("Upload an audio file (WAV format, up to 30 seconds)", type=["wav"])
# Helper function: Plot spectrogram
def plot_spectrogram(waveform, sample_rate, title):
plt.figure(figsize=(10, 4))
S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=128, fmax=sample_rate // 2)
S_dB = librosa.power_to_db(S, ref=np.max)
librosa.display.specshow(S_dB, sr=sample_rate, x_axis='time', y_axis='mel', cmap='viridis')
plt.colorbar(format='%+2.0f dB')
plt.title(title)
plt.tight_layout()
st.pyplot(plt)
# Process Audio
if uploaded_file and st.button("Run SemantiCodec"):
with tempfile.TemporaryDirectory() as temp_dir:
# Save uploaded file
input_path = os.path.join(temp_dir, "input.wav")
with open(input_path, "wb") as f:
f.write(uploaded_file.read())
# Load audio
waveform, sample_rate = torchaudio.load(input_path)
# Check if resampling is needed
if sample_rate != DEFAULT_SAMPLE_RATE:
st.write(f"Resampling audio from {sample_rate} Hz to {DEFAULT_SAMPLE_RATE} Hz...")
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=DEFAULT_SAMPLE_RATE)
waveform = resampler(waveform)
sample_rate = DEFAULT_SAMPLE_RATE # Update sample rate to 16kHz
# Check and limit duration
num_samples = waveform.size(1)
max_samples = MAX_DURATION_SECONDS * sample_rate # 30 seconds limit
if num_samples > max_samples:
st.write(f"Truncating audio to the first {MAX_DURATION_SECONDS} seconds...")
waveform = waveform[:, :max_samples]
# Convert to numpy for librosa compatibility
waveform_np = waveform[0].numpy()
# Plot Original Spectrogram (16kHz resampled and truncated)
st.write(f"Original Audio Spectrogram (Resampled and limited to {MAX_DURATION_SECONDS} seconds):")
plot_spectrogram(waveform_np, sample_rate, f"Original Audio Spectrogram (Resampled to {DEFAULT_SAMPLE_RATE} Hz)")
# Save truncated audio for processing
truncated_path = os.path.join(temp_dir, "truncated_input.wav")
torchaudio.save(truncated_path, waveform, sample_rate)
# Initialize SemantiCodec
st.write("Initializing SemantiCodec...")
semanticodec = SemantiCodec(
token_rate=token_rate,
semantic_vocab_size=semantic_vocab_size,
ddim_sample_step=ddim_steps,
cfg_scale=guidance_scale,
)
semanticodec.device = device
semanticodec.encoder = semanticodec.encoder.to(device)
semanticodec.decoder = semanticodec.decoder.to(device)
# Encode and Decode
st.write("Encoding and Decoding Audio...")
tokens = semanticodec.encode(truncated_path)
reconstructed_waveform = semanticodec.decode(tokens)[0, 0]
# Save reconstructed audio
reconstructed_path = os.path.join(temp_dir, "reconstructed.wav")
torchaudio.save(reconstructed_path, torch.tensor([reconstructed_waveform]), sample_rate)
# Plot Reconstructed Spectrogram
st.write("Reconstructed Audio Spectrogram:")
plot_spectrogram(reconstructed_waveform, sample_rate, "Reconstructed Audio Spectrogram")
# Display latent code shape
st.write(f"Shape of Latent Code: {tokens.shape}")
# Audio Players
st.audio(truncated_path, format="audio/wav")
st.write("Original Audio (Truncated)")
st.audio(reconstructed_path, format="audio/wav")
st.write("Reconstructed Audio")
# Download Button for Reconstructed Audio
st.download_button(
"Download Reconstructed Audio",
data=open(reconstructed_path, "rb").read(),
file_name="reconstructed_audio.wav",
)
# Footer
st.write("Built with [Streamlit](https://streamlit.io) and SemantiCodec")