File size: 8,667 Bytes
babca6f
 
dc9cf4c
cbe4d4c
 
c8e54ed
1ae8e53
ec796a2
a6b9c5b
ff14337
218afdc
87e9ad0
ff14337
df85058
ff14337
 
 
53eb88c
 
 
 
 
 
 
28ff844
ff14337
 
df85058
a94b06f
df85058
 
 
 
 
 
61fa7d4
 
d4a83f2
34bf2a6
d4a83f2
61fa7d4
 
 
4b9eea9
df85058
0c5e4a4
ed2f0b8
 
 
 
fd26334
 
e7cf2e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8e54ed
eabbe21
c8e54ed
bbd3701
2cadcf2
f10b2fa
6bfef5d
395d676
ed2f0b8
dd5c246
 
fd26334
8c28395
e91d036
 
 
218afdc
 
395d676
 
 
 
73d041b
395d676
 
 
 
b31f1e8
 
0c5e4a4
218afdc
395d676
 
 
 
 
 
 
 
 
 
 
95f2d9c
184643c
f5f212e
59bfc5c
22b7cff
 
 
 
2186147
b9a0cdb
184643c
fd26334
95b70d7
a5a144e
 
 
eabbe21
22b7cff
395d676
 
fd26334
6615174
ff14337
7a481f6
187b547
bb7f792
187b547
bb7f792
 
 
7a481f6
187b547
9bae889
 
ff14337
 
6615174
ff14337
789fd51
ff14337
 
be06195
ff14337
 
 
 
 
 
 
 
 
95b70d7
 
395d676
c2b4186
a6b9c5b
 
 
 
 
 
 
 
 
 
33b1b5b
df92cf7
53eb88c
187b547
218afdc
 
3e45c8c
33b1b5b
ca7ae8f
 
335e90e
 
ec796a2
22b7cff
71fe961
a5a144e
eabbe21
30dbd25
c8e54ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
os.system("pip install git+https://github.com/openai/whisper.git")
import whisper
import evaluate
from evaluate.utils import launch_gradio_widget
import gradio as gr
import torch
import pandas as pd
import random
import classify
import replace_explitives
from whisper.model import Whisper
from whisper.tokenizer import get_tokenizer
from speechbrain.pretrained.interfaces import foreign_class
from transformers import AutoModelForSequenceClassification, pipeline, WhisperTokenizer, RobertaForSequenceClassification, RobertaTokenizer, AutoTokenizer


# pull in emotion detection
# --- Add element for specification
# pull in text classification
# --- Add custom labels
# --- Associate labels with radio elements
# add logic to initiate mock notificaiton when detected
# pull in misophonia-specific model

model_cache = {}

# Building prediction function for gradio
emo_dict = {
    'sad': 'Sad', 
    'hap': 'Happy',
    'ang': 'Anger',
    'neu': 'Neutral'
}

# static classes for now, but it would be best ot have the user select from multiple, and to enter their own
class_options = {
    "racism": ["racism", "hate speech", "bigotry", "racially targeted", "racial slur", "ethnic slur", "ethnic hate", "pro-white nationalism"],
    "LGBTQ+ hate": ["gay slur", "trans slur", "homophobic slur", "transphobia", "anti-LBGTQ+", "hate speech"],
    "sexually explicit": ["sexually explicit", "sexually coercive", "sexual exploitation", "vulgar", "raunchy", "sexist", "sexually demeaning", "sexual violence", "victim blaming"],
    "misophonia": ["chewing", "breathing", "mouthsounds", "popping", "sneezing", "yawning", "smacking", "sniffling", "panting"]
}

pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large")

def classify_emotion(audio):
    #### Emotion classification ####
    emotion_classifier = foreign_class(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
    out_prob, score, index, text_lab = emotion_classifier.classify_file(audio)
    return  emo_dict[text_lab[0]]

def slider_logic(slider):
    threshold = 0
    if slider == 1:
        threshold = .98
    elif slider == 2:
        threshold = .88
    elif slider == 3:
        threshold = .77
    elif slider == 4:
        threshold = .66
    elif slider == 5:
        threshold = .55
    else:
        threshold = []
    return threshold

# Create a Gradio interface with audio file and text inputs
def classify_toxicity(audio_file, text_input, classify_anxiety, emo_class, explitive_selection, slider):
    # Transcribe the audio file using Whisper ASR
    if audio_file != None:
        transcribed_text = pipe(audio_file)["text"]
    else:
        transcribed_text = text_input
    if classify_anxiety != "misophonia":
        print("emo_class ", emo_class, "explitive select", explitive_selection)

        ## SLIDER ##
        threshold = slider_logic(slider)
        
        #------- explitive call ---------------
        
        if replace_explitives != None and emo_class == None:
            transcribed_text = replace_explitives.sub_explitives(transcribed_text, explitive_selection)
        
        #### Toxicity Classifier ####
            
        toxicity_module = evaluate.load("toxicity",  "facebook/roberta-hate-speech-dynabench-r4-target")
        #toxicity_module = evaluate.load("toxicity", 'DaNLP/da-electra-hatespeech-detection', module_type="measurement")
    
        toxicity_results = toxicity_module.compute(predictions=[transcribed_text])
     
        toxicity_score = toxicity_results["toxicity"][0]
        print(toxicity_score)
        # emo call
        if emo_class != None:
            classify_emotion(audio_file)

        #### Text classification #####
    
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    
        text_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
    
        sequence_to_classify = transcribed_text
        print(classify_anxiety, class_options)
        candidate_labels = class_options.get(classify_anxiety, [])
        # classification_output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
        classification_output = text_classifier(sequence_to_classify, candidate_labels, multi_label=True)
        print("class output ", type(classification_output))
        # classification_df = pd.DataFrame.from_dict(classification_output)
        print("keys ", classification_output.keys())
        
        # formatted_classification_output = "\n".join([f"{key}: {value}" for key, value in classification_output.items()])
        label_score_pairs = [(label, score) for label, score in zip(classification_output['labels'], classification_output['scores'])]




        # plot.update(x=classification_df["labels"], y=classification_df["scores"])
        if toxicity_score > threshold:
            print("threshold exceeded!! Launch intervention")
            affirm = positive_affirmations()
        else:
            affirm = ""
 
        return toxicity_score, label_score_pairs, transcribed_text, affirm
        # return f"Toxicity Score ({available_models[selected_model]}): {toxicity_score:.4f}"
    else: 
        threshold = slider_logic(slider)
        model = whisper.load_model("large")
        # model = model_cache[model_name]
        # class_names = classify_anxiety.split(",")
        class_names_list = class_options.get(classify_anxiety, [])
        class_str = ""
        for elm in class_names_list:
            class_str += elm + ","
        #class_names = class_names_temp.split(",")
        class_names = class_str.split(",")
        print("class names ", class_names, "classify_anxiety ", classify_anxiety)
        
        tokenizer = get_tokenizer("large")
        # tokenizer= WhisperTokenizer.from_pretrained("openai/whisper-large")
    
        internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs(
            model=model,
            class_names=class_names,
            # class_names=classify_anxiety,
            tokenizer=tokenizer,
        )
        audio_features = classify.calculate_audio_features(audio_file, model)
        average_logprobs = classify.calculate_average_logprobs(
            model=model,
            audio_features=audio_features,
            class_names=class_names,
            tokenizer=tokenizer,
        )
        average_logprobs -= internal_lm_average_logprobs
        scores = average_logprobs.softmax(-1).tolist()
        return {class_name: score for class_name, score in zip(class_names, scores)}
        if toxicity_score > threshold:
            print("threshold exceeded!! Launch intervention")
        return classify_anxiety
        
def positive_affirmations():
    affirmations = [
        "I have survived my anxiety before and I will survive again now",
        "I am not in danger; I am just uncomfortable; this too will pass",
        "I forgive and release the past and look forward to the future",
        "I can't control what other people say but I can control my breathing and my response"
    ]
    selected_affirm = random.choice(affirmations)
    return selected_affirm
    
with gr.Blocks() as iface:
    show_state = gr.State([])
    with gr.Column():
        anxiety_class = gr.Radio(["racism", "LGBTQ+ hate", "sexually explicit", "misophonia"])
        explit_preference = gr.Radio(choices=["N-Word", "B-Word", "All Explitives"], label="Words to omit from general anxiety classes", info="certain words may be acceptible within certain contects for given groups of people, and some people may be unbothered by explitives broadly speaking.")
        emo_class = gr.Radio(choices=["negaitve emotionality"], label="label", info="Select if you would like explitives to be considered anxiety-indiucing in the case of anger/ negative emotionality.")
        sense_slider = gr.Slider(minimum=1, maximum=5, step=1.0, label="How readily do you want the tool to intervene? 1 = in extreme cases and 5 = at every opportunity")
    with gr.Column():
        aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
        text = gr.Textbox(label="Enter Text", placeholder="Enter text here...")
        submit_btn = gr.Button(label="Run")
    with gr.Column():
        out_val = gr.Textbox()
        out_class = gr.Label()
        out_text = gr.Textbox()
        out_affirm = gr.Textbox()
    submit_btn.click(fn=classify_toxicity, inputs=[aud_input, text, anxiety_class, emo_class, explit_preference, sense_slider], outputs=[out_val, out_class, out_text, out_affirm])

iface.launch()