haoheliu commited on
Commit
26c3c7a
·
verified ·
1 Parent(s): 56f0306

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -15
app.py CHANGED
@@ -8,7 +8,7 @@ import tempfile
8
  import numpy as np
9
  import os
10
 
11
- # Set MPS device if available (for Mac M-Series GPUs)
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  # Title and Description
@@ -22,7 +22,7 @@ Only the first 10 seconds of the audio will be processed.
22
  # Upload audio file
23
  uploaded_file = st.file_uploader("Upload an audio file (WAV format)", type=["wav"])
24
 
25
- # Model Parameters
26
  st.sidebar.title("Model Parameters")
27
  model_name = st.sidebar.selectbox("Select Model", ["basic", "speech"], index=0)
28
  ddim_steps = st.sidebar.slider("DDIM Steps", min_value=10, max_value=100, value=50)
@@ -30,18 +30,27 @@ guidance_scale = st.sidebar.slider("Guidance Scale", min_value=1.0, max_value=10
30
  random_seed = st.sidebar.number_input("Random Seed", min_value=0, value=42, step=1)
31
  latent_t_per_second = 12.8
32
 
33
- # Helper function to plot spectrogram
34
  def plot_spectrogram(waveform, sample_rate, title):
35
  plt.figure(figsize=(10, 4))
36
- spectrogram = torchaudio.transforms.MelSpectrogram(
37
- sample_rate=sample_rate, n_fft=2048, hop_length=512, n_mels=128
38
- )(torch.tensor(waveform))
39
- log_spectrogram = torchaudio.transforms.AmplitudeToDB()(spectrogram)
40
- plt.imshow(log_spectrogram.numpy(), aspect="auto", origin="lower", cmap="viridis")
 
 
 
 
 
 
 
 
 
41
  plt.colorbar(format="%+2.0f dB")
42
  plt.title(title)
43
- plt.xlabel("Time")
44
- plt.ylabel("Frequency")
45
  plt.tight_layout()
46
  st.pyplot(plt)
47
 
@@ -49,7 +58,6 @@ def plot_spectrogram(waveform, sample_rate, title):
49
  if uploaded_file and st.button("Enhance Audio"):
50
  st.write("Processing audio...")
51
 
52
- # Create temp directory for saving files
53
  with tempfile.TemporaryDirectory() as temp_dir:
54
  input_path = os.path.join(temp_dir, "input.wav")
55
  truncated_path = os.path.join(temp_dir, "truncated.wav")
@@ -59,7 +67,7 @@ if uploaded_file and st.button("Enhance Audio"):
59
  with open(input_path, "wb") as f:
60
  f.write(uploaded_file.read())
61
 
62
- # Load audio and truncate the first 10 seconds
63
  waveform, sample_rate = torchaudio.load(input_path)
64
  max_samples = sample_rate * 10 # First 10 seconds
65
  if waveform.size(1) > max_samples:
@@ -85,11 +93,12 @@ if uploaded_file and st.button("Enhance Audio"):
85
  )
86
 
87
  # Save enhanced audio
88
- save_wave(waveform_sr, inputpath=truncated_path, savepath=temp_dir, name="output", samplerate=48000)
 
89
 
90
- # Plot output spectrogram
91
  st.write("Enhanced Audio Spectrogram:")
92
- plot_spectrogram(waveform_sr.numpy(), 48000, title="Enhanced Audio Spectrogram")
93
 
94
  # Display audio players and download link
95
  st.audio(truncated_path, format="audio/wav")
 
8
  import numpy as np
9
  import os
10
 
11
+ # Set device (MPS for Mac, CUDA for other GPUs, otherwise CPU)
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  # Title and Description
 
22
  # Upload audio file
23
  uploaded_file = st.file_uploader("Upload an audio file (WAV format)", type=["wav"])
24
 
25
+ # Sidebar: Model Parameters
26
  st.sidebar.title("Model Parameters")
27
  model_name = st.sidebar.selectbox("Select Model", ["basic", "speech"], index=0)
28
  ddim_steps = st.sidebar.slider("DDIM Steps", min_value=10, max_value=100, value=50)
 
30
  random_seed = st.sidebar.number_input("Random Seed", min_value=0, value=42, step=1)
31
  latent_t_per_second = 12.8
32
 
33
+ # Helper function: Plot linear STFT spectrogram
34
  def plot_spectrogram(waveform, sample_rate, title):
35
  plt.figure(figsize=(10, 4))
36
+ spectrogram = torch.stft(
37
+ torch.tensor(waveform),
38
+ n_fft=2048,
39
+ hop_length=512,
40
+ win_length=2048,
41
+ return_complex=True,
42
+ ).abs().numpy()
43
+ plt.imshow(
44
+ np.log1p(spectrogram),
45
+ aspect="auto",
46
+ origin="lower",
47
+ extent=[0, waveform.shape[-1] / sample_rate, 0, sample_rate / 2],
48
+ cmap="viridis",
49
+ )
50
  plt.colorbar(format="%+2.0f dB")
51
  plt.title(title)
52
+ plt.xlabel("Time (s)")
53
+ plt.ylabel("Frequency (Hz)")
54
  plt.tight_layout()
55
  st.pyplot(plt)
56
 
 
58
  if uploaded_file and st.button("Enhance Audio"):
59
  st.write("Processing audio...")
60
 
 
61
  with tempfile.TemporaryDirectory() as temp_dir:
62
  input_path = os.path.join(temp_dir, "input.wav")
63
  truncated_path = os.path.join(temp_dir, "truncated.wav")
 
67
  with open(input_path, "wb") as f:
68
  f.write(uploaded_file.read())
69
 
70
+ # Load and truncate the first 10 seconds
71
  waveform, sample_rate = torchaudio.load(input_path)
72
  max_samples = sample_rate * 10 # First 10 seconds
73
  if waveform.size(1) > max_samples:
 
93
  )
94
 
95
  # Save enhanced audio
96
+ output_waveform = waveform_sr.detach().numpy()
97
+ save_wave(torch.tensor(output_waveform), inputpath=truncated_path, savepath=temp_dir, name="output", samplerate=48000)
98
 
99
+ # Plot enhanced spectrogram
100
  st.write("Enhanced Audio Spectrogram:")
101
+ plot_spectrogram(output_waveform, 48000, title="Enhanced Audio Spectrogram")
102
 
103
  # Display audio players and download link
104
  st.audio(truncated_path, format="audio/wav")