Shokoufehhh commited on
Commit
dfb36ea
1 Parent(s): c65e0e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -24
app.py CHANGED
@@ -1,34 +1,50 @@
1
- import gradio as gr
2
  import torch
3
- from sgmse.model import SGMSENoiseReducer # Adjust import as per your model structure
4
- import soundfile as sf
 
 
5
 
6
- # Load your pre-trained model
7
  model = SGMSENoiseReducer.from_pretrained("sp-uhh/speech-enhancement-sgmse")
8
 
9
- # Define a function to process the uploaded file
10
- def enhance_speech(noisy_audio):
11
- # Load noisy audio file
12
- noisy, sr = sf.read(noisy_audio)
13
 
14
- # Apply your model to enhance the speech
15
- enhanced_audio = model.enhance(noisy, sr)
16
 
17
- # Save enhanced audio to a temporary file
18
- output_file = "enhanced_output.wav"
19
- sf.write(output_file, enhanced_audio, sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
 
 
 
 
 
 
 
21
  return output_file
22
 
23
- # Set up the Gradio interface
24
- interface = gr.Interface(
25
- fn=enhance_speech,
26
- inputs=gr.Audio(source="upload", type="filepath"),
27
- outputs=gr.Audio(type="file"),
28
- title="SGMSE Speech Enhancement",
29
- description="Upload a noisy audio file and download the enhanced (clean) version."
30
- )
31
 
32
- # Launch the interface
33
- if __name__ == "__main__":
34
- interface.launch()
 
 
1
  import torch
2
+ import torchaudio
3
+ from sgmse.model import ScoreModel
4
+ import gradio as gr
5
+ from sgmse.util.other import pad_spec
6
 
7
+ # Load the pre-trained model
8
  model = SGMSENoiseReducer.from_pretrained("sp-uhh/speech-enhancement-sgmse")
9
 
10
+ def enhance_speech(audio_file):
11
+ # Load and process the audio file
12
+ y, sr = torchaudio.load(audio_file)
 
13
 
14
+ T_orig = y.size(1)
 
15
 
16
+ # Normalize
17
+ norm_factor = y.abs().max()
18
+ y = y / norm_factor
19
+
20
+ # Prepare DNN input
21
+ Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args.device))), 0)
22
+ Y = pad_spec(Y, mode=pad_mode)
23
+
24
+ # Reverse sampling
25
+ sampler = model.get_pc_sampler(
26
+ 'reverse_diffusion', args.corrector, Y.to(args.device), N=args.N,
27
+ corrector_steps=args.corrector_steps, snr=args.snr)
28
+ sample, _ = sampler()
29
+
30
+ # Backward transform in time domain
31
+ x_hat = model.to_audio(sample.squeeze(), T_orig)
32
 
33
+ # Renormalize
34
+ x_hat = x_hat * norm_factor
35
+
36
+ # Save the enhanced audio
37
+ output_file = 'enhanced_output.wav'
38
+ torchaudio.save(output_file, x_hat.cpu().numpy(), sr)
39
+
40
  return output_file
41
 
42
+ # Gradio interface setup
43
+ inputs = gr.Audio(label="Input Audio", type="filepath")
44
+ outputs = gr.Audio(label="Output Audio", type="filepath")
45
+ title = "Speech Enhancement using SGMSE"
46
+ description = "This Gradio demo uses the SGMSE model for speech enhancement. Upload your audio file to enhance it."
47
+ article = "<p style='text-align: center'><a href='https://huggingface.co/SP-UHH/speech-enhancement-sgmse' target='_blank'>Model Card</a></p>"
 
 
48
 
49
+ # Launch without share=True (as it's not supported on Hugging Face Spaces)
50
+ gr.Interface(fn=enhance_speech, inputs=inputs, outputs=outputs, title=title, description=description, article=article).launch()