File size: 4,405 Bytes
d8595ba
2b33988
d8595ba
 
2b33988
d8595ba
 
 
11de0fa
d8595ba
26c3c7a
d8595ba
 
 
 
 
 
 
c2747d4
d8595ba
 
 
 
 
26c3c7a
d8595ba
 
 
 
 
 
 
38b7cd1
26c3c7a
2b33988
38b7cd1
 
 
 
d8595ba
26c3c7a
 
 
 
 
 
 
 
 
 
 
38b7cd1
26c3c7a
 
2b33988
d8595ba
26c3c7a
 
d8595ba
2b33988
d8595ba
38b7cd1
d8595ba
 
 
 
2b33988
 
 
 
d8595ba
 
 
 
 
26c3c7a
2b33988
 
 
 
 
 
c2747d4
 
 
2b33988
d8595ba
 
 
 
 
2b33988
d8595ba
c2747d4
d8595ba
 
 
 
 
 
 
b650afc
26c3c7a
d8595ba
26c3c7a
d8595ba
26c3c7a
d8595ba
 
c2747d4
 
d8595ba
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import streamlit as st
import torchaudio
import torch
import matplotlib.pyplot as plt
import soundfile as sf
from audiosr import build_model, super_resolution, save_wave
import tempfile
import numpy as np
import os

# Set device (MPS for Mac, CUDA for other GPUs, otherwise CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Title and Description
st.title("AudioSR: Versatile Audio Super-Resolution")
st.write("""
Upload your low-resolution audio files, and AudioSR will enhance them to high fidelity!
Supports all types of audio (music, speech, sound effects, etc.) with arbitrary sampling rates.
Only the first 10 seconds of the audio will be processed.
""")

# Upload audio file
uploaded_file = st.file_uploader("Upload an audio file (WAV format)", type=["wav"])

# Sidebar: Model Parameters
st.sidebar.title("Model Parameters")
model_name = st.sidebar.selectbox("Select Model", ["basic", "speech"], index=0)
ddim_steps = st.sidebar.slider("DDIM Steps", min_value=10, max_value=100, value=50)
guidance_scale = st.sidebar.slider("Guidance Scale", min_value=1.0, max_value=10.0, value=3.5)
random_seed = st.sidebar.number_input("Random Seed", min_value=0, value=42, step=1)
latent_t_per_second = 12.8

# Helper function: Plot linear STFT spectrogram
# Helper function: Plot linear STFT spectrogram
def plot_spectrogram(waveform, sample_rate, title):
    # Ensure waveform is a 1D tensor
    if len(waveform.shape) > 1:
        waveform = waveform.squeeze()  # Remove extra dimensions
    
    plt.figure(figsize=(10, 4))
    spectrogram = torch.stft(
        torch.tensor(waveform),
        n_fft=2048,
        hop_length=512,
        win_length=2048,
        return_complex=True,
    ).abs().numpy()
    plt.imshow(
        np.log1p(spectrogram),
        aspect="auto",
        origin="lower",
        extent=[0, len(waveform) / sample_rate, 0, sample_rate / 2],
        cmap="viridis",
    )
    plt.colorbar(format="%+2.0f dB")
    plt.title(title)
    plt.xlabel("Time (s)")
    plt.ylabel("Frequency (Hz)")
    plt.tight_layout()
    st.pyplot(plt)


# Process Button
if uploaded_file and st.button("Enhance Audio"):
    st.write("Processing audio...")

    with tempfile.TemporaryDirectory() as temp_dir:
        input_path = os.path.join(temp_dir, "input.wav")
        truncated_path = os.path.join(temp_dir, "truncated.wav")
        output_path = os.path.join(temp_dir, "output.wav")

        # Save uploaded file locally
        with open(input_path, "wb") as f:
            f.write(uploaded_file.read())

        # Load and truncate the first 10 seconds
        waveform, sample_rate = torchaudio.load(input_path)
        max_samples = sample_rate * 10  # First 10 seconds
        if waveform.size(1) > max_samples:
            waveform = waveform[:, :max_samples]
            st.write("Truncated audio to the first 10 seconds.")
        sf.write(truncated_path, waveform[0].numpy(), sample_rate)

        # Plot truncated spectrogram
        st.write("Truncated Input Audio Spectrogram (First 10 seconds):")
        plot_spectrogram(waveform[0].numpy(), sample_rate, title="Truncated Input Audio Spectrogram")

        # Build and load the model
        audiosr = build_model(model_name=model_name, device=device)

        # Perform super-resolution
        waveform_sr = super_resolution(
            audiosr,
            truncated_path,
            seed=random_seed,
            guidance_scale=guidance_scale,
            ddim_steps=ddim_steps,
            latent_t_per_second=latent_t_per_second,
        )

        # Save enhanced audio
        output_waveform = waveform_sr
        save_wave(torch.tensor(output_waveform), inputpath=truncated_path, savepath=temp_dir, name="output", samplerate=48000)

        # Plot enhanced spectrogram
        st.write("Enhanced Audio Spectrogram:")
        plot_spectrogram(output_waveform, 48000, title="Enhanced Audio Spectrogram")

        # Display audio players and download link
        st.audio(truncated_path, format="audio/wav")
        st.write("Truncated Original Audio (First 10 seconds):")
        
        st.audio(output_path, format="audio/wav")
        st.write("Enhanced Audio:")
        st.download_button("Download Enhanced Audio", data=open(output_path, "rb").read(), file_name="enhanced_audio.wav")

# Footer
st.write("Built with [Streamlit](https://streamlit.io) and [AudioSR](https://audioldm.github.io/audiosr)")