import streamlit as st import torchaudio import torch import librosa import librosa.display import matplotlib.pyplot as plt from semanticodec import SemantiCodec import numpy as np import tempfile import os # Set default parameters DEFAULT_TOKEN_RATE = 100 DEFAULT_SEMANTIC_VOCAB_SIZE = 16384 DEFAULT_SAMPLE_RATE = 16000 MAX_DURATION_SECONDS = 30 # Maximum allowed duration device = "cuda" if torch.cuda.is_available() else "cpu" # Title and Description st.title("SemantiCodec: Ultra-Low Bitrate Neural Audio Codec") st.write(""" Upload your audio file (up to 30 seconds), adjust the codec parameters, and compare the original and reconstructed audio. SemantiCodec achieves high-quality audio reconstruction with ultra-low bitrates! """) # Sidebar: Parameters st.sidebar.title("Codec Parameters") token_rate = st.sidebar.selectbox("Token Rate (tokens/sec)", [25, 50, 100], index=2) semantic_vocab_size = st.sidebar.selectbox( "Semantic Vocabulary Size", [4096, 8192, 16384, 32768], index=2, ) ddim_steps = st.sidebar.slider("DDIM Sampling Steps", 10, 100, 50, step=5) guidance_scale = st.sidebar.slider("CFG Guidance Scale", 0.5, 5.0, 2.0, step=0.1) # Upload Audio File uploaded_file = st.file_uploader("Upload an audio file (WAV format, up to 30 seconds)", type=["wav"]) # Helper function: Plot spectrogram def plot_spectrogram(waveform, sample_rate, title): plt.figure(figsize=(10, 4)) S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=128, fmax=sample_rate // 2) S_dB = librosa.power_to_db(S, ref=np.max) librosa.display.specshow(S_dB, sr=sample_rate, x_axis='time', y_axis='mel', cmap='viridis') plt.colorbar(format='%+2.0f dB') plt.title(title) plt.tight_layout() st.pyplot(plt) # Process Audio if uploaded_file and st.button("Run SemantiCodec"): with tempfile.TemporaryDirectory() as temp_dir: # Save uploaded file input_path = os.path.join(temp_dir, "input.wav") with open(input_path, "wb") as f: f.write(uploaded_file.read()) # Load audio waveform, sample_rate = torchaudio.load(input_path) # Check if resampling is needed if sample_rate != DEFAULT_SAMPLE_RATE: st.write(f"Resampling audio from {sample_rate} Hz to {DEFAULT_SAMPLE_RATE} Hz...") resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=DEFAULT_SAMPLE_RATE) waveform = resampler(waveform) sample_rate = DEFAULT_SAMPLE_RATE # Update sample rate to 16kHz # Check and limit duration num_samples = waveform.size(1) max_samples = MAX_DURATION_SECONDS * sample_rate # 30 seconds limit if num_samples > max_samples: st.write(f"Truncating audio to the first {MAX_DURATION_SECONDS} seconds...") waveform = waveform[:, :max_samples] # Convert to numpy for librosa compatibility waveform_np = waveform[0].numpy() # Plot Original Spectrogram (16kHz resampled and truncated) st.write(f"Original Audio Spectrogram (Resampled and limited to {MAX_DURATION_SECONDS} seconds):") plot_spectrogram(waveform_np, sample_rate, f"Original Audio Spectrogram (Resampled to {DEFAULT_SAMPLE_RATE} Hz)") # Save truncated audio for processing truncated_path = os.path.join(temp_dir, "truncated_input.wav") torchaudio.save(truncated_path, waveform, sample_rate) # Initialize SemantiCodec st.write("Initializing SemantiCodec...") semanticodec = SemantiCodec( token_rate=token_rate, semantic_vocab_size=semantic_vocab_size, ddim_sample_step=ddim_steps, cfg_scale=guidance_scale, ) semanticodec.device = device semanticodec.encoder = semanticodec.encoder.to(device) semanticodec.decoder = semanticodec.decoder.to(device) # Encode and Decode st.write("Encoding and Decoding Audio...") tokens = semanticodec.encode(truncated_path) reconstructed_waveform = semanticodec.decode(tokens)[0, 0] # Save reconstructed audio reconstructed_path = os.path.join(temp_dir, "reconstructed.wav") torchaudio.save(reconstructed_path, torch.tensor([reconstructed_waveform]), sample_rate) # Plot Reconstructed Spectrogram st.write("Reconstructed Audio Spectrogram:") plot_spectrogram(reconstructed_waveform, sample_rate, "Reconstructed Audio Spectrogram") # Display latent code shape st.write(f"Shape of Latent Code: {tokens.shape}") # Audio Players st.audio(truncated_path, format="audio/wav") st.write("Original Audio (Truncated)") st.audio(reconstructed_path, format="audio/wav") st.write("Reconstructed Audio") # Download Button for Reconstructed Audio st.download_button( "Download Reconstructed Audio", data=open(reconstructed_path, "rb").read(), file_name="reconstructed_audio.wav", ) # Footer st.write("Built with [Streamlit](https://streamlit.io) and SemantiCodec")