anzorq commited on
Commit
a28d64a
1 Parent(s): c6001c1

+spectral gating filter

Browse files
Files changed (1) hide show
  1. app.py +28 -12
app.py CHANGED
@@ -35,7 +35,12 @@ def preprocess_audio(audio_tensor, original_sample_rate, apply_normalization):
35
  audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=original_sample_rate, new_freq=16000) # Resample
36
  return audio_tensor
37
 
38
- def apply_wiener_filter(audio_tensor):
 
 
 
 
 
39
  audio_data = audio_tensor.numpy()
40
  filtered_audio = wiener(audio_data)
41
  return torch.tensor(filtered_audio, dtype=audio_tensor.dtype)
@@ -43,13 +48,13 @@ def apply_wiener_filter(audio_tensor):
43
  @spaces.GPU
44
  def transcribe_speech(audio, progress=gr.Progress()):
45
  if audio is None:
46
- return "No audio received."
47
  progress(0.5, desc="Transcribing audio...")
48
  audio_np = audio.numpy().squeeze()
49
  transcription = pipe(audio_np, chunk_length_s=10)['text']
50
- return replace_symbols_back(transcription)
51
 
52
- def transcribe_from_youtube(url, apply_wiener, apply_normalization, progress=gr.Progress()):
53
  progress(0, "Downloading YouTube audio...")
54
 
55
  yt = YouTube(url)
@@ -62,16 +67,24 @@ def transcribe_from_youtube(url, apply_wiener, apply_normalization, progress=gr.
62
  audio, original_sample_rate = torchaudio.load(audio_data)
63
  audio = preprocess_audio(audio, original_sample_rate, apply_normalization)
64
 
65
- if apply_wiener:
66
  progress(0.4, "Applying Wiener filter...")
67
- audio = apply_wiener_filter(audio)
 
 
 
 
68
 
69
- transcription = transcribe_speech(audio)
 
 
 
 
70
 
71
  except Exception as e:
72
- return str(e)
73
 
74
- return transcription
75
 
76
  def populate_metadata(url):
77
  yt = YouTube(url)
@@ -96,16 +109,18 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
96
  mic_audio = gr.Audio(sources=['microphone','upload'], type="filepath", label="Record or upload an audio")
97
  transcribe_button = gr.Button("Transcribe")
98
  transcription_output = gr.Textbox(label="Transcription")
 
99
 
100
- transcribe_button.click(fn=transcribe_speech, inputs=mic_audio, outputs=transcription_output)
101
 
102
  with gr.Tab("YouTube URL"):
103
  gr.Markdown("## Transcribe speech from YouTube video")
104
  youtube_url = gr.Textbox(label="Enter YouTube video URL")
105
 
106
  with gr.Accordion("Audio Improvements", open=False):
107
- apply_wiener = gr.Checkbox(label="Reduce noise", info="Apply Wiener Filter", value=False)
108
  apply_normalization = gr.Checkbox(label="Normalize audio volume", value=True)
 
 
109
 
110
  with gr.Row():
111
  img = gr.Image(label="Thumbnail", height=240, width=240, scale=1)
@@ -113,8 +128,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
113
 
114
  transcribe_button = gr.Button("Transcribe")
115
  transcription_output = gr.Textbox(label="Transcription", placeholder="Transcription Output", lines=10)
 
116
 
117
- transcribe_button.click(fn=transcribe_from_youtube, inputs=[youtube_url, apply_wiener, apply_normalization], outputs=transcription_output)
118
  youtube_url.change(populate_metadata, inputs=[youtube_url], outputs=[img, title])
119
 
120
  demo.launch()
 
35
  audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=original_sample_rate, new_freq=16000) # Resample
36
  return audio_tensor
37
 
38
+ def spectral_gating(audio_tensor):
39
+ audio_data = audio_tensor.numpy()
40
+ reduced_noise = nr.reduce_noise(y=audio_data, sr=16_000)
41
+ return torch.tensor(reduced_noise, dtype=audio_tensor.dtype)
42
+
43
+ def wiener_filter(audio_tensor):
44
  audio_data = audio_tensor.numpy()
45
  filtered_audio = wiener(audio_data)
46
  return torch.tensor(filtered_audio, dtype=audio_tensor.dtype)
 
48
  @spaces.GPU
49
  def transcribe_speech(audio, progress=gr.Progress()):
50
  if audio is None:
51
+ return "No audio received.", None
52
  progress(0.5, desc="Transcribing audio...")
53
  audio_np = audio.numpy().squeeze()
54
  transcription = pipe(audio_np, chunk_length_s=10)['text']
55
+ return replace_symbols_back(transcription), audio
56
 
57
+ def transcribe_from_youtube(url, apply_wiener_filter, apply_normalization, apply_spectral_gating, progress=gr.Progress()):
58
  progress(0, "Downloading YouTube audio...")
59
 
60
  yt = YouTube(url)
 
67
  audio, original_sample_rate = torchaudio.load(audio_data)
68
  audio = preprocess_audio(audio, original_sample_rate, apply_normalization)
69
 
70
+ if apply_wiener_filter:
71
  progress(0.4, "Applying Wiener filter...")
72
+ audio = wiener_filter(audio)
73
+
74
+ if apply_spectral_gating:
75
+ progress(0.4, "Applying Spectral Gating filter...")
76
+ audio = spectral_gating(audio)
77
 
78
+ transcription, processed_audio = transcribe_speech(audio)
79
+ audio_np = processed_audio.numpy().squeeze()
80
+ audio_output = BytesIO()
81
+ torchaudio.save(audio_output, torch.tensor(audio_np).unsqueeze(0), 16000)
82
+ audio_output.seek(0)
83
 
84
  except Exception as e:
85
+ return str(e), None
86
 
87
+ return transcription, audio_output
88
 
89
  def populate_metadata(url):
90
  yt = YouTube(url)
 
109
  mic_audio = gr.Audio(sources=['microphone','upload'], type="filepath", label="Record or upload an audio")
110
  transcribe_button = gr.Button("Transcribe")
111
  transcription_output = gr.Textbox(label="Transcription")
112
+ audio_output = gr.Audio(label="Processed Audio")
113
 
114
+ transcribe_button.click(fn=transcribe_speech, inputs=mic_audio, outputs=[transcription_output, audio_output])
115
 
116
  with gr.Tab("YouTube URL"):
117
  gr.Markdown("## Transcribe speech from YouTube video")
118
  youtube_url = gr.Textbox(label="Enter YouTube video URL")
119
 
120
  with gr.Accordion("Audio Improvements", open=False):
 
121
  apply_normalization = gr.Checkbox(label="Normalize audio volume", value=True)
122
+ apply_spectral_gating = gr.Checkbox(label="Apply Spectral Gating filter", info="Noise reduction", value=True)
123
+ apply_wiener = gr.Checkbox(label="Apply Wiener filter", info="Noise reduction", value=False)
124
 
125
  with gr.Row():
126
  img = gr.Image(label="Thumbnail", height=240, width=240, scale=1)
 
128
 
129
  transcribe_button = gr.Button("Transcribe")
130
  transcription_output = gr.Textbox(label="Transcription", placeholder="Transcription Output", lines=10)
131
+ audio_output = gr.Audio(label="Processed Audio")
132
 
133
+ transcribe_button.click(fn=transcribe_from_youtube, inputs=[youtube_url, apply_wiener, apply_normalization, apply_spectral_gating], outputs=[transcription_output, audio_output])
134
  youtube_url.change(populate_metadata, inputs=[youtube_url], outputs=[img, title])
135
 
136
  demo.launch()