Spaces:
Runtime error
Runtime error
File size: 1,448 Bytes
9d21d47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from tasks.base import BaseProbInference
class SST2ProbInferenceForMC(BaseProbInference):
def __init__(self, prompt_version):
super().__init__(prompt_version)
self.CHOICES = ["negative", "positive"]
self.can_be_stratified = True
self.num_base_shot = len(self.CHOICES)
def default_prompt_version(self):
return "sp"
def dataset_signature(self):
return {
"result": ("glue", "sst2", "validation"),
"sample": ("glue", "sst2", "train"),
}
def dataset_preprocess(self, raw_data):
data = []
for e in raw_data:
# print(e, flush=True)
data.append({"query": e["sentence"].strip(), "choices": self.CHOICES, "answer_idx": e["label"]})
return data
def handcrafted_exemplars(self):
raise NotImplementedError
def exemplar_seperator(self):
if self.prompt_version.startswith("sp"):
return "\n\n"
else:
raise ValueError(f"SST2: Not supported prompt_version: {self.prompt_version}")
def multiple_choice_promptify(self, query, choice):
if self.prompt_version.startswith("sp"):
with_query = f"Review: {query}\nSentiment:"
with_query_and_choice = f"{with_query} {choice}"
else:
raise ValueError(f"SST2: Not supported prompt_version: {self.prompt_version}")
return with_query, with_query_and_choice
|