File size: 3,692 Bytes
9d21d47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer


class TokenizedForMCRightPad(Dataset):
    def __init__(self, data, tok: PreTrainedTokenizer, prompt_fn):
        # data: [query: str, choices: list(str)]
        self.tok = tok
        self.prompt_fn = prompt_fn
        self.max_length = self._find_max_length(data)
        self.data = self._build_mc_data(data)

    def _find_max_length(self, data):
        max_len = 0

        def tok_len(t):
            return len(self.tok.encode(t))

        for ex in data:
            query = ex["query"]
            len_choices = [tok_len(self.prompt_fn(query, c)[1]) for c in ex["choices"]]
            max_len = max(max_len, *len_choices)

        return max_len

    def _build_mc_data(self, data):
        processed = []
        num_choices = set(len(e["choices"]) for e in data)
        if not len(num_choices) == 1:
            raise ValueError(f"Queries have different number of choices, which is not supported! #choices: {num_choices}")
        for ex in data:
            query, choices = ex["query"], ex["choices"]
            processed_input = [self.prompt_fn(query, choice) for choice in choices]
            processed_input = [self.tokenize(t_query, t_full) for t_query, t_full in processed_input]
            processed.append(processed_input)

        return processed

    def tokenize_demonstration(self, demonstration):
        e = self.tok(demonstration)
        return torch.LongTensor(e["input_ids"]), torch.LongTensor(e["attention_mask"])  # no padding

    def tokenize(self, only_query, full_text):
        tok_only_query = self.tok(only_query, add_special_tokens=False)
        tok_full_no_padding = self.tok(full_text, add_special_tokens=False)
        tok_full = self.tok(
            full_text,
            padding="max_length",
            max_length=self.max_length,
            add_special_tokens=False,
        )  # <pad> is not a special token
        # tok_only_query = self.tok(only_query)
        # tok_full_no_padding = self.tok(full_text)
        # tok_full = self.tok(
        #     full_text,
        #     padding="max_length",
        #     max_length=self.max_length,
        # )  # <pad> is not a special token

        # print(f"tok_only_query: {self.tok.convert_ids_to_tokens(tok_only_query.input_ids)}")
        # print(f"tok_full_no_padding: {self.tok.convert_ids_to_tokens(tok_full_no_padding.input_ids)}")
        # print(f"tok_full: {self.tok.convert_ids_to_tokens(tok_full.input_ids)}")
        # exit(0)

        len_full = len(tok_full_no_padding.input_ids)
        len_query = len(tok_only_query.input_ids)
        e = {
            "input_ids": tok_full.input_ids,
            "attention_mask": tok_full.attention_mask,
            "choice_start": len_query,
            "choice_end": len_full,
        }
        # print("Attn:")
        # print(tok_full.attention_mask)
        # print("input_ids:")
        # print(tok_full.input_ids)

        dcd_sp = self.tok.convert_ids_to_tokens(tok_full.input_ids, skip_special_tokens=False)

        # print(f'{e["choice_start"]}: {e["choice_end"]} = [{self.tok.convert_tokens_to_string(dcd_sp[e["choice_start"] : e["choice_end"]])}]')

        return e

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        def _get_one_item(e):
            return torch.LongTensor(e["input_ids"]), torch.LongTensor(e["attention_mask"]), e["choice_start"], e["choice_end"]

        es = self.data[idx]
        # num_choices * (input_ids, attn, start_idx, end_idx)
        # input_ids, attn: [B, L]
        # start_idx, end_idx: [B, ]
        return [_get_one_item(e) for e in es]