HareemFatima's picture
Update app.py
5a7c9db verified
raw
history blame
784 Bytes
import gradio as gr
# Load the model pipeline
model = pipeline("audio-classification", model="HareemFatima/distilhubert-finetuned-stutterdetection")
# Define a function to classify the audio and return the predicted label
def classify_audio(audio_input):
prediction = model(audio_input)
predicted_label = prediction[0]["label"]
# Define label mapping dictionary
label_map = {
0: "nonstutter",
1: "prolongation",
2: "repetition",
3: "blocks"
}
# Use the dictionary to get the label
return label_map.get(predicted_label, "Unknown")
# Create the Gradio interface
audio_input = gr.inputs.Audio(source="microphone", type="file")
output_label = gr.outputs.Label()
gr.Interface(fn=classify_audio, inputs=audio_input, outputs=output_label).launch()