|
import streamlit as st |
|
import torchaudio |
|
import torch |
|
import librosa |
|
import librosa.display |
|
import matplotlib.pyplot as plt |
|
from semanticodec import SemantiCodec |
|
import numpy as np |
|
import tempfile |
|
import os |
|
|
|
|
|
DEFAULT_TOKEN_RATE = 100 |
|
DEFAULT_SEMANTIC_VOCAB_SIZE = 16384 |
|
DEFAULT_SAMPLE_RATE = 16000 |
|
MAX_DURATION_SECONDS = 30 |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
st.title("SemantiCodec: Ultra-Low Bitrate Neural Audio Codec") |
|
st.write(""" |
|
Upload your audio file (up to 30 seconds), adjust the codec parameters, and compare the original and reconstructed audio. |
|
SemantiCodec achieves high-quality audio reconstruction with ultra-low bitrates! |
|
""") |
|
|
|
|
|
st.sidebar.title("Codec Parameters") |
|
token_rate = st.sidebar.selectbox("Token Rate (tokens/sec)", [25, 50, 100], index=2) |
|
semantic_vocab_size = st.sidebar.selectbox( |
|
"Semantic Vocabulary Size", |
|
[4096, 8192, 16384, 32768], |
|
index=2, |
|
) |
|
ddim_steps = st.sidebar.slider("DDIM Sampling Steps", 10, 100, 50, step=5) |
|
guidance_scale = st.sidebar.slider("CFG Guidance Scale", 0.5, 5.0, 2.0, step=0.1) |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload an audio file (WAV format, up to 30 seconds)", type=["wav"]) |
|
|
|
|
|
def plot_spectrogram(waveform, sample_rate, title): |
|
plt.figure(figsize=(10, 4)) |
|
S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=128, fmax=sample_rate // 2) |
|
S_dB = librosa.power_to_db(S, ref=np.max) |
|
librosa.display.specshow(S_dB, sr=sample_rate, x_axis='time', y_axis='mel', cmap='viridis') |
|
plt.colorbar(format='%+2.0f dB') |
|
plt.title(title) |
|
plt.tight_layout() |
|
st.pyplot(plt) |
|
|
|
|
|
if uploaded_file and st.button("Run SemantiCodec"): |
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
input_path = os.path.join(temp_dir, "input.wav") |
|
with open(input_path, "wb") as f: |
|
f.write(uploaded_file.read()) |
|
|
|
|
|
waveform, sample_rate = torchaudio.load(input_path) |
|
|
|
|
|
if sample_rate != DEFAULT_SAMPLE_RATE: |
|
st.write(f"Resampling audio from {sample_rate} Hz to {DEFAULT_SAMPLE_RATE} Hz...") |
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=DEFAULT_SAMPLE_RATE) |
|
waveform = resampler(waveform) |
|
sample_rate = DEFAULT_SAMPLE_RATE |
|
|
|
|
|
num_samples = waveform.size(1) |
|
max_samples = MAX_DURATION_SECONDS * sample_rate |
|
if num_samples > max_samples: |
|
st.write(f"Truncating audio to the first {MAX_DURATION_SECONDS} seconds...") |
|
waveform = waveform[:, :max_samples] |
|
|
|
|
|
waveform_np = waveform[0].numpy() |
|
|
|
|
|
st.write(f"Original Audio Spectrogram (Resampled and limited to {MAX_DURATION_SECONDS} seconds):") |
|
plot_spectrogram(waveform_np, sample_rate, f"Original Audio Spectrogram (Resampled to {DEFAULT_SAMPLE_RATE} Hz)") |
|
|
|
|
|
truncated_path = os.path.join(temp_dir, "truncated_input.wav") |
|
torchaudio.save(truncated_path, waveform, sample_rate) |
|
|
|
|
|
st.write("Initializing SemantiCodec...") |
|
semanticodec = SemantiCodec( |
|
token_rate=token_rate, |
|
semantic_vocab_size=semantic_vocab_size, |
|
ddim_sample_step=ddim_steps, |
|
cfg_scale=guidance_scale, |
|
) |
|
semanticodec.device = device |
|
semanticodec.encoder = semanticodec.encoder.to(device) |
|
semanticodec.decoder = semanticodec.decoder.to(device) |
|
|
|
|
|
st.write("Encoding and Decoding Audio...") |
|
tokens = semanticodec.encode(truncated_path) |
|
reconstructed_waveform = semanticodec.decode(tokens)[0, 0] |
|
|
|
|
|
reconstructed_path = os.path.join(temp_dir, "reconstructed.wav") |
|
torchaudio.save(reconstructed_path, torch.tensor([reconstructed_waveform]), sample_rate) |
|
|
|
|
|
st.write("Reconstructed Audio Spectrogram:") |
|
plot_spectrogram(reconstructed_waveform, sample_rate, "Reconstructed Audio Spectrogram") |
|
|
|
|
|
st.write(f"Shape of Latent Code: {tokens.shape}") |
|
|
|
|
|
st.audio(truncated_path, format="audio/wav") |
|
st.write("Original Audio (Truncated)") |
|
st.audio(reconstructed_path, format="audio/wav") |
|
st.write("Reconstructed Audio") |
|
|
|
|
|
st.download_button( |
|
"Download Reconstructed Audio", |
|
data=open(reconstructed_path, "rb").read(), |
|
file_name="reconstructed_audio.wav", |
|
) |
|
|
|
|
|
st.write("Built with [Streamlit](https://streamlit.io) and SemantiCodec") |
|
|