File size: 5,156 Bytes
1f34ab8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8989e8
1f34ab8
 
 
 
 
c8989e8
1f34ab8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8989e8
1f34ab8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8989e8
1f34ab8
 
 
 
 
 
 
c8989e8
 
 
 
 
 
 
1f34ab8
c8989e8
1f34ab8
c8989e8
 
 
 
 
 
 
1f34ab8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8989e8
1f34ab8
 
 
 
 
 
 
 
 
 
 
 
 
 
c8989e8
 
1f34ab8
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import streamlit as st
import torchaudio
import torch
import librosa
import librosa.display
import matplotlib.pyplot as plt
from semanticodec import SemantiCodec
import numpy as np
import tempfile
import os

# Set default parameters
DEFAULT_TOKEN_RATE = 100
DEFAULT_SEMANTIC_VOCAB_SIZE = 16384
DEFAULT_SAMPLE_RATE = 16000
MAX_DURATION_SECONDS = 30  # Maximum allowed duration
device = "cuda" if torch.cuda.is_available() else "cpu"

# Title and Description
st.title("SemantiCodec: Ultra-Low Bitrate Neural Audio Codec")
st.write("""
Upload your audio file (up to 30 seconds), adjust the codec parameters, and compare the original and reconstructed audio. 
SemantiCodec achieves high-quality audio reconstruction with ultra-low bitrates!
""")

# Sidebar: Parameters
st.sidebar.title("Codec Parameters")
token_rate = st.sidebar.selectbox("Token Rate (tokens/sec)", [25, 50, 100], index=2)
semantic_vocab_size = st.sidebar.selectbox(
    "Semantic Vocabulary Size",
    [4096, 8192, 16384, 32768],
    index=2,
)
ddim_steps = st.sidebar.slider("DDIM Sampling Steps", 10, 100, 50, step=5)
guidance_scale = st.sidebar.slider("CFG Guidance Scale", 0.5, 5.0, 2.0, step=0.1)

# Upload Audio File
uploaded_file = st.file_uploader("Upload an audio file (WAV format, up to 30 seconds)", type=["wav"])

# Helper function: Plot spectrogram
def plot_spectrogram(waveform, sample_rate, title):
    plt.figure(figsize=(10, 4))
    S = librosa.feature.melspectrogram(y=waveform, sr=sample_rate, n_mels=128, fmax=sample_rate // 2)
    S_dB = librosa.power_to_db(S, ref=np.max)
    librosa.display.specshow(S_dB, sr=sample_rate, x_axis='time', y_axis='mel', cmap='viridis')
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.tight_layout()
    st.pyplot(plt)

# Process Audio
if uploaded_file and st.button("Run SemantiCodec"):
    with tempfile.TemporaryDirectory() as temp_dir:
        # Save uploaded file
        input_path = os.path.join(temp_dir, "input.wav")
        with open(input_path, "wb") as f:
            f.write(uploaded_file.read())

        # Load audio
        waveform, sample_rate = torchaudio.load(input_path)

        # Check if resampling is needed
        if sample_rate != DEFAULT_SAMPLE_RATE:
            st.write(f"Resampling audio from {sample_rate} Hz to {DEFAULT_SAMPLE_RATE} Hz...")
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=DEFAULT_SAMPLE_RATE)
            waveform = resampler(waveform)
            sample_rate = DEFAULT_SAMPLE_RATE  # Update sample rate to 16kHz
        
        # Check and limit duration
        num_samples = waveform.size(1)
        max_samples = MAX_DURATION_SECONDS * sample_rate  # 30 seconds limit
        if num_samples > max_samples:
            st.write(f"Truncating audio to the first {MAX_DURATION_SECONDS} seconds...")
            waveform = waveform[:, :max_samples]

        # Convert to numpy for librosa compatibility
        waveform_np = waveform[0].numpy()

        # Plot Original Spectrogram (16kHz resampled and truncated)
        st.write(f"Original Audio Spectrogram (Resampled and limited to {MAX_DURATION_SECONDS} seconds):")
        plot_spectrogram(waveform_np, sample_rate, f"Original Audio Spectrogram (Resampled to {DEFAULT_SAMPLE_RATE} Hz)")

        # Save truncated audio for processing
        truncated_path = os.path.join(temp_dir, "truncated_input.wav")
        torchaudio.save(truncated_path, waveform, sample_rate)

        # Initialize SemantiCodec
        st.write("Initializing SemantiCodec...")
        semanticodec = SemantiCodec(
            token_rate=token_rate,
            semantic_vocab_size=semantic_vocab_size,
            ddim_sample_step=ddim_steps,
            cfg_scale=guidance_scale,
        )
        semanticodec.device = device
        semanticodec.encoder = semanticodec.encoder.to(device)
        semanticodec.decoder = semanticodec.decoder.to(device)

        # Encode and Decode
        st.write("Encoding and Decoding Audio...")
        tokens = semanticodec.encode(truncated_path)
        reconstructed_waveform = semanticodec.decode(tokens)[0, 0]

        # Save reconstructed audio
        reconstructed_path = os.path.join(temp_dir, "reconstructed.wav")
        torchaudio.save(reconstructed_path, torch.tensor([reconstructed_waveform]), sample_rate)

        # Plot Reconstructed Spectrogram
        st.write("Reconstructed Audio Spectrogram:")
        plot_spectrogram(reconstructed_waveform, sample_rate, "Reconstructed Audio Spectrogram")

        # Display latent code shape
        st.write(f"Shape of Latent Code: {tokens.shape}")

        # Audio Players
        st.audio(truncated_path, format="audio/wav")
        st.write("Original Audio (Truncated)")
        st.audio(reconstructed_path, format="audio/wav")
        st.write("Reconstructed Audio")
        
        # Download Button for Reconstructed Audio
        st.download_button(
            "Download Reconstructed Audio",
            data=open(reconstructed_path, "rb").read(),
            file_name="reconstructed_audio.wav",
        )

# Footer
st.write("Built with [Streamlit](https://streamlit.io) and SemantiCodec")