PromptSource / app.py
mrm8488's picture
Update app.py
b2516be
import argparse
import textwrap
from multiprocessing import Manager, Pool
import pandas as pd
import plotly.express as px
import streamlit as st
from datasets import get_dataset_infos
from pygments import highlight
from pygments.formatters import HtmlFormatter
from pygments.lexers import DjangoLexer
from session import _get_state
from templates import Template, TemplateCollection
from utils import (
get_dataset,
get_dataset_confs,
list_datasets,
removeHyphen,
renameDatasetColumn,
render_features,
)
# add an argument for read-only
# At the moment, streamlit does not handle python script arguments gracefully.
# Thus, for read-only mode, you have to type one of the below two:
# streamlit run promptsource/app.py -- -r
# streamlit run promptsource/app.py -- --read-only
# Check https://github.com/streamlit/streamlit/issues/337 for more information.
parser = argparse.ArgumentParser(description="run app.py with args")
parser.add_argument("-r", "--read-only", action="store_true", help="whether to run it as read-only mode")
args = parser.parse_args()
if args.read_only:
select_options = ["Helicopter view", "Prompted dataset viewer"]
side_bar_title_prefix = "Promptsource (Read only)"
else:
select_options = ["Helicopter view", "Prompted dataset viewer", "Sourcing"]
side_bar_title_prefix = "Promptsource"
#
# Helper functions for datasets library
#
get_dataset = st.cache(allow_output_mutation=True)(get_dataset)
get_dataset_confs = st.cache(get_dataset_confs)
def reset_template_state():
state.template_name = None
state.jinja = None
state.reference = None
#
# Loads session state
#
state = _get_state()
#
# Initial page setup
#
st.set_page_config(page_title="Promptsource", layout="wide")
st.sidebar.markdown(
"<center><a href='https://github.com/bigscience-workshop/promptsource'>💻Github - Promptsource\n\n</a></center>",
unsafe_allow_html=True,
)
mode = st.sidebar.selectbox(
label="Choose a mode",
options=select_options,
index=0,
key="mode_select",
)
st.sidebar.title(f"{side_bar_title_prefix} 🌸 - {mode}")
#
# Adds pygments styles to the page.
#
st.markdown(
"<style>" + HtmlFormatter(style="friendly").get_style_defs(".highlight") + "</style>", unsafe_allow_html=True
)
WIDTH = 80
def show_jinja(t, width=WIDTH):
wrap = textwrap.fill(t, width=width, replace_whitespace=False)
out = highlight(wrap, DjangoLexer(), HtmlFormatter())
st.write(out, unsafe_allow_html=True)
def show_text(t, width=WIDTH, with_markdown=False):
wrap = [textwrap.fill(subt, width=width, replace_whitespace=False) for subt in t.split("\n")]
wrap = "\n".join(wrap)
if with_markdown:
st.write(wrap, unsafe_allow_html=True)
else:
st.text(wrap)
#
# Loads template data
#
try:
template_collection = TemplateCollection()
except FileNotFoundError:
st.error(
"Unable to find the prompt folder!\n\n"
"We expect the folder to be in the working directory. "
"You might need to restart the app in the root directory of the repo."
)
st.stop()
if mode == "Helicopter view":
st.title("High level metrics")
st.write(
"If you want to contribute, please refer to the instructions in "
+ "[Contributing](https://github.com/bigscience-workshop/promptsource/blob/main/CONTRIBUTING.md)."
)
#
# Global metrics
#
counts = template_collection.get_templates_count()
nb_prompted_datasets = len(counts)
st.write(f"## Number of *prompted datasets*: `{nb_prompted_datasets}`")
nb_prompts = sum(counts.values())
st.write(f"## Number of *prompts*: `{nb_prompts}`")
#
# Metrics per dataset/subset
#
# Download dataset infos (multiprocessing download)
manager = Manager()
all_infos = manager.dict()
all_datasets = list(set([t[0] for t in template_collection.keys]))
def get_infos(d_name):
all_infos[d_name] = get_dataset_infos(d_name)
pool = Pool(processes=len(all_datasets))
pool.map(get_infos, all_datasets)
pool.close()
pool.join()
results = []
for (dataset_name, subset_name) in template_collection.keys:
# Collect split sizes (train, validation and test)
if dataset_name not in all_infos:
infos = get_dataset_infos(dataset_name)
all_infos[dataset_name] = infos
else:
infos = all_infos[dataset_name]
if infos:
if subset_name is None:
subset_infos = infos[list(infos.keys())[0]]
else:
subset_infos = infos[subset_name]
split_sizes = {k: v.num_examples for k, v in subset_infos.splits.items()}
else:
# Zaid/coqa_expanded and Zaid/quac_expanded don't have dataset_infos.json
# so infos is an empty dic, and `infos[list(infos.keys())[0]]` raises an error
# For simplicity, just filling `split_sizes` with nothing, so the displayed split sizes will be 0.
split_sizes = {}
# Collect template counts, original task counts and names
dataset_templates = template_collection.get_dataset(dataset_name, subset_name)
results.append(
{
"Dataset name": dataset_name,
"Subset name": "∅" if subset_name is None else subset_name,
"Train size": split_sizes["train"] if "train" in split_sizes else 0,
"Validation size": split_sizes["validation"] if "validation" in split_sizes else 0,
"Test size": split_sizes["test"] if "test" in split_sizes else 0,
"Number of prompts": len(dataset_templates),
"Number of original task prompts": sum(
[bool(t.metadata.original_task) for t in dataset_templates.templates.values()]
),
"Prompt names": [t.name for t in dataset_templates.templates.values()],
}
)
results_df = pd.DataFrame(results)
results_df.sort_values(["Number of prompts"], inplace=True, ascending=False)
results_df.reset_index(drop=True, inplace=True)
nb_training_instances = results_df["Train size"].sum()
st.write(f"## Number of *training instances*: `{nb_training_instances}`")
plot_df = results_df[["Dataset name", "Subset name", "Train size", "Number of prompts"]].copy()
plot_df["Name"] = plot_df["Dataset name"] + " - " + plot_df["Subset name"]
plot_df.sort_values(["Train size"], inplace=True, ascending=False)
fig = px.bar(
plot_df,
x="Name",
y="Train size",
hover_data=["Dataset name", "Subset name", "Number of prompts"],
log_y=True,
title="Number of training instances per data(sub)set - y-axis is in logscale",
)
fig.update_xaxes(visible=False, showticklabels=False)
st.plotly_chart(fig, use_container_width=True)
st.write(
f"- Top 3 training subsets account for `{100*plot_df[:3]['Train size'].sum()/nb_training_instances:.2f}%` of the training instances."
)
biggest_training_subset = plot_df.iloc[0]
st.write(
f"- Biggest training subset is *{biggest_training_subset['Name']}* with `{biggest_training_subset['Train size']}` instances"
)
smallest_training_subset = plot_df[plot_df["Train size"] > 0].iloc[-1]
st.write(
f"- Smallest training subset is *{smallest_training_subset['Name']}* with `{smallest_training_subset['Train size']}` instances"
)
st.markdown("***")
st.write("Details per dataset")
st.table(results_df)
else:
# Combining mode `Prompted dataset viewer` and `Sourcing` since the
# backbone of the interfaces is the same
assert mode in ["Prompted dataset viewer", "Sourcing"], ValueError(
f"`mode` ({mode}) should be in `[Helicopter view, Prompted dataset viewer, Sourcing]`"
)
#
# Loads dataset information
#
dataset_list = list_datasets(
template_collection,
state,
)
ag_news_index = dataset_list.index("ag_news")
#
# Select a dataset - starts with ag_news
#
dataset_key = st.sidebar.selectbox(
"Dataset",
dataset_list,
key="dataset_select",
index=ag_news_index,
help="Select the dataset to work on.",
)
#
# If a particular dataset is selected, loads dataset and template information
#
if dataset_key is not None:
#
# Check for subconfigurations (i.e. subsets)
#
configs = get_dataset_confs(dataset_key)
conf_option = None
if len(configs) > 0:
conf_option = st.sidebar.selectbox("Subset", configs, index=0, format_func=lambda a: a.name)
dataset = get_dataset(dataset_key, str(conf_option.name) if conf_option else None)
splits = list(dataset.keys())
index = 0
if "train" in splits:
index = splits.index("train")
split = st.sidebar.selectbox("Split", splits, key="split_select", index=index)
dataset = dataset[split]
dataset = renameDatasetColumn(dataset)
dataset_templates = template_collection.get_dataset(dataset_key, conf_option.name if conf_option else None)
template_list = dataset_templates.all_template_names
num_templates = len(template_list)
st.sidebar.write(
"No of prompts created for "
+ f"`{dataset_key + (('/' + conf_option.name) if conf_option else '')}`"
+ f": **{str(num_templates)}**"
)
if mode == "Prompted dataset viewer":
if num_templates > 0:
template_name = st.sidebar.selectbox(
"Prompt name",
template_list,
key="template_select",
index=0,
help="Select the prompt to visualize.",
)
step = 50
example_index = st.sidebar.number_input(
f"Select the example index (Size = {len(dataset)})",
min_value=0,
max_value=len(dataset) - step,
value=0,
step=step,
key="example_index_number_input",
help="Offset = 50.",
)
else: # mode = Sourcing
st.sidebar.subheader("Select Example")
example_index = st.sidebar.slider("Select the example index", 0, len(dataset) - 1)
example = dataset[example_index]
example = removeHyphen(example)
st.sidebar.write(example)
st.sidebar.subheader("Dataset Schema")
rendered_features = render_features(dataset.features)
st.sidebar.write(rendered_features)
#
# Display dataset information
#
st.header("Dataset: " + dataset_key + " " + (("/ " + conf_option.name) if conf_option else ""))
st.markdown(
"*Homepage*: "
+ dataset.info.homepage
+ "\n\n*Dataset*: https://github.com/huggingface/datasets/blob/master/datasets/%s/%s.py"
% (dataset_key, dataset_key)
)
md = """
%s
""" % (
dataset.info.description.replace("\\", "") if dataset_key else ""
)
st.markdown(md)
#
# Body of the app: display prompted examples in mode `Prompted dataset viewer`
# or text boxes to create new prompts in mode `Sourcing`
#
if mode == "Prompted dataset viewer":
#
# Display template information
#
if num_templates > 0:
template = dataset_templates[template_name]
st.subheader("Prompt")
st.markdown("##### Name")
st.text(template.name)
st.markdown("##### Reference")
st.text(template.reference)
st.markdown("##### Original Task? ")
st.text(template.metadata.original_task)
st.markdown("##### Choices in template? ")
st.text(template.metadata.choices_in_prompt)
st.markdown("##### Metrics")
st.text(", ".join(template.metadata.metrics) if template.metadata.metrics else None)
st.markdown("##### Answer Choices")
if template.get_answer_choices_expr() is not None:
show_jinja(template.get_answer_choices_expr())
else:
st.text(None)
st.markdown("##### Jinja template")
splitted_template = template.jinja.split("|||")
st.markdown("###### Input template")
show_jinja(splitted_template[0].strip())
if len(splitted_template) > 1:
st.markdown("###### Target template")
show_jinja(splitted_template[1].strip())
st.markdown("***")
#
# Display a couple (steps) examples
#
for ex_idx in range(example_index, example_index + step):
if ex_idx >= len(dataset):
continue
example = dataset[ex_idx]
example = removeHyphen(example)
col1, _, col2 = st.beta_columns([12, 1, 12])
with col1:
st.write(example)
if num_templates > 0:
with col2:
prompt = template.apply(example, highlight_variables=False)
if prompt == [""]:
st.write("∅∅∅ *Blank result*")
else:
st.write("Input")
show_text(prompt[0])
if len(prompt) > 1:
st.write("Target")
show_text(prompt[1])
st.markdown("***")
else: # mode = Sourcing
st.markdown("## Prompt Creator")
#
# Create a new template or select an existing one
#
col1a, col1b, _, col2 = st.beta_columns([9, 9, 1, 6])
# current_templates_key and state.templates_key are keys for the templates object
current_templates_key = (dataset_key, conf_option.name if conf_option else None)
# Resets state if there has been a change in templates_key
if state.templates_key != current_templates_key:
state.templates_key = current_templates_key
reset_template_state()
with col1a, st.form("new_template_form"):
new_template_name = st.text_input(
"Create a New Prompt",
key="new_template",
value="",
help="Enter name and hit enter to create a new prompt.",
)
new_template_submitted = st.form_submit_button("Create")
if new_template_submitted:
if new_template_name in dataset_templates.all_template_names:
st.error(
f"A prompt with the name {new_template_name} already exists "
f"for dataset {state.templates_key}."
)
elif new_template_name == "":
st.error("Need to provide a prompt name.")
else:
template = Template(new_template_name, "", "")
dataset_templates.add_template(template)
reset_template_state()
state.template_name = new_template_name
else:
state.new_template_name = None
with col1b, st.beta_expander("or Select Prompt", expanded=True):
dataset_templates = template_collection.get_dataset(*state.templates_key)
template_list = dataset_templates.all_template_names
if state.template_name:
index = template_list.index(state.template_name)
else:
index = 0
state.template_name = st.selectbox(
"", template_list, key="template_select", index=index, help="Select the prompt to work on."
)
if st.button("Delete Prompt", key="delete_prompt"):
dataset_templates.remove_template(state.template_name)
reset_template_state()
variety_guideline = """
:heavy_exclamation_mark::question:Creating a diverse set of prompts whose differences go beyond surface wordings (i.e. marginally changing 2 or 3 words) is highly encouraged.
Ultimately, the hope is that exposing the model to such a diversity will have a non-trivial impact on the model's robustness to the prompt formulation.
\r**To get various prompts, you can try moving the cursor along theses axes**:
\n- **Interrogative vs affirmative form**: Ask a question about an attribute of the inputs or tell the model to decide something about the input.
\n- **Task description localization**: where is the task description blended with the inputs? In the beginning, in the middle, at the end?
\n- **Implicit situation or contextualization**: how explicit is the query? For instance, *Given this review, would you buy this product?* is an indirect way to ask whether the review is positive.
"""
col1, _, _ = st.beta_columns([18, 1, 6])
with col1:
if state.template_name is not None:
show_text(variety_guideline, with_markdown=True)
#
# Edit the created or selected template
#
col1, _, col2 = st.beta_columns([18, 1, 6])
with col1:
if state.template_name is not None:
template = dataset_templates[state.template_name]
#
# If template is selected, displays template editor
#
with st.form("edit_template_form"):
updated_template_name = st.text_input("Name", value=template.name)
state.reference = st.text_input(
"Prompt Reference",
help="Short description of the prompt and/or paper reference for the prompt.",
value=template.reference,
)
# Metadata
state.metadata = template.metadata
state.metadata.original_task = st.checkbox(
"Original Task?",
value=template.metadata.original_task,
help="Prompt asks model to perform the original task designed for this dataset.",
)
state.metadata.choices_in_prompt = st.checkbox(
"Choices in Template?",
value=template.metadata.choices_in_prompt,
help="Prompt explicitly lists choices in the template for the output.",
)
# Metrics from here:
# https://github.com/google-research/text-to-text-transfer-transformer/blob/4b580f23968c2139be7fb1cd53b22c7a7f686cdf/t5/evaluation/metrics.py
metrics_choices = [
"BLEU",
"ROUGE",
"Squad",
"Trivia QA",
"Accuracy",
"Pearson Correlation",
"Spearman Correlation",
"MultiRC",
"AUC",
"COQA F1",
"Edit Distance",
]
# Add mean reciprocal rank
metrics_choices.append("Mean Reciprocal Rank")
# Add generic other
metrics_choices.append("Other")
# Sort alphabetically
metrics_choices = sorted(metrics_choices)
state.metadata.metrics = st.multiselect(
"Metrics",
metrics_choices,
default=template.metadata.metrics,
help="Select all metrics that are commonly used (or should "
"be used if a new task) to evaluate this prompt.",
)
# Answer choices
if template.get_answer_choices_expr() is not None:
answer_choices = template.get_answer_choices_expr()
else:
answer_choices = ""
state.answer_choices = st.text_input(
"Answer Choices",
value=answer_choices,
help="A Jinja expression for computing answer choices. "
"Separate choices with a triple bar (|||).",
)
# Jinja
state.jinja = st.text_area("Template", height=40, value=template.jinja)
# Submit form
if st.form_submit_button("Save"):
if (
updated_template_name in dataset_templates.all_template_names
and updated_template_name != state.template_name
):
st.error(
f"A prompt with the name {updated_template_name} already exists "
f"for dataset {state.templates_key}."
)
elif updated_template_name == "":
st.error("Need to provide a prompt name.")
else:
# Parses state.answer_choices
if state.answer_choices == "":
updated_answer_choices = None
else:
updated_answer_choices = state.answer_choices
dataset_templates.update_template(
state.template_name,
updated_template_name,
state.jinja,
state.reference,
state.metadata,
updated_answer_choices,
)
# Update the state as well
state.template_name = updated_template_name
#
# Displays template output on current example if a template is selected
# (in second column)
#
with col2:
if state.template_name is not None:
st.empty()
template = dataset_templates[state.template_name]
prompt = template.apply(example)
if prompt == [""]:
st.write("∅∅∅ *Blank result*")
else:
st.write("Input")
show_text(prompt[0], width=40)
if len(prompt) > 1:
st.write("Target")
show_text(prompt[1], width=40)
#
# Must sync state at end
#
state.sync()