Spaces:
Sleeping
Sleeping
File size: 7,043 Bytes
20e2dd0 3be8896 20e2dd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import pandas as pd
import streamlit as st
from langchain import PromptTemplate, HuggingFaceHub, LLMChain
from langchain.llms import OpenAI
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import os
import re
def extract_positive_negative(text):
pattern = r'\b(?:positive|negative)\b'
result = re.findall(pattern, text)
return result
def classify_text(text, llm_chain, api):
if api == "HuggingFace":
classification = llm_chain.run(str(text))
elif api == "OpenAI":
classification = llm_chain.run(str(text))
classification = re.sub(r'\s', '', classification)
return classification.lower()
def classify_csv(df, llm_chain, api):
df["label_gold"] = df["label"]
del df["label"]
df["label_pred"] = df["text"].apply(classify_text, llm_chain=llm_chain, api=api)
return df
def classify_csv_zero(zero_file, llm_chain, api):
df = pd.read_csv(zero_file, sep=';')
df["label"] = df["text"].apply(classify_text, llm_chain=llm_chain, api=api)
return df
def evaluate_performance(df):
merged_df = df
correct_preds = sum(merged_df["label_gold"] == merged_df["label_pred"])
total_preds = len(merged_df)
percentage_overlap = correct_preds / total_preds * 100
return percentage_overlap
def display_home():
st.write("Please select an API and a model to classify the text. We currently support HuggingFace and OpenAI.")
api = st.selectbox("Select an API", ["HuggingFace", "OpenAI"])
if api == "HuggingFace":
model = st.selectbox("Select a model", ["google/flan-t5-xl", "databricks/dolly-v1-6b"])
api_key_hug = st.text_input("HuggingFace API Key")
elif api == "OpenAI":
model = None
api_key_openai = st.text_input("OpenAI API Key")
st.write("Please select a temperature for the model. The higher the temperature, the more creative the model will be.")
temperature = st.slider("Set the temperature", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
st.write("We provide two different setups for the annotation task. In the first setup (**Test**), you can upload a CSV file with gold labels and evaluate the performance of the model. In the second setup (**Zero-Shot**), you can upload a CSV file without gold labels and use the model to classify the text.")
setup = st.selectbox("Setup", ["Test", "Zero-Shot"])
if setup == "Test":
gold_file = st.file_uploader("Upload Gold Labels CSV file with a text and a label column", type=["csv"])
elif setup == "Zero-Shot":
gold_file = None
zero_file = st.file_uploader("Upload CSV file with a text column", type=["csv"])
st.write("Please enter the prompt template below. You can use the following variables: {text} (text to classify).")
prompt_template = st.text_area("Enter your task description", """Instruction: Identify the sentiment of a text. Please read the text and provide one of these responses: "positive" or "negative".\nText to classify in "positive" or "negative": {text}\nAnswer:""", height=200)
classify_button = st.button("Run Classification/ Annotation")
if classify_button:
if prompt_template:
prompt = PromptTemplate(
template=prompt_template,
input_variables=["text"]
)
if api == "HuggingFace":
if api_key_hug:
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key_hug
llm_chain = LLMChain(prompt=prompt, llm=HuggingFaceHub(repo_id=model, model_kwargs={"temperature": temperature, "max_length": 128}))
elif not api_key_hug:
st.warning("Please enter your HuggingFace API key to classify the text.")
elif api == "OpenAI":
if api_key_openai:
os.environ["OPENAI_API_KEY"] = api_key_openai
llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=temperature))
elif not api_key_openai:
st.warning("Please enter your OpenAI API key to classify the text.")
if setup == "Zero-Shot":
if zero_file is not None:
df_predicted = classify_csv_zero(zero_file, llm_chain, api)
st.write(df_predicted)
st.download_button(
label="Download CSV",
data=df_predicted.to_csv(index=False),
file_name="classified_zero-shot_data.csv",
mime="text/csv"
)
elif setup == "Test":
if gold_file is not None:
df = pd.read_csv(gold_file, sep=';')
if "label" not in df.columns:
st.warning("Please make sure that the gold labels CSV file contains a column named 'label'.")
else:
df = classify_csv(df, llm_chain, api)
st.write(df)
st.download_button(
label="Download CSV",
data=df.to_csv(index=False),
file_name="classified_test_data.csv",
mime="text/csv"
)
percentage_overlap = evaluate_performance(df)
st.write("**Performance Evaluation**")
st.write(f"Percentage overlap between gold labels and predicted labels: {percentage_overlap:.2f}%")
elif gold_file is None:
st.warning("Please upload a gold labels CSV file to evaluate the performance of the model.")
elif not prompt:
st.warning("Please enter a prompt question to classify the text.")
def main():
st.set_page_config(page_title="PromptCards Playground", page_icon=":pencil2:")
st.title("AInnotator")
# add a menu to the sidebar
if "current_page" not in st.session_state:
st.session_state.current_page = "homepage"
# Initialize selected_prompt in session_state if not set
if "selected_prompt" not in st.session_state:
st.session_state.selected_prompt = ""
# Add a menu
menu = ["Homepage", "Playground", "Prompt Archive", "Annotator", "About"]
st.sidebar.title("About")
st.sidebar.write("AInnotator 🤖🏷️ is a tool for creating artificial labels/ annotations. It is based on the concept of PromptCards, which are small, self-contained descriptions of a task that can be used to generate labels for a wide range of NLP tasks. Check out the GitHub repository and the PromptCards Archive for more information.")
st.sidebar.write("---")
st.sidebar.write("Check out the [PromptCards archive](https://huggingface.co/spaces/chkla/AnnotationPromptCards) to find a wide range of prompts for different NLP tasks.")
st.sidebar.write("---")
st.sidebar.write("Made with ❤️ and 🤖.")
display_home()
if __name__ == "__main__":
main()
|