|
import streamlit as st |
|
from transformers import pipeline, AutoTokenizer, AutoModelForTextToWaveform |
|
|
|
|
|
processor = AutoProcessor.from_pretrained("HareemFatima/distilhubert-finetuned-stutterdetection") |
|
model = AutoModelForAudioClassification.from_pretrained("HareemFatima/distilhubert-finetuned-stutterdetection") |
|
|
|
|
|
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") |
|
tts_model = AutoModelForTextToWaveform.from_pretrained("facebook/mms-tts-eng") |
|
|
|
|
|
def classify_and_speak(audio_input): |
|
|
|
classification_result = audio_classification_model(audio_input) |
|
predicted_class = classification_result[0]["label"] |
|
|
|
|
|
speech_text = map_class_to_speech(predicted_class) |
|
|
|
|
|
input_ids = tts_tokenizer(speech_text, return_tensors="pt").input_ids |
|
speech = tts_model.generate(input_ids) |
|
|
|
|
|
st.write("Predicted Stutter Type:", predicted_class) |
|
st.audio(speech, format="audio/wav") |
|
|
|
|
|
def map_class_to_speech(predicted_class): |
|
|
|
speech_texts = { |
|
"nonstutter": "You are speaking fluently without any stutter.", |
|
"prolongation": "You are experiencing prolongation stutter. Try to relax and speak slowly.", |
|
"repetition": "You are experiencing repetition stutter. Focus on your breathing and try to speak smoothly.", |
|
"blocks": "You are experiencing block stutter. Take a deep breath and try to speak slowly and calmly." |
|
} |
|
return speech_texts.get(predicted_class, "Unknown stutter type") |
|
|
|
|
|
def main(): |
|
st.title("Stutter Classification and Therapy App") |
|
audio_input = st.audio("Capture Audio", format="audio/wav", start_recording=True, channels=1) |
|
if st.button("Stop Recording"): |
|
with st.spinner("Classifying and speaking..."): |
|
classify_and_speak(audio_input) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|