HareemFatima commited on
Commit
dec0715
·
verified ·
1 Parent(s): abe4ba1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -1
app.py CHANGED
@@ -1,3 +1,28 @@
1
  import gradio as gr
 
2
 
3
- gr.load("models/HareemFatima/distilhubert-finetuned-stutterdetection").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
 
4
+ # Load the model pipeline
5
+ model = pipeline("audio-classification", model="HareemFatima/distilhubert-finetuned-stutterdetection")
6
+
7
+ # Define a function to classify the audio and return the predicted label
8
+ def classify_audio(audio_input):
9
+ # Call the model pipeline to classify the audio
10
+ prediction = model(audio_input)
11
+ # Get the predicted label
12
+ predicted_label = prediction[0]["label"]
13
+ # Map the label to the corresponding stutter type
14
+ if predicted_label == 0:
15
+ return "nonstutter"
16
+ elif predicted_label == 1:
17
+ return "prolongation"
18
+ elif predicted_label == 2:
19
+ return "repetition"
20
+ elif predicted_label == 3:
21
+ return "blocks"
22
+ else:
23
+ return "Unknown"
24
+
25
+ # Create the Gradio interface
26
+ audio_input = gr.inputs.Audio(source="microphone", type="file")
27
+ output_label = gr.outputs.Label()
28
+ gr.Interface(fn=classify_audio, inputs=audio_input, outputs=output_label).launch()