import pandas as pd
import streamlit as st
from langchain import PromptTemplate, HuggingFaceHub, LLMChain
from langchain.llms import OpenAI
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import os
import re


def extract_positive_negative(text):
    pattern = r'\b(?:positive|negative)\b'
    result = re.findall(pattern, text)
    return result

def classify_text(text, llm_chain, api):
    if api == "HuggingFace":
        classification = llm_chain.run(str(text))
    elif api == "OpenAI":
        classification = llm_chain.run(str(text))
        classification = re.sub(r'\s', '', classification)
    return classification.lower()

def classify_csv(df, llm_chain, api):
    df["label_gold"] = df["label"]
    del df["label"]
    df["label_pred"] = df["text"].apply(classify_text, llm_chain=llm_chain, api=api)
    return df

def classify_csv_zero(zero_file, llm_chain, api):
    df = pd.read_csv(zero_file, sep=';')
    df["label"] = df["text"].apply(classify_text, llm_chain=llm_chain, api=api)
    return df

def evaluate_performance(df):
    merged_df = df
    correct_preds = sum(merged_df["label_gold"] == merged_df["label_pred"])
    total_preds = len(merged_df)
    percentage_overlap = correct_preds / total_preds * 100

    return percentage_overlap

def display_home():
    st.write("Please select an API and a model to classify the text. We currently support HuggingFace and OpenAI.")
    api = st.selectbox("Select an API", ["HuggingFace", "OpenAI"])

    if api == "HuggingFace":
        model = st.selectbox("Select a model", ["google/flan-t5-xl", "databricks/dolly-v1-6b"])
        api_key_hug = st.text_input("HuggingFace API Key")
    elif api == "OpenAI":
        model = None
        api_key_openai = st.text_input("OpenAI API Key")

    st.write("Please select a temperature for the model. The higher the temperature, the more creative the model will be.")
    temperature = st.slider("Set the temperature", min_value=0.0, max_value=1.0, value=0.0, step=0.01)

    st.write("We provide two different setups for the annotation task. In the first setup (**Test**), you can upload a CSV file with gold labels and evaluate the performance of the model. In the second setup (**Zero-Shot**), you can upload a CSV file without gold labels and use the model to classify the text.")
    setup = st.selectbox("Setup", ["Test", "Zero-Shot"])

    if setup == "Test":
        gold_file = st.file_uploader("Upload Gold Labels CSV file with a text and a label column", type=["csv"])
    elif setup == "Zero-Shot":
        gold_file = None
        zero_file = st.file_uploader("Upload CSV file with a text column", type=["csv"])

    st.write("Please enter the prompt template below. You can use the following variables: {text} (text to classify).")
    prompt_template = st.text_area("Enter your task description", """Instruction: Identify the sentiment of a text. Please read the text and provide one of these responses: "positive" or "negative".\nText to classify in "positive" or "negative": {text}\nAnswer:""", height=200)

    classify_button = st.button("Run Classification/ Annotation")

    if classify_button:
        if prompt_template:
            prompt = PromptTemplate(
                template=prompt_template,
                input_variables=["text"]
            )

            if api == "HuggingFace":
                if api_key_hug:
                    os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key_hug
                    llm_chain = LLMChain(prompt=prompt, llm=HuggingFaceHub(repo_id=model, model_kwargs={"temperature": temperature, "max_length": 128}))
                elif not api_key_hug:
                    st.warning("Please enter your HuggingFace API key to classify the text.")
            elif api == "OpenAI":
                if api_key_openai:
                    os.environ["OPENAI_API_KEY"] = api_key_openai
                    llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=temperature))
                elif not api_key_openai:
                    st.warning("Please enter your OpenAI API key to classify the text.")

            if setup == "Zero-Shot":
                if zero_file is not None:
                    df_predicted = classify_csv_zero(zero_file, llm_chain, api)
                    st.write(df_predicted)
                    st.download_button(
                        label="Download CSV",
                        data=df_predicted.to_csv(index=False),
                        file_name="classified_zero-shot_data.csv",
                        mime="text/csv"
                    )
            elif setup == "Test":
                if gold_file is not None:
                    df = pd.read_csv(gold_file, sep=';')
                    if "label" not in df.columns:
                        st.warning("Please make sure that the gold labels CSV file contains a column named 'label'.")
                    else:
                        df = classify_csv(df, llm_chain, api)
                        st.write(df)
                        st.download_button(
                            label="Download CSV",
                            data=df.to_csv(index=False),
                            file_name="classified_test_data.csv",
                            mime="text/csv"
                        )
                        percentage_overlap = evaluate_performance(df)
                        st.write("**Performance Evaluation**")
                        st.write(f"Percentage overlap between gold labels and predicted labels: {percentage_overlap:.2f}%")
                elif gold_file is None:
                    st.warning("Please upload a gold labels CSV file to evaluate the performance of the model.")
        elif not prompt:
            st.warning("Please enter a prompt question to classify the text.")

def main():
    st.set_page_config(page_title="PromptCards Playground", page_icon=":pencil2:")
    st.title("AInnotator")

    # add a menu to the sidebar
    if "current_page" not in st.session_state:
        st.session_state.current_page = "homepage"

    # Initialize selected_prompt in session_state if not set
    if "selected_prompt" not in st.session_state:
        st.session_state.selected_prompt = ""

    # Add a menu
    menu = ["Homepage", "Playground", "Prompt Archive", "Annotator", "About"]
    st.sidebar.title("About")
    st.sidebar.write("AInnotator 🤖🏷️ is a tool for creating artificial labels/ annotations. It is based on the concept of PromptCards, which are small, self-contained descriptions of a task that can be used to generate labels for a wide range of NLP tasks. Check out the GitHub repository and the PromptCards Archive for more information.")
    st.sidebar.write("---")
    st.sidebar.write("Check out the [PromptCards archive](https://huggingface.co/spaces/chkla/AnnotationPromptCards) to find a wide range of prompts for different NLP tasks.")
    st.sidebar.write("---")
    st.sidebar.write("Made with ❤️ and 🤖.")

    display_home()

if __name__ == "__main__":
    main()