Spaces:
Runtime error
Runtime error
from tasks.base import BaseProbInference | |
class SST2ProbInferenceForMC(BaseProbInference): | |
def __init__(self, prompt_version): | |
super().__init__(prompt_version) | |
self.CHOICES = ["negative", "positive"] | |
self.can_be_stratified = True | |
self.num_base_shot = len(self.CHOICES) | |
def default_prompt_version(self): | |
return "sp" | |
def dataset_signature(self): | |
return { | |
"result": ("glue", "sst2", "validation"), | |
"sample": ("glue", "sst2", "train"), | |
} | |
def dataset_preprocess(self, raw_data): | |
data = [] | |
for e in raw_data: | |
# print(e, flush=True) | |
data.append({"query": e["sentence"].strip(), "choices": self.CHOICES, "answer_idx": e["label"]}) | |
return data | |
def handcrafted_exemplars(self): | |
raise NotImplementedError | |
def exemplar_seperator(self): | |
if self.prompt_version.startswith("sp"): | |
return "\n\n" | |
else: | |
raise ValueError(f"SST2: Not supported prompt_version: {self.prompt_version}") | |
def multiple_choice_promptify(self, query, choice): | |
if self.prompt_version.startswith("sp"): | |
with_query = f"Review: {query}\nSentiment:" | |
with_query_and_choice = f"{with_query} {choice}" | |
else: | |
raise ValueError(f"SST2: Not supported prompt_version: {self.prompt_version}") | |
return with_query, with_query_and_choice | |