Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
import streamlit as st
|
|
|
2 |
import torch
|
3 |
-
import os
|
4 |
-
import librosa
|
5 |
-
import librosa.display
|
6 |
import matplotlib.pyplot as plt
|
|
|
7 |
from audiosr import build_model, super_resolution, save_wave
|
8 |
import tempfile
|
9 |
import numpy as np
|
@@ -31,48 +30,51 @@ random_seed = st.sidebar.number_input("Random Seed", min_value=0, value=42, step
|
|
31 |
latent_t_per_second = 12.8
|
32 |
|
33 |
# Helper function to plot spectrogram
|
34 |
-
def plot_spectrogram(
|
35 |
-
y, sr = librosa.load(audio_path, sr=None)
|
36 |
-
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=sr // 2)
|
37 |
-
S_dB = librosa.power_to_db(S, ref=np.max)
|
38 |
-
|
39 |
plt.figure(figsize=(10, 4))
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
plt.title(title)
|
|
|
|
|
43 |
plt.tight_layout()
|
44 |
-
|
45 |
|
46 |
# Process Button
|
47 |
if uploaded_file and st.button("Enhance Audio"):
|
48 |
st.write("Processing audio...")
|
49 |
|
50 |
# Create temp directory for saving files
|
51 |
-
with tempfile.TemporaryDirectory() as
|
52 |
-
input_path = os.path.join(
|
53 |
-
truncated_path = os.path.join(
|
54 |
-
output_path = os.path.join(
|
55 |
|
56 |
# Save uploaded file locally
|
57 |
with open(input_path, "wb") as f:
|
58 |
f.write(uploaded_file.read())
|
59 |
|
60 |
-
# Load and truncate the first 10 seconds
|
61 |
-
|
62 |
-
max_samples =
|
63 |
-
|
64 |
-
|
|
|
|
|
65 |
|
66 |
# Plot truncated spectrogram
|
67 |
st.write("Truncated Input Audio Spectrogram (First 10 seconds):")
|
68 |
-
|
69 |
-
st.pyplot(truncated_spectrogram)
|
70 |
|
71 |
# Build and load the model
|
72 |
audiosr = build_model(model_name=model_name, device=device)
|
73 |
|
74 |
# Perform super-resolution
|
75 |
-
|
76 |
audiosr,
|
77 |
truncated_path,
|
78 |
seed=random_seed,
|
@@ -82,12 +84,11 @@ if uploaded_file and st.button("Enhance Audio"):
|
|
82 |
)
|
83 |
|
84 |
# Save enhanced audio
|
85 |
-
save_wave(
|
86 |
|
87 |
# Plot output spectrogram
|
88 |
st.write("Enhanced Audio Spectrogram:")
|
89 |
-
|
90 |
-
st.pyplot(output_spectrogram)
|
91 |
|
92 |
# Display audio players and download link
|
93 |
st.audio(truncated_path, format="audio/wav")
|
|
|
1 |
import streamlit as st
|
2 |
+
import torchaudio
|
3 |
import torch
|
|
|
|
|
|
|
4 |
import matplotlib.pyplot as plt
|
5 |
+
import soundfile as sf
|
6 |
from audiosr import build_model, super_resolution, save_wave
|
7 |
import tempfile
|
8 |
import numpy as np
|
|
|
30 |
latent_t_per_second = 12.8
|
31 |
|
32 |
# Helper function to plot spectrogram
|
33 |
+
def plot_spectrogram(waveform, sample_rate, title):
|
|
|
|
|
|
|
|
|
34 |
plt.figure(figsize=(10, 4))
|
35 |
+
spectrogram = torchaudio.transforms.MelSpectrogram(
|
36 |
+
sample_rate=sample_rate, n_fft=2048, hop_length=512, n_mels=128
|
37 |
+
)(torch.tensor(waveform))
|
38 |
+
log_spectrogram = torchaudio.transforms.AmplitudeToDB()(spectrogram)
|
39 |
+
plt.imshow(log_spectrogram.numpy(), aspect="auto", origin="lower", cmap="viridis")
|
40 |
+
plt.colorbar(format="%+2.0f dB")
|
41 |
plt.title(title)
|
42 |
+
plt.xlabel("Time")
|
43 |
+
plt.ylabel("Frequency")
|
44 |
plt.tight_layout()
|
45 |
+
st.pyplot(plt)
|
46 |
|
47 |
# Process Button
|
48 |
if uploaded_file and st.button("Enhance Audio"):
|
49 |
st.write("Processing audio...")
|
50 |
|
51 |
# Create temp directory for saving files
|
52 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
53 |
+
input_path = os.path.join(temp_dir, "input.wav")
|
54 |
+
truncated_path = os.path.join(temp_dir, "truncated.wav")
|
55 |
+
output_path = os.path.join(temp_dir, "output.wav")
|
56 |
|
57 |
# Save uploaded file locally
|
58 |
with open(input_path, "wb") as f:
|
59 |
f.write(uploaded_file.read())
|
60 |
|
61 |
+
# Load audio and truncate the first 10 seconds
|
62 |
+
waveform, sample_rate = torchaudio.load(input_path)
|
63 |
+
max_samples = sample_rate * 10 # First 10 seconds
|
64 |
+
if waveform.size(1) > max_samples:
|
65 |
+
waveform = waveform[:, :max_samples]
|
66 |
+
st.write("Truncated audio to the first 10 seconds.")
|
67 |
+
sf.write(truncated_path, waveform[0].numpy(), sample_rate)
|
68 |
|
69 |
# Plot truncated spectrogram
|
70 |
st.write("Truncated Input Audio Spectrogram (First 10 seconds):")
|
71 |
+
plot_spectrogram(waveform[0].numpy(), sample_rate, title="Truncated Input Audio Spectrogram")
|
|
|
72 |
|
73 |
# Build and load the model
|
74 |
audiosr = build_model(model_name=model_name, device=device)
|
75 |
|
76 |
# Perform super-resolution
|
77 |
+
waveform_sr = super_resolution(
|
78 |
audiosr,
|
79 |
truncated_path,
|
80 |
seed=random_seed,
|
|
|
84 |
)
|
85 |
|
86 |
# Save enhanced audio
|
87 |
+
save_wave(waveform_sr, inputpath=truncated_path, savepath=tmp_dir, name="output", samplerate=48000)
|
88 |
|
89 |
# Plot output spectrogram
|
90 |
st.write("Enhanced Audio Spectrogram:")
|
91 |
+
plot_spectrogram(waveform_sr.numpy(), 48000, title="Enhanced Audio Spectrogram")
|
|
|
92 |
|
93 |
# Display audio players and download link
|
94 |
st.audio(truncated_path, format="audio/wav")
|