HareemFatima's picture
Create app.py
2edf518 verified
raw
history blame
1.35 kB
from transformers import pipeline
# Load the model pipeline
model = pipeline("audio-classification", model="HareemFatima/distilhubert-finetuned-stutterdetection")
# Define a function to map predicted labels to types of stuttering
def map_label_to_stutter_type(label):
if label == 0:
return "nonstutter"
elif label == 1:
return "prolongation"
elif label == 2:
return "repetition"
elif label == 3:
return "blocks"
else:
return "Unknown"
# Function to classify audio input and return the stutter type
def classify_audio(audio_input):
# Call your model pipeline to classify the audio
prediction = model(audio_input)
# Get the predicted label
predicted_label = prediction[0]["label"]
# Map the label to the corresponding stutter type
stutter_type = map_label_to_stutter_type(predicted_label)
return stutter_type
# Streamlit app
def main():
st.title("Stutter Classification App")
st.audio("path_to_your_audio_file", format="audio/wav") # Add audio input widget here
if st.button("Classify"):
audio_input = st.audio("path_to_your_audio_file", format="audio/wav") # Add audio input widget here
stutter_type = classify_audio(audio_input)
st.write("Predicted Stutter Type:", stutter_type)
if __name__ == "__main__":
main()