File size: 1,448 Bytes
9d21d47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from tasks.base import BaseProbInference


class SST2ProbInferenceForMC(BaseProbInference):
    def __init__(self, prompt_version):
        super().__init__(prompt_version)

        self.CHOICES = ["negative", "positive"]
        self.can_be_stratified = True
        self.num_base_shot = len(self.CHOICES)

    def default_prompt_version(self):
        return "sp"

    def dataset_signature(self):
        return {
            "result": ("glue", "sst2", "validation"),
            "sample": ("glue", "sst2", "train"),
        }

    def dataset_preprocess(self, raw_data):
        data = []
        for e in raw_data:
            # print(e, flush=True)
            data.append({"query": e["sentence"].strip(), "choices": self.CHOICES, "answer_idx": e["label"]})
        return data

    def handcrafted_exemplars(self):
        raise NotImplementedError

    def exemplar_seperator(self):
        if self.prompt_version.startswith("sp"):
            return "\n\n"
        else:
            raise ValueError(f"SST2: Not supported prompt_version: {self.prompt_version}")

    def multiple_choice_promptify(self, query, choice):
        if self.prompt_version.startswith("sp"):
            with_query = f"Review: {query}\nSentiment:"
            with_query_and_choice = f"{with_query} {choice}"
        else:
            raise ValueError(f"SST2: Not supported prompt_version: {self.prompt_version}")
        return with_query, with_query_and_choice