import os
import whisper
import evaluate
from evaluate.utils import launch_gradio_widget
import gradio as gr
import torch
import pandas as pd
import random
import classify
import replace_explitives
from whisper.model import Whisper
from whisper.tokenizer import get_tokenizer
from speechbrain.pretrained.interfaces import foreign_class
from transformers import AutoModelForSequenceClassification, pipeline, WhisperTokenizer, RobertaForSequenceClassification, RobertaTokenizer, AutoTokenizer


# pull in emotion detection
# --- Add element for specification
# pull in text classification
# --- Add custom labels
# --- Associate labels with radio elements
# add logic to initiate mock notificaiton when detected
# pull in misophonia-specific model

model_cache = {}

# Building prediction function for gradio
emo_dict = {
    'sad': 'Sad', 
    'hap': 'Happy',
    'ang': 'Anger',
    'neu': 'Neutral'
}

# static classes for now, but it would be best ot have the user select from multiple, and to enter their own
class_options = {
    "Racism": ["racism", "hate speech", "bigotry", "racially targeted", "racial slur", "ethnic slur", "ethnic hate", "pro-white nationalism"],
    "LGBTQ+ Hate": ["gay slur", "trans slur", "homophobic slur", "transphobia", "anti-LBGTQ+"],
    "Sexually Explicit": ["sexually explicit", "sexually coercive", "sexual exploitation", "vulgar", "raunchy", "sexist", "sexually demeaning", "sexual violence", "victim blaming"],
    "Pregnancy Complications": ["miscarriage", "child loss", "child death", "abortion", "pregnancy", "childbirth", "baby shower", "postpartum"],
}

pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large")

toxicity_module = evaluate.load("toxicity",  "facebook/roberta-hate-speech-dynabench-r4-target")
emotion_classifier = foreign_class(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
text_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

def classify_emotion(audio):
    #### Emotion classification ####
    # EMO MODEL LINE emotion_classifier = foreign_class(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
    out_prob, score, index, text_lab = emotion_classifier.classify_file(audio)
    return  emo_dict[text_lab[0]]

def slider_logic(slider):
    threshold = 0
    if slider == 1:
        threshold = .90
    elif slider == 2:
        threshold = .80
    elif slider == 3:
        threshold = .60
    elif slider == 4:
        threshold = .50
    elif slider == 5:
        threshold = .40
    else:
        threshold = []
    return threshold

# Create a Gradio interface with audio file and text inputs
def classify_toxicity(audio_file, classify_anxiety, emo_class, explitive_selection, slider):
    
    # Transcribe the audio file using Whisper ASR
    transcribed_text = pipe(audio_file)["text"]
    
    ## SLIDER ##
    threshold = slider_logic(slider)
    
    #------- explitive call ---------------
    
    if replace_explitives != None and emo_class == None:
        transcribed_text = replace_explitives.sub_explitives(transcribed_text, explitive_selection)
    
    #### Toxicity Classifier ####
        
    # TOX MODEL LINE toxicity_module = evaluate.load("toxicity",  "facebook/roberta-hate-speech-dynabench-r4-target")
    #toxicity_module = evaluate.load("toxicity", 'DaNLP/da-electra-hatespeech-detection', module_type="measurement")

    toxicity_results = toxicity_module.compute(predictions=[transcribed_text])
 
    toxicity_score = toxicity_results["toxicity"][0]
    print(toxicity_score)
    
    # emo call
    if emo_class != None:
        classify_emotion(audio_file)

    #### Text classification #####
    if classify_anxiety != None: 
        # DEVICE LINE device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    
        # CLASSIFICATION LINE text_classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
    
        sequence_to_classify = transcribed_text
        print(classify_anxiety, class_options)
        candidate_labels = class_options.get(classify_anxiety, [])
        # classification_output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
        classification_output = text_classifier(sequence_to_classify, candidate_labels, multi_label=True)
        print("class output ", type(classification_output))
        # classification_df = pd.DataFrame.from_dict(classification_output)
        print("keys ", classification_output.keys())
        
        # formatted_classification_output = "\n".join([f"{key}: {value}" for key, value in classification_output.items()])
        # label_score_pairs = [(label, score) for label, score in zip(classification_output['labels'], classification_output['scores'])]
        label_score_dict = {label: score for label, score in zip(classification_output['labels'], classification_output['scores'])}
        k = max(label_score_dict, key=label_score_dict.get)
        print("k keys: ", k)
        maxval = label_score_dict[k]
        print("max value: ", maxval)
        topScore = ""
        affirm = ""
        if maxval > threshold:
            print("Toxic")
            affirm = positive_affirmations()
            topScore = maxval
        else:
            print("Not Toxic")
            affirm = ""
            topScore = maxval
    else:
        topScore = ""
        affirm = ""
        if toxicity_score > threshold:
            affirm = positive_affirmations()
            topScore = toxicity_score
        else:
            affirm = ""
            topScore = toxicity_score
        label_score_dict = {"toxicity" : toxicity_score}

    return transcribed_text, topScore, label_score_dict, affirm
    # return f"Toxicity Score ({available_models[selected_model]}): {toxicity_score:.4f}"
    
def positive_affirmations():
    affirmations = [
        "I have survived my anxiety before and I will survive again now",
        "I am not in danger; I am just uncomfortable; this too will pass",
        "I forgive and release the past and look forward to the future",
        "I can't control what other people say but I can control my breathing and my response"
    ]
    selected_affirm = random.choice(affirmations)
    return selected_affirm
    
with gr.Blocks() as iface:
    show_state = gr.State([])
    with gr.Column():
        anxiety_class = gr.Radio(label="Specify Subclass", choices=["Racism", "LGBTQ+ Hate", "Sexually Explicit", "Pregnancy Complications"])
        explit_preference = gr.Radio(choices=["N-Word", "B-Word", "All Explitives"], label="Words to omit from general anxiety classes", info="certain words may be acceptible within certain contects for given groups of people, and some people may be unbothered by explitives broadly speaking.")
        emo_class = gr.Radio(choices=["negaitve emotionality"], label="Negative Emotionality", info="Select if you would like explitives to be considered anxiety-indiucing in the case of anger/ negative emotionality.")
        sense_slider = gr.Slider(minimum=1, maximum=5, step=1.0, label="How readily do you want the tool to intervene? 1 = in extreme cases and 5 = at every opportunity")
    with gr.Column():
        aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File")
        submit_btn = gr.Button(label="Run")
    with gr.Column():
        out_text = gr.Textbox(label="Transcribed Audio")
        out_val = gr.Textbox(label="Overall Toxicity")
        out_affirm = gr.Textbox(label="Intervention")
        out_class = gr.Label(label="Toxicity Class Breakdown")
    submit_btn.click(fn=classify_toxicity, inputs=[aud_input, anxiety_class, emo_class, explit_preference, sense_slider], outputs=[out_text, out_val, out_class, out_affirm])

iface.launch()