File size: 2,188 Bytes
0f0cd12
eda6be5
 
 
98e7bd0
 
eda6be5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import streamlit as st
from transformers import pipeline, AutoTokenizer, AutoModelForTextToWaveform

# Load the audio classification model
processor = AutoProcessor.from_pretrained("HareemFatima/distilhubert-finetuned-stutterdetection")
model = AutoModelForAudioClassification.from_pretrained("HareemFatima/distilhubert-finetuned-stutterdetection")

# Load the TTS tokenizer and model
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
tts_model = AutoModelForTextToWaveform.from_pretrained("facebook/mms-tts-eng")

# Define a function to classify audio and generate speech
def classify_and_speak(audio_input):
    # Classify the audio
    classification_result = audio_classification_model(audio_input)
    predicted_class = classification_result[0]["label"]
    
    # Map predicted class to corresponding speech text
    speech_text = map_class_to_speech(predicted_class)
    
    # Generate speech
    input_ids = tts_tokenizer(speech_text, return_tensors="pt").input_ids
    speech = tts_model.generate(input_ids)
    
    # Display classification result and play speech
    st.write("Predicted Stutter Type:", predicted_class)
    st.audio(speech, format="audio/wav")

# Define a function to map predicted class to speech text
def map_class_to_speech(predicted_class):
    # Define speech text for each class
    speech_texts = {
        "nonstutter": "You are speaking fluently without any stutter.",
        "prolongation": "You are experiencing prolongation stutter. Try to relax and speak slowly.",
        "repetition": "You are experiencing repetition stutter. Focus on your breathing and try to speak smoothly.",
        "blocks": "You are experiencing block stutter. Take a deep breath and try to speak slowly and calmly."
    }
    return speech_texts.get(predicted_class, "Unknown stutter type")

# Streamlit app
def main():
    st.title("Stutter Classification and Therapy App")
    audio_input = st.audio("Capture Audio", format="audio/wav", start_recording=True, channels=1)
    if st.button("Stop Recording"):
        with st.spinner("Classifying and speaking..."):
            classify_and_speak(audio_input)

if __name__ == "__main__":
    main()