HareemFatima commited on
Commit
eda6be5
·
verified ·
1 Parent(s): 2da2570

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -45
app.py CHANGED
@@ -1,46 +1,48 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
-
4
- # Load audio classification model
5
- audio_classifier = pipeline(
6
- "audio-classification", model="HareemFatima/distilhubert-finetuned-stutterdetection"
7
- )
8
-
9
- # Load text-to-speech model (replace with your TTS model details)
10
- # Placeholder text-to-speech function (replace with your actual implementation)
11
- def tts(text):
12
- # Replace this with your text-to-speech processing logic
13
- # This is a placeholder to demonstrate the concept
14
- return f"Synthesized speech for therapy: {text}"
15
-
16
- # Define therapy text for different stutter types (replace with your specific content)
17
- therapy_text = {
18
- "Repetition": "Your speech sounds great! Keep practicing!",
19
- "Blocks": "Take a deep breath and try speaking slowly. You can do it!",
20
- "Prolongation": "Focus on relaxing your mouth muscles and speaking smoothly.",
21
- # Add more stutter types and therapy text here
22
- }
23
-
24
- st.title("Stuttering Therapy Assistant")
25
- st.write("This app helps you identify stuttering types and provides personalized therapy suggestions.")
26
-
27
- uploaded_audio = st.file_uploader("Upload Audio Clip")
28
-
29
- if uploaded_audio is not None:
30
- # Read audio data
31
- audio_bytes = uploaded_audio.read()
32
-
33
- # Classify stuttering type
34
- prediction = audio_classifier(audio_bytes)
35
- stutter_type = prediction[0]["label"]
36
-
37
- # Retrieve therapy text
38
- therapy = therapy_text.get(stutter_type, "General therapy tip: Practice slow, relaxed speech.")
39
-
40
- # Generate synthesized speech (placeholder for now)
41
- synthesized_speech = tts(therapy)
42
-
43
- st.write(f"Predicted Stutter Type: {stutter_type}")
44
- st.write(f"Therapy Tip: {therapy}")
45
- st.audio(synthesized_speech) # Placeholder audio output (replace with actual synthesized speech)
46
-
 
 
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForTextToWaveform
3
+
4
+ # Load the audio classification model
5
+ audio_classification_model = pipeline("audio-classification", model="HareemFatima/distilhubert-finetuned-stutterdetection")
6
+
7
+ # Load the TTS tokenizer and model
8
+ tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
9
+ tts_model = AutoModelForTextToWaveform.from_pretrained("facebook/mms-tts-eng")
10
+
11
+ # Define a function to classify audio and generate speech
12
+ def classify_and_speak(audio_input):
13
+ # Classify the audio
14
+ classification_result = audio_classification_model(audio_input)
15
+ predicted_class = classification_result[0]["label"]
16
+
17
+ # Map predicted class to corresponding speech text
18
+ speech_text = map_class_to_speech(predicted_class)
19
+
20
+ # Generate speech
21
+ input_ids = tts_tokenizer(speech_text, return_tensors="pt").input_ids
22
+ speech = tts_model.generate(input_ids)
23
+
24
+ # Display classification result and play speech
25
+ st.write("Predicted Stutter Type:", predicted_class)
26
+ st.audio(speech, format="audio/wav")
27
+
28
+ # Define a function to map predicted class to speech text
29
+ def map_class_to_speech(predicted_class):
30
+ # Define speech text for each class
31
+ speech_texts = {
32
+ "nonstutter": "You are speaking fluently without any stutter.",
33
+ "prolongation": "You are experiencing prolongation stutter. Try to relax and speak slowly.",
34
+ "repetition": "You are experiencing repetition stutter. Focus on your breathing and try to speak smoothly.",
35
+ "blocks": "You are experiencing block stutter. Take a deep breath and try to speak slowly and calmly."
36
+ }
37
+ return speech_texts.get(predicted_class, "Unknown stutter type")
38
+
39
+ # Streamlit app
40
+ def main():
41
+ st.title("Stutter Classification and Therapy App")
42
+ audio_input = st.audio("Capture Audio", format="audio/wav", start_recording=True, channels=1)
43
+ if st.button("Stop Recording"):
44
+ with st.spinner("Classifying and speaking..."):
45
+ classify_and_speak(audio_input)
46
+
47
+ if __name__ == "__main__":
48
+ main()