HareemFatima commited on
Commit
2edf518
1 Parent(s): 538f618

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -0
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+
3
+ # Load the model pipeline
4
+ model = pipeline("audio-classification", model="HareemFatima/distilhubert-finetuned-stutterdetection")
5
+
6
+ # Define a function to map predicted labels to types of stuttering
7
+ def map_label_to_stutter_type(label):
8
+ if label == 0:
9
+ return "nonstutter"
10
+ elif label == 1:
11
+ return "prolongation"
12
+ elif label == 2:
13
+ return "repetition"
14
+ elif label == 3:
15
+ return "blocks"
16
+ else:
17
+ return "Unknown"
18
+
19
+ # Function to classify audio input and return the stutter type
20
+ def classify_audio(audio_input):
21
+ # Call your model pipeline to classify the audio
22
+ prediction = model(audio_input)
23
+ # Get the predicted label
24
+ predicted_label = prediction[0]["label"]
25
+ # Map the label to the corresponding stutter type
26
+ stutter_type = map_label_to_stutter_type(predicted_label)
27
+ return stutter_type
28
+
29
+ # Streamlit app
30
+ def main():
31
+ st.title("Stutter Classification App")
32
+ st.audio("path_to_your_audio_file", format="audio/wav") # Add audio input widget here
33
+ if st.button("Classify"):
34
+ audio_input = st.audio("path_to_your_audio_file", format="audio/wav") # Add audio input widget here
35
+ stutter_type = classify_audio(audio_input)
36
+ st.write("Predicted Stutter Type:", stutter_type)
37
+
38
+ if __name__ == "__main__":
39
+ main()