File size: 784 Bytes
abe4ba1
 
dec0715
 
 
 
 
7248f4f
 
 
 
 
 
 
 
 
 
 
 
 
dec0715
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import gradio as gr

# Load the model pipeline
model = pipeline("audio-classification", model="HareemFatima/distilhubert-finetuned-stutterdetection")

# Define a function to classify the audio and return the predicted label
def classify_audio(audio_input):
  prediction = model(audio_input)
  predicted_label = prediction[0]["label"]

  # Define label mapping dictionary
  label_map = {
      0: "nonstutter",
      1: "prolongation",
      2: "repetition",
      3: "blocks"
  }

  # Use the dictionary to get the label
  return label_map.get(predicted_label, "Unknown")

# Create the Gradio interface
audio_input = gr.inputs.Audio(source="microphone", type="file")
output_label = gr.outputs.Label()
gr.Interface(fn=classify_audio, inputs=audio_input, outputs=output_label).launch()