haoheliu commited on
Commit
2b33988
·
verified ·
1 Parent(s): c2747d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -26
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import streamlit as st
 
2
  import torch
3
- import os
4
- import librosa
5
- import librosa.display
6
  import matplotlib.pyplot as plt
 
7
  from audiosr import build_model, super_resolution, save_wave
8
  import tempfile
9
  import numpy as np
@@ -31,48 +30,51 @@ random_seed = st.sidebar.number_input("Random Seed", min_value=0, value=42, step
31
  latent_t_per_second = 12.8
32
 
33
  # Helper function to plot spectrogram
34
- def plot_spectrogram(audio_path, title):
35
- y, sr = librosa.load(audio_path, sr=None)
36
- S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=sr // 2)
37
- S_dB = librosa.power_to_db(S, ref=np.max)
38
-
39
  plt.figure(figsize=(10, 4))
40
- librosa.display.specshow(S_dB, sr=sr, x_axis='time', y_axis='mel', fmax=sr // 2, cmap='viridis')
41
- plt.colorbar(format='%+2.0f dB')
 
 
 
 
42
  plt.title(title)
 
 
43
  plt.tight_layout()
44
- return plt
45
 
46
  # Process Button
47
  if uploaded_file and st.button("Enhance Audio"):
48
  st.write("Processing audio...")
49
 
50
  # Create temp directory for saving files
51
- with tempfile.TemporaryDirectory() as tmp_dir:
52
- input_path = os.path.join(tmp_dir, "input.wav")
53
- truncated_path = os.path.join(tmp_dir, "truncated.wav")
54
- output_path = os.path.join(tmp_dir, "output.wav")
55
 
56
  # Save uploaded file locally
57
  with open(input_path, "wb") as f:
58
  f.write(uploaded_file.read())
59
 
60
- # Load and truncate the first 10 seconds
61
- y, sr = librosa.load(input_path, sr=None)
62
- max_samples = sr * 10 # First 10 seconds
63
- y_truncated = y[:max_samples]
64
- librosa.output.write_wav(truncated_path, y_truncated, sr)
 
 
65
 
66
  # Plot truncated spectrogram
67
  st.write("Truncated Input Audio Spectrogram (First 10 seconds):")
68
- truncated_spectrogram = plot_spectrogram(truncated_path, title="Truncated Input Audio Spectrogram")
69
- st.pyplot(truncated_spectrogram)
70
 
71
  # Build and load the model
72
  audiosr = build_model(model_name=model_name, device=device)
73
 
74
  # Perform super-resolution
75
- waveform = super_resolution(
76
  audiosr,
77
  truncated_path,
78
  seed=random_seed,
@@ -82,12 +84,11 @@ if uploaded_file and st.button("Enhance Audio"):
82
  )
83
 
84
  # Save enhanced audio
85
- save_wave(waveform, inputpath=truncated_path, savepath=tmp_dir, name="output", samplerate=48000)
86
 
87
  # Plot output spectrogram
88
  st.write("Enhanced Audio Spectrogram:")
89
- output_spectrogram = plot_spectrogram(output_path, title="Enhanced Audio Spectrogram")
90
- st.pyplot(output_spectrogram)
91
 
92
  # Display audio players and download link
93
  st.audio(truncated_path, format="audio/wav")
 
1
  import streamlit as st
2
+ import torchaudio
3
  import torch
 
 
 
4
  import matplotlib.pyplot as plt
5
+ import soundfile as sf
6
  from audiosr import build_model, super_resolution, save_wave
7
  import tempfile
8
  import numpy as np
 
30
  latent_t_per_second = 12.8
31
 
32
  # Helper function to plot spectrogram
33
+ def plot_spectrogram(waveform, sample_rate, title):
 
 
 
 
34
  plt.figure(figsize=(10, 4))
35
+ spectrogram = torchaudio.transforms.MelSpectrogram(
36
+ sample_rate=sample_rate, n_fft=2048, hop_length=512, n_mels=128
37
+ )(torch.tensor(waveform))
38
+ log_spectrogram = torchaudio.transforms.AmplitudeToDB()(spectrogram)
39
+ plt.imshow(log_spectrogram.numpy(), aspect="auto", origin="lower", cmap="viridis")
40
+ plt.colorbar(format="%+2.0f dB")
41
  plt.title(title)
42
+ plt.xlabel("Time")
43
+ plt.ylabel("Frequency")
44
  plt.tight_layout()
45
+ st.pyplot(plt)
46
 
47
  # Process Button
48
  if uploaded_file and st.button("Enhance Audio"):
49
  st.write("Processing audio...")
50
 
51
  # Create temp directory for saving files
52
+ with tempfile.TemporaryDirectory() as temp_dir:
53
+ input_path = os.path.join(temp_dir, "input.wav")
54
+ truncated_path = os.path.join(temp_dir, "truncated.wav")
55
+ output_path = os.path.join(temp_dir, "output.wav")
56
 
57
  # Save uploaded file locally
58
  with open(input_path, "wb") as f:
59
  f.write(uploaded_file.read())
60
 
61
+ # Load audio and truncate the first 10 seconds
62
+ waveform, sample_rate = torchaudio.load(input_path)
63
+ max_samples = sample_rate * 10 # First 10 seconds
64
+ if waveform.size(1) > max_samples:
65
+ waveform = waveform[:, :max_samples]
66
+ st.write("Truncated audio to the first 10 seconds.")
67
+ sf.write(truncated_path, waveform[0].numpy(), sample_rate)
68
 
69
  # Plot truncated spectrogram
70
  st.write("Truncated Input Audio Spectrogram (First 10 seconds):")
71
+ plot_spectrogram(waveform[0].numpy(), sample_rate, title="Truncated Input Audio Spectrogram")
 
72
 
73
  # Build and load the model
74
  audiosr = build_model(model_name=model_name, device=device)
75
 
76
  # Perform super-resolution
77
+ waveform_sr = super_resolution(
78
  audiosr,
79
  truncated_path,
80
  seed=random_seed,
 
84
  )
85
 
86
  # Save enhanced audio
87
+ save_wave(waveform_sr, inputpath=truncated_path, savepath=tmp_dir, name="output", samplerate=48000)
88
 
89
  # Plot output spectrogram
90
  st.write("Enhanced Audio Spectrogram:")
91
+ plot_spectrogram(waveform_sr.numpy(), 48000, title="Enhanced Audio Spectrogram")
 
92
 
93
  # Display audio players and download link
94
  st.audio(truncated_path, format="audio/wav")