jx-yang's picture
<ADD> +app
9d21d47
raw
history blame
1.45 kB
from tasks.base import BaseProbInference
class SST2ProbInferenceForMC(BaseProbInference):
def __init__(self, prompt_version):
super().__init__(prompt_version)
self.CHOICES = ["negative", "positive"]
self.can_be_stratified = True
self.num_base_shot = len(self.CHOICES)
def default_prompt_version(self):
return "sp"
def dataset_signature(self):
return {
"result": ("glue", "sst2", "validation"),
"sample": ("glue", "sst2", "train"),
}
def dataset_preprocess(self, raw_data):
data = []
for e in raw_data:
# print(e, flush=True)
data.append({"query": e["sentence"].strip(), "choices": self.CHOICES, "answer_idx": e["label"]})
return data
def handcrafted_exemplars(self):
raise NotImplementedError
def exemplar_seperator(self):
if self.prompt_version.startswith("sp"):
return "\n\n"
else:
raise ValueError(f"SST2: Not supported prompt_version: {self.prompt_version}")
def multiple_choice_promptify(self, query, choice):
if self.prompt_version.startswith("sp"):
with_query = f"Review: {query}\nSentiment:"
with_query_and_choice = f"{with_query} {choice}"
else:
raise ValueError(f"SST2: Not supported prompt_version: {self.prompt_version}")
return with_query, with_query_and_choice