Build error
Build error
import argparse | |
import textwrap | |
from multiprocessing import Manager, Pool | |
import pandas as pd | |
import as px | |
import streamlit as st | |
from datasets import get_dataset_infos | |
from pygments import highlight | |
from pygments.formatters import HtmlFormatter | |
from pygments.lexers import DjangoLexer | |
from promptsource.session import _get_state | |
from promptsource.templates import Template, TemplateCollection | |
from promptsource.utils import ( | |
get_dataset, | |
get_dataset_confs, | |
list_datasets, | |
removeHyphen, | |
renameDatasetColumn, | |
render_features, | |
) | |
# add an argument for read-only | |
# At the moment, streamlit does not handle python script arguments gracefully. | |
# Thus, for read-only mode, you have to type one of the below two: | |
# streamlit run promptsource/ -- -r | |
# streamlit run promptsource/ -- --read-only | |
# Check for more information. | |
parser = argparse.ArgumentParser(description="run with args") | |
parser.add_argument("-r", "--read-only", action="store_true", help="whether to run it as read-only mode") | |
args = parser.parse_args() | |
if args.read_only: | |
select_options = ["Helicopter view", "Prompted dataset viewer"] | |
side_bar_title_prefix = "Promptsource (Read only)" | |
else: | |
select_options = ["Helicopter view", "Prompted dataset viewer", "Sourcing"] | |
side_bar_title_prefix = "Promptsource" | |
# | |
# Helper functions for datasets library | |
# | |
get_dataset = st.cache(allow_output_mutation=True)(get_dataset) | |
get_dataset_confs = st.cache(get_dataset_confs) | |
def reset_template_state(): | |
state.template_name = None | |
state.jinja = None | |
state.reference = None | |
# | |
# Loads session state | |
# | |
state = _get_state() | |
# | |
# Initial page setup | |
# | |
st.set_page_config(page_title="Promptsource", layout="wide") | |
st.sidebar.markdown( | |
"<center><a href=''>💻Github - Promptsource\n\n</a></center>", | |
unsafe_allow_html=True, | |
) | |
mode = st.sidebar.selectbox( | |
label="Choose a mode", | |
options=select_options, | |
index=0, | |
key="mode_select", | |
) | |
st.sidebar.title(f"{side_bar_title_prefix} 🌸 - {mode}") | |
# | |
# Adds pygments styles to the page. | |
# | |
st.markdown( | |
"<style>" + HtmlFormatter(style="friendly").get_style_defs(".highlight") + "</style>", unsafe_allow_html=True | |
) | |
WIDTH = 80 | |
def show_jinja(t, width=WIDTH): | |
wrap = textwrap.fill(t, width=width, replace_whitespace=False) | |
out = highlight(wrap, DjangoLexer(), HtmlFormatter()) | |
st.write(out, unsafe_allow_html=True) | |
def show_text(t, width=WIDTH, with_markdown=False): | |
wrap = [textwrap.fill(subt, width=width, replace_whitespace=False) for subt in t.split("\n")] | |
wrap = "\n".join(wrap) | |
if with_markdown: | |
st.write(wrap, unsafe_allow_html=True) | |
else: | |
st.text(wrap) | |
# | |
# Loads template data | |
# | |
try: | |
template_collection = TemplateCollection() | |
except FileNotFoundError: | |
st.error( | |
"Unable to find the prompt folder!\n\n" | |
"We expect the folder to be in the working directory. " | |
"You might need to restart the app in the root directory of the repo." | |
) | |
st.stop() | |
if mode == "Helicopter view": | |
st.title("High level metrics") | |
st.write( | |
"If you want to contribute, please refer to the instructions in " | |
+ "[Contributing](" | |
) | |
# | |
# Global metrics | |
# | |
counts = template_collection.get_templates_count() | |
nb_prompted_datasets = len(counts) | |
st.write(f"## Number of *prompted datasets*: `{nb_prompted_datasets}`") | |
nb_prompts = sum(counts.values()) | |
st.write(f"## Number of *prompts*: `{nb_prompts}`") | |
# | |
# Metrics per dataset/subset | |
# | |
# Download dataset infos (multiprocessing download) | |
manager = Manager() | |
all_infos = manager.dict() | |
all_datasets = list(set([t[0] for t in template_collection.keys])) | |
def get_infos(d_name): | |
all_infos[d_name] = get_dataset_infos(d_name) | |
pool = Pool(processes=len(all_datasets)) | |, all_datasets) | |
pool.close() | |
pool.join() | |
results = [] | |
for (dataset_name, subset_name) in template_collection.keys: | |
# Collect split sizes (train, validation and test) | |
if dataset_name not in all_infos: | |
infos = get_dataset_infos(dataset_name) | |
all_infos[dataset_name] = infos | |
else: | |
infos = all_infos[dataset_name] | |
if infos: | |
if subset_name is None: | |
subset_infos = infos[list(infos.keys())[0]] | |
else: | |
subset_infos = infos[subset_name] | |
split_sizes = {k: v.num_examples for k, v in subset_infos.splits.items()} | |
else: | |
# Zaid/coqa_expanded and Zaid/quac_expanded don't have dataset_infos.json | |
# so infos is an empty dic, and `infos[list(infos.keys())[0]]` raises an error | |
# For simplicity, just filling `split_sizes` with nothing, so the displayed split sizes will be 0. | |
split_sizes = {} | |
# Collect template counts, original task counts and names | |
dataset_templates = template_collection.get_dataset(dataset_name, subset_name) | |
results.append( | |
{ | |
"Dataset name": dataset_name, | |
"Subset name": "∅" if subset_name is None else subset_name, | |
"Train size": split_sizes["train"] if "train" in split_sizes else 0, | |
"Validation size": split_sizes["validation"] if "validation" in split_sizes else 0, | |
"Test size": split_sizes["test"] if "test" in split_sizes else 0, | |
"Number of prompts": len(dataset_templates), | |
"Number of original task prompts": sum( | |
[bool(t.metadata.original_task) for t in dataset_templates.templates.values()] | |
), | |
"Prompt names": [ for t in dataset_templates.templates.values()], | |
} | |
) | |
results_df = pd.DataFrame(results) | |
results_df.sort_values(["Number of prompts"], inplace=True, ascending=False) | |
results_df.reset_index(drop=True, inplace=True) | |
nb_training_instances = results_df["Train size"].sum() | |
st.write(f"## Number of *training instances*: `{nb_training_instances}`") | |
plot_df = results_df[["Dataset name", "Subset name", "Train size", "Number of prompts"]].copy() | |
plot_df["Name"] = plot_df["Dataset name"] + " - " + plot_df["Subset name"] | |
plot_df.sort_values(["Train size"], inplace=True, ascending=False) | |
fig = | |
plot_df, | |
x="Name", | |
y="Train size", | |
hover_data=["Dataset name", "Subset name", "Number of prompts"], | |
log_y=True, | |
title="Number of training instances per data(sub)set - y-axis is in logscale", | |
) | |
fig.update_xaxes(visible=False, showticklabels=False) | |
st.plotly_chart(fig, use_container_width=True) | |
st.write( | |
f"- Top 3 training subsets account for `{100*plot_df[:3]['Train size'].sum()/nb_training_instances:.2f}%` of the training instances." | |
) | |
biggest_training_subset = plot_df.iloc[0] | |
st.write( | |
f"- Biggest training subset is *{biggest_training_subset['Name']}* with `{biggest_training_subset['Train size']}` instances" | |
) | |
smallest_training_subset = plot_df[plot_df["Train size"] > 0].iloc[-1] | |
st.write( | |
f"- Smallest training subset is *{smallest_training_subset['Name']}* with `{smallest_training_subset['Train size']}` instances" | |
) | |
st.markdown("***") | |
st.write("Details per dataset") | |
st.table(results_df) | |
else: | |
# Combining mode `Prompted dataset viewer` and `Sourcing` since the | |
# backbone of the interfaces is the same | |
assert mode in ["Prompted dataset viewer", "Sourcing"], ValueError( | |
f"`mode` ({mode}) should be in `[Helicopter view, Prompted dataset viewer, Sourcing]`" | |
) | |
# | |
# Loads dataset information | |
# | |
dataset_list = list_datasets( | |
template_collection, | |
state, | |
) | |
ag_news_index = dataset_list.index("ag_news") | |
# | |
# Select a dataset - starts with ag_news | |
# | |
dataset_key = st.sidebar.selectbox( | |
"Dataset", | |
dataset_list, | |
key="dataset_select", | |
index=ag_news_index, | |
help="Select the dataset to work on.", | |
) | |
# | |
# If a particular dataset is selected, loads dataset and template information | |
# | |
if dataset_key is not None: | |
# | |
# Check for subconfigurations (i.e. subsets) | |
# | |
configs = get_dataset_confs(dataset_key) | |
conf_option = None | |
if len(configs) > 0: | |
conf_option = st.sidebar.selectbox("Subset", configs, index=0, format_func=lambda a: | |
dataset = get_dataset(dataset_key, str( if conf_option else None) | |
splits = list(dataset.keys()) | |
index = 0 | |
if "train" in splits: | |
index = splits.index("train") | |
split = st.sidebar.selectbox("Split", splits, key="split_select", index=index) | |
dataset = dataset[split] | |
dataset = renameDatasetColumn(dataset) | |
dataset_templates = template_collection.get_dataset(dataset_key, if conf_option else None) | |
template_list = dataset_templates.all_template_names | |
num_templates = len(template_list) | |
st.sidebar.write( | |
"No of prompts created for " | |
+ f"`{dataset_key + (('/' + if conf_option else '')}`" | |
+ f": **{str(num_templates)}**" | |
) | |
if mode == "Prompted dataset viewer": | |
if num_templates > 0: | |
template_name = st.sidebar.selectbox( | |
"Prompt name", | |
template_list, | |
key="template_select", | |
index=0, | |
help="Select the prompt to visualize.", | |
) | |
step = 50 | |
example_index = st.sidebar.number_input( | |
f"Select the example index (Size = {len(dataset)})", | |
min_value=0, | |
max_value=len(dataset) - step, | |
value=0, | |
step=step, | |
key="example_index_number_input", | |
help="Offset = 50.", | |
) | |
else: # mode = Sourcing | |
st.sidebar.subheader("Select Example") | |
example_index = st.sidebar.slider("Select the example index", 0, len(dataset) - 1) | |
example = dataset[example_index] | |
example = removeHyphen(example) | |
st.sidebar.write(example) | |
st.sidebar.subheader("Dataset Schema") | |
rendered_features = render_features(dataset.features) | |
st.sidebar.write(rendered_features) | |
# | |
# Display dataset information | |
# | |
st.header("Dataset: " + dataset_key + " " + (("/ " + if conf_option else "")) | |
st.markdown( | |
"*Homepage*: " | |
+ | |
+ "\n\n*Dataset*:" | |
% (dataset_key, dataset_key) | |
) | |
md = """ | |
%s | |
""" % ( | |"\\", "") if dataset_key else "" | |
) | |
st.markdown(md) | |
# | |
# Body of the app: display prompted examples in mode `Prompted dataset viewer` | |
# or text boxes to create new prompts in mode `Sourcing` | |
# | |
if mode == "Prompted dataset viewer": | |
# | |
# Display template information | |
# | |
if num_templates > 0: | |
template = dataset_templates[template_name] | |
st.subheader("Prompt") | |
st.markdown("##### Name") | |
st.text( | |
st.markdown("##### Reference") | |
st.text(template.reference) | |
st.markdown("##### Original Task? ") | |
st.text(template.metadata.original_task) | |
st.markdown("##### Choices in template? ") | |
st.text(template.metadata.choices_in_prompt) | |
st.markdown("##### Metrics") | |
st.text(", ".join(template.metadata.metrics) if template.metadata.metrics else None) | |
st.markdown("##### Answer Choices") | |
if template.get_answer_choices_expr() is not None: | |
show_jinja(template.get_answer_choices_expr()) | |
else: | |
st.text(None) | |
st.markdown("##### Jinja template") | |
splitted_template = template.jinja.split("|||") | |
st.markdown("###### Input template") | |
show_jinja(splitted_template[0].strip()) | |
if len(splitted_template) > 1: | |
st.markdown("###### Target template") | |
show_jinja(splitted_template[1].strip()) | |
st.markdown("***") | |
# | |
# Display a couple (steps) examples | |
# | |
for ex_idx in range(example_index, example_index + step): | |
if ex_idx >= len(dataset): | |
continue | |
example = dataset[ex_idx] | |
example = removeHyphen(example) | |
col1, _, col2 = st.beta_columns([12, 1, 12]) | |
with col1: | |
st.write(example) | |
if num_templates > 0: | |
with col2: | |
prompt = template.apply(example, highlight_variables=False) | |
if prompt == [""]: | |
st.write("∅∅∅ *Blank result*") | |
else: | |
st.write("Input") | |
show_text(prompt[0]) | |
if len(prompt) > 1: | |
st.write("Target") | |
show_text(prompt[1]) | |
st.markdown("***") | |
else: # mode = Sourcing | |
st.markdown("## Prompt Creator") | |
# | |
# Create a new template or select an existing one | |
# | |
col1a, col1b, _, col2 = st.beta_columns([9, 9, 1, 6]) | |
# current_templates_key and state.templates_key are keys for the templates object | |
current_templates_key = (dataset_key, if conf_option else None) | |
# Resets state if there has been a change in templates_key | |
if state.templates_key != current_templates_key: | |
state.templates_key = current_templates_key | |
reset_template_state() | |
with col1a, st.form("new_template_form"): | |
new_template_name = st.text_input( | |
"Create a New Prompt", | |
key="new_template", | |
value="", | |
help="Enter name and hit enter to create a new prompt.", | |
) | |
new_template_submitted = st.form_submit_button("Create") | |
if new_template_submitted: | |
if new_template_name in dataset_templates.all_template_names: | |
st.error( | |
f"A prompt with the name {new_template_name} already exists " | |
f"for dataset {state.templates_key}." | |
) | |
elif new_template_name == "": | |
st.error("Need to provide a prompt name.") | |
else: | |
template = Template(new_template_name, "", "") | |
dataset_templates.add_template(template) | |
reset_template_state() | |
state.template_name = new_template_name | |
else: | |
state.new_template_name = None | |
with col1b, st.beta_expander("or Select Prompt", expanded=True): | |
dataset_templates = template_collection.get_dataset(*state.templates_key) | |
template_list = dataset_templates.all_template_names | |
if state.template_name: | |
index = template_list.index(state.template_name) | |
else: | |
index = 0 | |
state.template_name = st.selectbox( | |
"", template_list, key="template_select", index=index, help="Select the prompt to work on." | |
) | |
if st.button("Delete Prompt", key="delete_prompt"): | |
dataset_templates.remove_template(state.template_name) | |
reset_template_state() | |
variety_guideline = """ | |
:heavy_exclamation_mark::question:Creating a diverse set of prompts whose differences go beyond surface wordings (i.e. marginally changing 2 or 3 words) is highly encouraged. | |
Ultimately, the hope is that exposing the model to such a diversity will have a non-trivial impact on the model's robustness to the prompt formulation. | |
\r**To get various prompts, you can try moving the cursor along theses axes**: | |
\n- **Interrogative vs affirmative form**: Ask a question about an attribute of the inputs or tell the model to decide something about the input. | |
\n- **Task description localization**: where is the task description blended with the inputs? In the beginning, in the middle, at the end? | |
\n- **Implicit situation or contextualization**: how explicit is the query? For instance, *Given this review, would you buy this product?* is an indirect way to ask whether the review is positive. | |
""" | |
col1, _, _ = st.beta_columns([18, 1, 6]) | |
with col1: | |
if state.template_name is not None: | |
show_text(variety_guideline, with_markdown=True) | |
# | |
# Edit the created or selected template | |
# | |
col1, _, col2 = st.beta_columns([18, 1, 6]) | |
with col1: | |
if state.template_name is not None: | |
template = dataset_templates[state.template_name] | |
# | |
# If template is selected, displays template editor | |
# | |
with st.form("edit_template_form"): | |
updated_template_name = st.text_input("Name", | |
state.reference = st.text_input( | |
"Prompt Reference", | |
help="Short description of the prompt and/or paper reference for the prompt.", | |
value=template.reference, | |
) | |
# Metadata | |
state.metadata = template.metadata | |
state.metadata.original_task = st.checkbox( | |
"Original Task?", | |
value=template.metadata.original_task, | |
help="Prompt asks model to perform the original task designed for this dataset.", | |
) | |
state.metadata.choices_in_prompt = st.checkbox( | |
"Choices in Template?", | |
value=template.metadata.choices_in_prompt, | |
help="Prompt explicitly lists choices in the template for the output.", | |
) | |
# Metrics from here: | |
# | |
metrics_choices = [ | |
"BLEU", | |
"ROUGE", | |
"Squad", | |
"Trivia QA", | |
"Accuracy", | |
"Pearson Correlation", | |
"Spearman Correlation", | |
"MultiRC", | |
"AUC", | |
"COQA F1", | |
"Edit Distance", | |
] | |
# Add mean reciprocal rank | |
metrics_choices.append("Mean Reciprocal Rank") | |
# Add generic other | |
metrics_choices.append("Other") | |
# Sort alphabetically | |
metrics_choices = sorted(metrics_choices) | |
state.metadata.metrics = st.multiselect( | |
"Metrics", | |
metrics_choices, | |
default=template.metadata.metrics, | |
help="Select all metrics that are commonly used (or should " | |
"be used if a new task) to evaluate this prompt.", | |
) | |
# Answer choices | |
if template.get_answer_choices_expr() is not None: | |
answer_choices = template.get_answer_choices_expr() | |
else: | |
answer_choices = "" | |
state.answer_choices = st.text_input( | |
"Answer Choices", | |
value=answer_choices, | |
help="A Jinja expression for computing answer choices. " | |
"Separate choices with a triple bar (|||).", | |
) | |
# Jinja | |
state.jinja = st.text_area("Template", height=40, value=template.jinja) | |
# Submit form | |
if st.form_submit_button("Save"): | |
if ( | |
updated_template_name in dataset_templates.all_template_names | |
and updated_template_name != state.template_name | |
): | |
st.error( | |
f"A prompt with the name {updated_template_name} already exists " | |
f"for dataset {state.templates_key}." | |
) | |
elif updated_template_name == "": | |
st.error("Need to provide a prompt name.") | |
else: | |
# Parses state.answer_choices | |
if state.answer_choices == "": | |
updated_answer_choices = None | |
else: | |
updated_answer_choices = state.answer_choices | |
dataset_templates.update_template( | |
state.template_name, | |
updated_template_name, | |
state.jinja, | |
state.reference, | |
state.metadata, | |
updated_answer_choices, | |
) | |
# Update the state as well | |
state.template_name = updated_template_name | |
# | |
# Displays template output on current example if a template is selected | |
# (in second column) | |
# | |
with col2: | |
if state.template_name is not None: | |
st.empty() | |
template = dataset_templates[state.template_name] | |
prompt = template.apply(example) | |
if prompt == [""]: | |
st.write("∅∅∅ *Blank result*") | |
else: | |
st.write("Input") | |
show_text(prompt[0], width=40) | |
if len(prompt) > 1: | |
st.write("Target") | |
show_text(prompt[1], width=40) | |
# | |
# Must sync state at end | |
# | |
state.sync() | |