import numpy as np
import csv
from typing import Optional
from urllib.request import urlopen
import gradio as gr


class SentimentTransform():
    def __init__(
            self,
            model_name: str = "cardiffnlp/twitter-roberta-base-sentiment",
            highlight: bool = False,
            positive_sentiment_name: str = "positive",
            max_number_of_shap_documents: Optional[int] = None,
            min_abs_score: float = 0.1,
            sensitivity: float = 0,
            **kwargs,
    ):
        """
        Sentiment Ops.
        Parameters
        -------------
        model_name: str
            The name of the model
        sensitivity: float
            How confident it is about being `neutral`. If you are dealing with news sources,
            you probably want less sensitivity
        """
        self.model_name = model_name
        self.highlight = highlight
        self.positive_sentiment_name = positive_sentiment_name
        self.max_number_of_shap_documents = max_number_of_shap_documents
        self.min_abs_score = min_abs_score
        self.sensitivity = sensitivity
        for k, v in kwargs.items():
            setattr(self, k, v)

    def preprocess(self, text: str):
        new_text = []
        for t in text.split(" "):
            t = "@user" if t.startswith("@") and len(t) > 1 else t
            t = "http" if t.startswith("http") else t
            new_text.append(t)
        return " ".join(new_text)

    @property
    def classifier(self):
        if not hasattr(self, "_classifier"):
            import transformers

            self._classifier = transformers.pipeline(
                return_all_scores=True,
                model=self.model_name,
            )
        return self._classifier

    def _get_label_mapping(self, task: str):
        # Note: this is specific to the current model
        labels = []
        mapping_link = f"https://raw.githubusercontent.com/cardiffnlp/tweeteval/main/datasets/{task}/mapping.txt"
        with urlopen(mapping_link) as f:
            html = f.read().decode("utf-8").split("\n")
            csvreader = csv.reader(html, delimiter="\t")
        labels = [row[1] for row in csvreader if len(row) > 1]
        return labels

    @property
    def label_mapping(self):
        return {"LABEL_0": "negative", "LABEL_1": "neutral", "LABEL_2": "positive"}

    def analyze_sentiment(
            self,
            text,
            highlight: bool = False,
            positive_sentiment_name: str = "positive",
            max_number_of_shap_documents: Optional[int] = None,
            min_abs_score: float = 0.1,
    ):
        if text is None:
            return None
        labels = self.classifier([str(text)], truncation=True, max_length=512)
        ind_max = np.argmax([l["score"] for l in labels[0]])
        sentiment = labels[0][ind_max]["label"]
        max_score = labels[0][ind_max]["score"]
        sentiment = self.label_mapping.get(sentiment, sentiment)
        if sentiment.lower() == "neutral" and max_score > self.sensitivity:
            overall_sentiment = 1e-5
        elif sentiment.lower() == "neutral":
            # get the next highest score
            new_labels = labels[0][:ind_max] + labels[0][(ind_max + 1):]
            new_ind_max = np.argmax([l["score"] for l in new_labels])
            new_max_score = new_labels[new_ind_max]["score"]
            new_sentiment = new_labels[new_ind_max]["label"]
            new_sentiment = self.label_mapping.get(new_sentiment, new_sentiment)
            overall_sentiment = self._calculate_overall_sentiment(
                new_max_score, new_sentiment
            )

        else:
            overall_sentiment = self._calculate_overall_sentiment(max_score, sentiment)
        # Adjust to avoid bug
        if overall_sentiment == 0:
            overall_sentiment = 1e-5
        if not highlight:
            return {
                "sentiment": sentiment,
                "overall_sentiment_score": overall_sentiment,
            }
        shap_documents = self.get_shap_values(
            text,
            sentiment_ind=ind_max,
            max_number_of_shap_documents=max_number_of_shap_documents,
            min_abs_score=min_abs_score,
        )
        return {
            "sentiment": sentiment,
            "score": max_score,
            "overall_sentiment": overall_sentiment,
            "highlight_chunk_": shap_documents,
        }

    def _calculate_overall_sentiment(self, score: float, sentiment: str):
        if sentiment.lower().strip() == self.positive_sentiment_name:
            return score
        else:
            return -score

    # def explainer(self):
    #     if hasattr(self, "_explainer"):
    #         return self._explainer
    #     else:
    #         try:
    #             import shap
    #         except ModuleNotFoundError:
    #             raise MissingPackageError("shap")
    #         self._explainer = shap.Explainer(self.classifier)
    #         return self._explainer

    def get_shap_values(
            self,
            text: str,
            sentiment_ind: int = 2,
            max_number_of_shap_documents: Optional[int] = None,
            min_abs_score: float = 0.1,
    ):
        """Get SHAP values"""
        shap_values = self.explainer([text])
        cohorts = {"": shap_values}
        cohort_labels = list(cohorts.keys())
        cohort_exps = list(cohorts.values())
        features = cohort_exps[0].data
        feature_names = cohort_exps[0].feature_names
        values = np.array([cohort_exps[i].values for i in range(len(cohort_exps))])
        shap_docs = [
            {"text": v, "score": f}
            for f, v in zip(
                [x[sentiment_ind] for x in values[0][0].tolist()], feature_names[0]
            )
        ]
        if max_number_of_shap_documents is not None:
            sorted_scores = sorted(shap_docs, key=lambda x: x["score"], reverse=True)
        else:
            sorted_scores = sorted(shap_docs, key=lambda x: x["score"], reverse=True)[
                            :max_number_of_shap_documents
                            ]
        return [d for d in sorted_scores if abs(d["score"]) > min_abs_score]

    def transform(self, text):
        # # For each document, update the field
        # sentiment_docs = [{"_id": d["_id"]} for d in documents]
        # for i, t in enumerate(self.text_fields):
        #     if self.output_fields is not None:
        #         output_field = self.output_fields[i]
        #     else:
        #         output_field = self._get_output_field(t)
        sentiment = self.analyze_sentiment(
            text,
            highlight=self.highlight,
            max_number_of_shap_documents=self.max_number_of_shap_documents,
            min_abs_score=self.min_abs_score, )
        return sentiment


def sentiment_classifier(text, model_type, sensitivity):
    if model_type == 'Social Media Model':
        model_name = "cardiffnlp/twitter-roberta-base-sentiment"
    elif model_type == 'Survey Model':
        model_name = "j-hartmann/sentiment-roberta-large-english-3-classes"
    else:
        model_name = "j-hartmann/sentiment-roberta-large-english-3-classes"
    model = SentimentTransform(model_name=model_name, sensitivity=sensitivity)
    res_dict = model.transform(text)
    return res_dict['sentiment'], res_dict['overall_sentiment_score']


demo = gr.Interface(
    fn=sentiment_classifier,
    inputs=[gr.Textbox(placeholder="Put the text here and click 'submit' to predict its sentiment", label="Input Text"), gr.Dropdown(["Social Media Model", "Survey Model"], value="Survey Model", label="Select the Model that you want to use."), gr.Slider(0, 1, step = 0.01, label="Sensitivity (How confident it is about being `neutral`. If you are dealing with news sources, you probably want less sensitivity.)")],
    outputs=[gr.Textbox(label='Sentiment'), gr.Textbox(label='Sentiment Score')],
)
demo.launch(debug=True)