mskov commited on
Commit
53eb88c
1 Parent(s): a9ccd21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -2
app.py CHANGED
@@ -2,9 +2,16 @@ import evaluate
2
  from evaluate.utils import launch_gradio_widget
3
  import gradio as gr
4
  from transformers import AutoModelForSequenceClassification, pipeline, RobertaForSequenceClassification, RobertaTokenizer, AutoTokenizer
 
 
 
 
 
 
 
5
 
6
  # Create a Gradio interface with audio file and text inputs
7
- def classify_toxicity(audio_file, text_input):
8
  # Transcribe the audio file using Whisper ASR
9
  if audio_file != None:
10
  whisper_module = evaluate.load("whisper")
@@ -23,16 +30,31 @@ def classify_toxicity(audio_file, text_input):
23
 
24
  toxicity_score = toxicity_results["toxicity"][0]
25
  print(toxicity_score)
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  return toxicity_score, transcribed_text
27
  # return f"Toxicity Score ({available_models[selected_model]}): {toxicity_score:.4f}"
28
 
29
  with gr.Blocks() as iface:
 
 
30
  with gr.Column():
31
  aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
32
  text = gr.Textbox(label="Enter Text", placeholder="Enter text here...")
33
  submit_btn = gr.Button(label="Run")
34
  with gr.Column():
35
  out_text = gr.Textbox()
36
- submit_btn.click(fn=classify_toxicity, inputs=[aud_input, text], outputs=out_text)
37
 
38
  iface.launch()
 
2
  from evaluate.utils import launch_gradio_widget
3
  import gradio as gr
4
  from transformers import AutoModelForSequenceClassification, pipeline, RobertaForSequenceClassification, RobertaTokenizer, AutoTokenizer
5
+ # pull in emotion detection
6
+ # --- Add element for specification
7
+ # pull in text classification
8
+ # --- Add custom labels
9
+ # --- Associate labels with radio elements
10
+ # add logic to initiate mock notificaiton when detected
11
+ # pull in misophonia-specific model
12
 
13
  # Create a Gradio interface with audio file and text inputs
14
+ def classify_toxicity(audio_file, text_input, classify_anxiety):
15
  # Transcribe the audio file using Whisper ASR
16
  if audio_file != None:
17
  whisper_module = evaluate.load("whisper")
 
30
 
31
  toxicity_score = toxicity_results["toxicity"][0]
32
  print(toxicity_score)
33
+
34
+ # Text classification
35
+
36
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
37
+
38
+ classifiation_model = pipeline("zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")
39
+
40
+ sequence_to_classify = transcribed_text
41
+ candidate_labels = classify_anxiety
42
+ classification_output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
43
+ print(classification_output)
44
+
45
+
46
  return toxicity_score, transcribed_text
47
  # return f"Toxicity Score ({available_models[selected_model]}): {toxicity_score:.4f}"
48
 
49
  with gr.Blocks() as iface:
50
+ with gr.Column():
51
+ classify = gr.Radio(["racial identity hate", "LGBTQ+ hate", "sexually explicit", "misophonia"])
52
  with gr.Column():
53
  aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
54
  text = gr.Textbox(label="Enter Text", placeholder="Enter text here...")
55
  submit_btn = gr.Button(label="Run")
56
  with gr.Column():
57
  out_text = gr.Textbox()
58
+ submit_btn.click(fn=classify_toxicity, inputs=[aud_input, text, classify], outputs=out_text)
59
 
60
  iface.launch()