File size: 7,043 Bytes
20e2dd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3be8896
20e2dd0
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import pandas as pd
import streamlit as st
from langchain import PromptTemplate, HuggingFaceHub, LLMChain
from langchain.llms import OpenAI
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import os
import re


def extract_positive_negative(text):
    pattern = r'\b(?:positive|negative)\b'
    result = re.findall(pattern, text)
    return result

def classify_text(text, llm_chain, api):
    if api == "HuggingFace":
        classification = llm_chain.run(str(text))
    elif api == "OpenAI":
        classification = llm_chain.run(str(text))
        classification = re.sub(r'\s', '', classification)
    return classification.lower()

def classify_csv(df, llm_chain, api):
    df["label_gold"] = df["label"]
    del df["label"]
    df["label_pred"] = df["text"].apply(classify_text, llm_chain=llm_chain, api=api)
    return df

def classify_csv_zero(zero_file, llm_chain, api):
    df = pd.read_csv(zero_file, sep=';')
    df["label"] = df["text"].apply(classify_text, llm_chain=llm_chain, api=api)
    return df

def evaluate_performance(df):
    merged_df = df
    correct_preds = sum(merged_df["label_gold"] == merged_df["label_pred"])
    total_preds = len(merged_df)
    percentage_overlap = correct_preds / total_preds * 100

    return percentage_overlap

def display_home():
    st.write("Please select an API and a model to classify the text. We currently support HuggingFace and OpenAI.")
    api = st.selectbox("Select an API", ["HuggingFace", "OpenAI"])

    if api == "HuggingFace":
        model = st.selectbox("Select a model", ["google/flan-t5-xl", "databricks/dolly-v1-6b"])
        api_key_hug = st.text_input("HuggingFace API Key")
    elif api == "OpenAI":
        model = None
        api_key_openai = st.text_input("OpenAI API Key")

    st.write("Please select a temperature for the model. The higher the temperature, the more creative the model will be.")
    temperature = st.slider("Set the temperature", min_value=0.0, max_value=1.0, value=0.0, step=0.01)

    st.write("We provide two different setups for the annotation task. In the first setup (**Test**), you can upload a CSV file with gold labels and evaluate the performance of the model. In the second setup (**Zero-Shot**), you can upload a CSV file without gold labels and use the model to classify the text.")
    setup = st.selectbox("Setup", ["Test", "Zero-Shot"])

    if setup == "Test":
        gold_file = st.file_uploader("Upload Gold Labels CSV file with a text and a label column", type=["csv"])
    elif setup == "Zero-Shot":
        gold_file = None
        zero_file = st.file_uploader("Upload CSV file with a text column", type=["csv"])

    st.write("Please enter the prompt template below. You can use the following variables: {text} (text to classify).")
    prompt_template = st.text_area("Enter your task description", """Instruction: Identify the sentiment of a text. Please read the text and provide one of these responses: "positive" or "negative".\nText to classify in "positive" or "negative": {text}\nAnswer:""", height=200)

    classify_button = st.button("Run Classification/ Annotation")

    if classify_button:
        if prompt_template:
            prompt = PromptTemplate(
                template=prompt_template,
                input_variables=["text"]
            )

            if api == "HuggingFace":
                if api_key_hug:
                    os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key_hug
                    llm_chain = LLMChain(prompt=prompt, llm=HuggingFaceHub(repo_id=model, model_kwargs={"temperature": temperature, "max_length": 128}))
                elif not api_key_hug:
                    st.warning("Please enter your HuggingFace API key to classify the text.")
            elif api == "OpenAI":
                if api_key_openai:
                    os.environ["OPENAI_API_KEY"] = api_key_openai
                    llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=temperature))
                elif not api_key_openai:
                    st.warning("Please enter your OpenAI API key to classify the text.")

            if setup == "Zero-Shot":
                if zero_file is not None:
                    df_predicted = classify_csv_zero(zero_file, llm_chain, api)
                    st.write(df_predicted)
                    st.download_button(
                        label="Download CSV",
                        data=df_predicted.to_csv(index=False),
                        file_name="classified_zero-shot_data.csv",
                        mime="text/csv"
                    )
            elif setup == "Test":
                if gold_file is not None:
                    df = pd.read_csv(gold_file, sep=';')
                    if "label" not in df.columns:
                        st.warning("Please make sure that the gold labels CSV file contains a column named 'label'.")
                    else:
                        df = classify_csv(df, llm_chain, api)
                        st.write(df)
                        st.download_button(
                            label="Download CSV",
                            data=df.to_csv(index=False),
                            file_name="classified_test_data.csv",
                            mime="text/csv"
                        )
                        percentage_overlap = evaluate_performance(df)
                        st.write("**Performance Evaluation**")
                        st.write(f"Percentage overlap between gold labels and predicted labels: {percentage_overlap:.2f}%")
                elif gold_file is None:
                    st.warning("Please upload a gold labels CSV file to evaluate the performance of the model.")
        elif not prompt:
            st.warning("Please enter a prompt question to classify the text.")

def main():
    st.set_page_config(page_title="PromptCards Playground", page_icon=":pencil2:")
    st.title("AInnotator")

    # add a menu to the sidebar
    if "current_page" not in st.session_state:
        st.session_state.current_page = "homepage"

    # Initialize selected_prompt in session_state if not set
    if "selected_prompt" not in st.session_state:
        st.session_state.selected_prompt = ""

    # Add a menu
    menu = ["Homepage", "Playground", "Prompt Archive", "Annotator", "About"]
    st.sidebar.title("About")
    st.sidebar.write("AInnotator 🤖🏷️ is a tool for creating artificial labels/ annotations. It is based on the concept of PromptCards, which are small, self-contained descriptions of a task that can be used to generate labels for a wide range of NLP tasks. Check out the GitHub repository and the PromptCards Archive for more information.")
    st.sidebar.write("---")
    st.sidebar.write("Check out the [PromptCards archive](https://huggingface.co/spaces/chkla/AnnotationPromptCards) to find a wide range of prompts for different NLP tasks.")
    st.sidebar.write("---")
    st.sidebar.write("Made with ❤️ and 🤖.")

    display_home()

if __name__ == "__main__":
    main()