Spaces:
Build error
Build error
import os | |
import random | |
import uuid | |
from collections import Counter, defaultdict | |
from shutil import rmtree | |
from typing import Dict, List, Optional, Tuple | |
import pandas as pd | |
import pkg_resources | |
import yaml | |
from jinja2 import BaseLoader, Environment, meta | |
# Truncation of jinja template variables | |
# 1710 = 300 words x 4.7 avg characters per word + 300 spaces | |
TEXT_VAR_LENGTH = 2048 | |
# Local path to the folder containing the templates | |
TEMPLATES_FOLDER_PATH = pkg_resources.resource_filename(__name__, "templates") | |
env = Environment(loader=BaseLoader) | |
# Allow the python function zip() | |
env.globals.update(zip=zip) | |
# These are users whose datasets should be included in the results returned by | |
# filter_english_datasets (regardless of their metadata) | |
INCLUDED_USERS = {"Zaid", "craffel"} | |
def highlight(input): | |
return "<span style='color: #F08080'>" + input + "</span>" | |
def choice(choices): | |
return random.choice(choices) | |
def most_frequent(items): | |
"""Returns the set of items which appear most frequently in the input""" | |
if not items: | |
return | |
item_counts = Counter(items).most_common() | |
max_freq = item_counts[0][1] | |
most_frequent_items = [c[0] for c in item_counts if c[1] == max_freq] | |
return most_frequent_items | |
env.filters["highlight"] = highlight | |
env.filters["choice"] = choice | |
env.filters["most_frequent"] = most_frequent | |
class Template(yaml.YAMLObject): | |
""" | |
A prompt template. | |
""" | |
yaml_tag = "!Template" | |
def __init__(self, name, jinja, reference, metadata=None, answer_choices=None): | |
""" | |
Creates a prompt template. | |
A prompt template is expressed in Jinja. It is rendered using an example | |
from the corresponding Hugging Face datasets library (a dictionary). The | |
separator ||| should appear once to divide the template into prompt and | |
output. Generally, the prompt should provide information on the desired | |
behavior, e.g., text passage and instructions, and the output should be | |
a desired response. | |
:param name: unique name (per dataset) for template | |
:param jinja: template expressed in Jinja | |
:param reference: string describing author or paper reference for template | |
:param metadata: a Metadata object with template annotations | |
:param answer_choices: Jinja expression for answer choices. Should produce | |
a ||| delimited string of choices that enumerates | |
the possible completions for templates that should | |
be evaluated as ranked completions. If None, then | |
the template is open-ended. This list is accessible | |
from within Jinja as the variable `answer_choices`. | |
""" | |
self.id = str(uuid.uuid4()) | |
self.name = name | |
self.jinja = jinja | |
self.reference = reference | |
self.metadata = metadata if metadata is not None else Template.Metadata() | |
self.answer_choices = answer_choices | |
def get_id(self): | |
""" | |
Returns the id of the template | |
:return: unique id for template | |
""" | |
return self.id | |
def get_name(self): | |
""" | |
Returns the name of the template | |
:return: unique (per dataset) name for template | |
""" | |
return self.name | |
def get_reference(self): | |
""" | |
Returns the bibliographic reference (or author) for the template | |
:return: reference as a string | |
""" | |
return self.reference | |
def get_answer_choices_expr(self): | |
""" | |
Returns a Jinja expression for computing the answer choices from an example. | |
:return: String, or None if no answer choices | |
""" | |
return self.answer_choices | |
def get_answer_choices_list(self, example): | |
""" | |
Returns a list of answer choices for a given example | |
:return: list of strings, or None if get_answer_choices_expr is None | |
""" | |
jinja = self.get_answer_choices_expr() | |
if jinja is None: | |
return None | |
rtemplate = env.from_string(jinja) | |
protected_example = self._escape_pipe(example) | |
rendered_choices = rtemplate.render(**protected_example) | |
return [self._unescape_pipe(answer_choice.strip()) for answer_choice in rendered_choices.split("|||")] | |
def get_fixed_answer_choices_list(self): | |
""" | |
Returns a list of answer choices that is static across examples, if possible | |
:return: list of strings, or None if no static list exists | |
""" | |
jinja = self.get_answer_choices_expr() | |
if jinja is None: | |
return None | |
parse = env.parse(jinja) | |
variables = meta.find_undeclared_variables(parse) | |
if len(variables) == 0: | |
rtemplate = env.from_string(jinja) | |
rendered_choices = rtemplate.render() | |
return [answer_choice.strip() for answer_choice in rendered_choices.split("|||")] | |
else: | |
return None | |
def apply(self, example, truncate=True, highlight_variables=False): | |
""" | |
Creates a prompt by applying this template to an example | |
:param example: the dataset example to create a prompt for | |
:param truncate: if True, example fields will be truncated to TEXT_VAR_LENGTH chars | |
:param highlight_variables: highlight the added variables | |
:return: tuple of 2 strings, for prompt and output | |
""" | |
jinja = self.jinja | |
# Truncates the prompt if needed | |
if truncate: | |
trunc_command = ( | |
f" | string | truncate({TEXT_VAR_LENGTH}) }}}}" # Escaping curly braces requires doubling them | |
) | |
jinja = jinja.replace("}}", trunc_command) | |
# Highlights text that was substituted for variables, if requested | |
if highlight_variables: | |
jinja = jinja.replace("}}", " | highlight }}") | |
rtemplate = env.from_string(jinja) | |
protected_example = self._escape_pipe(example) | |
# Adds in answer_choices variable | |
if "answer_choices" in protected_example: | |
raise ValueError("Example contains the restricted key 'answer_choices'.") | |
protected_example["answer_choices"] = self.get_answer_choices_list(example) | |
# Renders the Jinja template | |
rendered_example = rtemplate.render(**protected_example) | |
# Splits on the separator, and then replaces back any occurrences of the | |
# separator in the original example | |
return [self._unescape_pipe(part).strip() for part in rendered_example.split("|||")] | |
pipe_protector = "3ed2dface8203c4c9dfb1a5dc58e41e0" | |
def _escape_pipe(cls, example): | |
# Replaces any occurrences of the "|||" separator in the example, which | |
# which will be replaced back after splitting | |
protected_example = { | |
key: value.replace("|||", cls.pipe_protector) if isinstance(value, str) else value | |
for key, value in example.items() | |
} | |
return protected_example | |
def _unescape_pipe(cls, string): | |
# replaces back any occurrences of the separator in a string | |
return string.replace(cls.pipe_protector, "|||") | |
class Metadata(yaml.YAMLObject): | |
""" | |
Metadata for a prompt template. | |
""" | |
yaml_tag = "!TemplateMetadata" | |
def __init__( | |
self, | |
original_task: Optional[bool] = None, | |
choices_in_prompt: Optional[bool] = None, | |
metrics: Optional[List[str]] = None, | |
): | |
""" | |
Initializes template metadata. | |
In the following, trivial choices are defined as Yes/No, True/False, | |
etc. and nontrivial choices are other types of choices denoted in | |
the answer_choices field. | |
:param original_task: If True, this prompt asks a model to perform the original task designed for | |
this dataset. | |
:param choices_in_prompt: If True, the answer choices are included in the templates such that models | |
see those choices in the input. Only applicable to classification tasks. | |
:param metrics: List of strings denoting metrics to use for evaluation | |
""" | |
self.original_task = original_task | |
self.choices_in_prompt = choices_in_prompt | |
self.metrics = metrics | |
class TemplateCollection: | |
""" | |
This helper class wraps the DatasetTemplates class | |
- Initialized the DatasetTemplates for all existing template folder | |
- Give access to each DatasetTemplates | |
- Provides aggregated counts over all DatasetTemplates | |
""" | |
def __init__(self): | |
# Dict of all the DatasetTemplates, key is the tuple (dataset_name, subset_name) | |
self.datasets_templates: Dict[(str, Optional[str]), DatasetTemplates] = self._collect_datasets() | |
def keys(self): | |
return list(self.datasets_templates.keys()) | |
def __len__(self) -> int: | |
return len(self.datasets_templates) | |
def remove(self, dataset_name: str, subset_name: Optional[str] = None) -> None: | |
del self.datasets_templates[dataset_name, subset_name] | |
def _collect_datasets(self) -> Dict[Tuple[str, str], "DatasetTemplates"]: | |
""" | |
Initialize a DatasetTemplates object for each templates.yaml detected in the templates folder | |
Returns: a dict with key=(dataset_name, subset_name) | |
""" | |
dataset_folders = os.listdir(TEMPLATES_FOLDER_PATH) | |
dataset_folders = [folder for folder in dataset_folders if not folder.startswith(".")] | |
output = {} # format is {(dataset_name, subset_name): DatasetsTemplates} | |
for dataset in dataset_folders: | |
if dataset in INCLUDED_USERS: | |
for filename in os.listdir(os.path.join(TEMPLATES_FOLDER_PATH, dataset)): | |
output = {**output, **self._collect_dataset(dataset + "/" + filename)} | |
else: | |
output = {**output, **self._collect_dataset(dataset)} | |
return output | |
def _collect_dataset(self, dataset): | |
output = {} # format is {(dataset_name, subset_name): DatasetsTemplates} | |
for filename in os.listdir(os.path.join(TEMPLATES_FOLDER_PATH, dataset)): | |
if filename.endswith(".yaml"): | |
# If there is no sub-folder, there is no subset for this dataset | |
output[(dataset, None)] = DatasetTemplates(dataset) | |
else: | |
# This is a subfolder, and its name corresponds to the subset name | |
output[(dataset, filename)] = DatasetTemplates(dataset_name=dataset, subset_name=filename) | |
return output | |
def get_dataset(self, dataset_name: str, subset_name: Optional[str] = None) -> "DatasetTemplates": | |
""" | |
Return the DatasetTemplates object corresponding to the dataset name | |
:param dataset_name: name of the dataset to get | |
:param subset_name: name of the subset | |
""" | |
# if the dataset does not exist, we add it | |
if dataset_name not in self.keys: | |
self.datasets_templates[(dataset_name, subset_name)] = DatasetTemplates(dataset_name, subset_name) | |
return self.datasets_templates[(dataset_name, subset_name)] | |
def get_templates_count(self) -> Dict: | |
""" | |
Return the overall number count over all datasets | |
NB: we don't breakdown datasets into subsets for the count, i.e subsets count are included | |
into the dataset count | |
""" | |
count_dict = defaultdict(int) | |
for k, v in self.datasets_templates.items(): | |
# Subsets count towards dataset count | |
count_dict[k[0]] += len(v) | |
# converting to regular dict | |
return dict(count_dict) | |
class DatasetTemplates: | |
""" | |
Class that wraps all templates for a specific dataset/subset and implements all the helper | |
functions necessary to read/write to the yaml file | |
""" | |
TEMPLATES_KEY = "templates" | |
DATASET_KEY = "dataset" | |
SUBSET_KEY = "subset" | |
TEMPLATE_FILENAME = "templates.yaml" | |
def __init__(self, dataset_name: str, subset_name: str = None): | |
self.dataset_name: str = dataset_name | |
self.subset_name: str = subset_name | |
# dictionary is keyed by template name. | |
self.templates: Dict = self.read_from_file() | |
# Mapping from template name to template id | |
self.name_to_id_mapping = {} | |
self.sync_mapping() | |
def sync_mapping(self) -> None: | |
""" | |
Re-compute the name_to_id_mapping to ensure it is in sync with self.templates | |
""" | |
self.name_to_id_mapping = {template.name: template.id for template in self.templates.values()} | |
def all_template_names(self) -> List[str]: | |
""" | |
Sorted list of all templates names for this dataset | |
""" | |
return sorted([template.name for template in self.templates.values()]) | |
def folder_path(self) -> str: | |
if self.subset_name: | |
return os.path.join(TEMPLATES_FOLDER_PATH, self.dataset_name, self.subset_name) | |
else: | |
return os.path.join(TEMPLATES_FOLDER_PATH, self.dataset_name) | |
def yaml_path(self) -> str: | |
return os.path.join(self.folder_path, self.TEMPLATE_FILENAME) | |
def format_for_dump(self) -> Dict: | |
""" | |
Create a formatted dictionary for the class attributes | |
""" | |
formatted_dict = {self.DATASET_KEY: self.dataset_name, self.TEMPLATES_KEY: self.templates} | |
if self.subset_name: | |
formatted_dict[self.SUBSET_KEY] = self.subset_name | |
return formatted_dict | |
def read_from_file(self) -> Dict: | |
""" | |
Reads a file containing a prompt collection. | |
""" | |
if not os.path.exists(self.yaml_path): | |
return {} | |
yaml_dict = yaml.load(open(self.yaml_path, "r"), Loader=yaml.FullLoader) | |
return yaml_dict[self.TEMPLATES_KEY] | |
def write_to_file(self) -> None: | |
""" | |
Writes to a file with the current prompt collection. | |
""" | |
# Sync the mapping | |
self.sync_mapping() | |
# We only create the folder if a template is written | |
if not os.path.exists(self.folder_path): | |
os.makedirs(self.folder_path) | |
yaml.dump(self.format_for_dump(), open(self.yaml_path, "w")) | |
def add_template(self, template: "Template") -> None: | |
""" | |
Adds a new template for the dataset | |
:param template: template | |
""" | |
self.templates[template.get_id()] = template | |
self.write_to_file() | |
def remove_template(self, template_name: str) -> None: | |
""" | |
Deletes a template | |
:param template_name: name of template to remove | |
""" | |
# Even if we have an ID, we want to check for duplicate names | |
if template_name not in self.all_template_names: | |
raise ValueError(f"No template with name {template_name} for dataset {self.dataset_name} exists.") | |
del self.templates[self.name_to_id_mapping[template_name]] | |
if len(self.templates) == 0: | |
# There is no remaining template, we can remove the entire folder | |
self.delete_folder() | |
else: | |
# We just update the file | |
self.write_to_file() | |
def update_template( | |
self, | |
current_template_name: str, | |
new_template_name: str, | |
jinja: str, | |
reference: str, | |
metadata: Template.Metadata, | |
answer_choices: str, | |
) -> None: | |
""" | |
Updates a pre-existing template and writes changes | |
:param current_template_name: current name of the template stored in self.templates | |
:param new_template_name: new name for the template | |
:param jinja: new jinja entry | |
:param reference: new reference entry | |
:param metadata: a Metadata object with template annotations | |
:param answer_choices: new answer_choices string | |
""" | |
template_id = self.name_to_id_mapping[current_template_name] | |
self.templates[template_id].name = new_template_name | |
self.templates[template_id].jinja = jinja | |
self.templates[template_id].reference = reference | |
self.templates[template_id].metadata = metadata | |
self.templates[template_id].answer_choices = answer_choices | |
self.write_to_file() | |
def delete_folder(self) -> None: | |
""" | |
Delete the folder corresponding to self.folder_path | |
""" | |
self.sync_mapping() | |
rmtree(self.folder_path) | |
# If it is a subset, we have to check whether to remove the dataset folder | |
if self.subset_name: | |
# have to check for other folders | |
base_dataset_folder = os.path.join(TEMPLATES_FOLDER_PATH, self.dataset_name) | |
if len(os.listdir(base_dataset_folder)) == 0: | |
rmtree(base_dataset_folder) | |
def __getitem__(self, template_key: str) -> "Template": | |
return self.templates[self.name_to_id_mapping[template_key]] | |
def __len__(self) -> int: | |
return len(self.templates) | |
def get_templates_data_frame(): | |
""" | |
Gathers all template information into a Pandas DataFrame. | |
:return: Pandas DataFrame | |
""" | |
data = { | |
"id": [], | |
"dataset": [], | |
"subset": [], | |
"name": [], | |
"reference": [], | |
"original_task": [], | |
"choices_in_prompt": [], | |
"metrics": [], | |
"answer_choices": [], | |
"jinja": [], | |
} | |
template_collection = TemplateCollection() | |
for key in template_collection.keys: | |
templates = template_collection.get_dataset(key[0], key[1]) | |
for template_name in templates.all_template_names: | |
template = templates[template_name] | |
data["id"].append(template.get_id()) | |
data["dataset"].append(key[0]) | |
data["subset"].append(key[1]) | |
data["name"].append(template.get_name()) | |
data["reference"].append(template.get_reference()) | |
data["original_task"].append(template.metadata.original_task) | |
data["choices_in_prompt"].append(template.metadata.choices_in_prompt) | |
data["metrics"].append(template.metadata.metrics) | |
data["answer_choices"].append(template.get_answer_choices_expr()) | |
data["jinja"].append(template.jinja) | |
return pd.DataFrame(data) | |