File size: 12,877 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import numpy as np
import torch
from fairseq.data import FairseqDataset


class BlockPairDataset(FairseqDataset):
    """Break a Dataset of tokens into sentence pair blocks for next sentence
       prediction as well as masked language model.

       High-level logics are:
       1. break input tensor to tensor blocks
       2. pair the blocks with 50% next sentence and 50% random sentence
       3. return paired blocks as well as related segment labels

    Args:
        dataset (~torch.utils.data.Dataset): dataset to break into blocks
        sizes: array of sentence lengths
        dictionary: dictionary for the task
        block_size: maximum block size
        break_mode: mode for breaking copurs into block pairs. currently we support
            2 modes
            doc: respect document boundaries and each part of the pair should belong to on document
            none: don't respect any boundary and cut tokens evenly
        short_seq_prob: probability for generating shorter block pairs
        doc_break_size: Size for empty line separating documents. Typically 1 if
                        the sentences have eos, 0 otherwise.
    """

    def __init__(
        self,
        dataset,
        dictionary,
        sizes,
        block_size,
        break_mode="doc",
        short_seq_prob=0.1,
        doc_break_size=1,
    ):
        super().__init__()
        self.dataset = dataset
        self.pad = dictionary.pad()
        self.eos = dictionary.eos()
        self.cls = dictionary.cls()
        self.mask = dictionary.mask()
        self.sep = dictionary.sep()
        self.break_mode = break_mode
        self.dictionary = dictionary
        self.short_seq_prob = short_seq_prob
        self.block_indices = []

        assert len(dataset) == len(sizes)

        if break_mode == "doc":
            cur_doc = []
            for sent_id, sz in enumerate(sizes):
                assert doc_break_size == 0 or sz != 0, (
                    "when doc_break_size is non-zero, we expect documents to be"
                    "separated by a blank line with a single eos."
                )
                # empty line as document separator
                if sz == doc_break_size:
                    if len(cur_doc) == 0:
                        continue
                    self.block_indices.append(cur_doc)
                    cur_doc = []
                else:
                    cur_doc.append(sent_id)
            max_num_tokens = block_size - 3  # Account for [CLS], [SEP], [SEP]
            self.sent_pairs = []
            self.sizes = []
            for doc_id, doc in enumerate(self.block_indices):
                self._generate_sentence_pair(doc, doc_id, max_num_tokens, sizes)
        elif break_mode is None or break_mode == "none":
            # each block should have half of the block size since we are constructing block pair
            sent_length = (block_size - 3) // 2
            total_len = sum(dataset.sizes)
            length = math.ceil(total_len / sent_length)

            def block_at(i):
                start = i * sent_length
                end = min(start + sent_length, total_len)
                return (start, end)

            sent_indices = np.array([block_at(i) for i in range(length)])
            sent_sizes = np.array([e - s for s, e in sent_indices])
            dataset_index = self._sent_to_dataset_index(sent_sizes)

            # pair sentences
            self._pair_sentences(dataset_index)
        else:
            raise ValueError("Invalid break_mode: " + break_mode)

    def _pair_sentences(self, dataset_index):
        """
        Give a list of evenly cut blocks/sentences, pair these sentences with 50%
        consecutive sentences and 50% random sentences.
        This is used for none break mode
        """
        # pair sentences
        for sent_id, sent in enumerate(dataset_index):
            next_sent_label = (
                1 if np.random.rand() > 0.5 and sent_id != len(dataset_index) - 1 else 0
            )
            if next_sent_label:
                next_sent = dataset_index[sent_id + 1]
            else:
                next_sent = dataset_index[
                    self._skip_sampling(len(dataset_index), [sent_id, sent_id + 1])
                ]
            self.sent_pairs.append((sent, next_sent, next_sent_label))

            # The current blocks don't include the special tokens but the
            # sizes already account for this
            self.sizes.append(3 + sent[3] + next_sent[3])

    def _sent_to_dataset_index(self, sent_sizes):
        """
        Build index mapping block indices to the underlying dataset indices
        """
        dataset_index = []
        ds_idx, ds_remaining = -1, 0
        for to_consume in sent_sizes:
            sent_size = to_consume
            if ds_remaining == 0:
                ds_idx += 1
                ds_remaining = sent_sizes[ds_idx]
            start_ds_idx = ds_idx
            start_offset = sent_sizes[ds_idx] - ds_remaining
            while to_consume > ds_remaining:
                to_consume -= ds_remaining
                ds_idx += 1
                ds_remaining = sent_sizes[ds_idx]
            ds_remaining -= to_consume
            dataset_index.append(
                (
                    start_ds_idx,  # starting index in dataset
                    start_offset,  # starting offset within starting index
                    ds_idx,  # ending index in dataset
                    sent_size,  # sentence length
                )
            )
        assert ds_remaining == 0
        assert ds_idx == len(self.dataset) - 1
        return dataset_index

    def _generate_sentence_pair(self, doc, doc_id, max_num_tokens, sizes):
        """
        Go through a single document and genrate sentence paris from it
        """
        current_chunk = []
        current_length = 0
        curr = 0
        # To provide more randomness, we decrease target seq length for parts of
        # samples (10% by default). Note that max_num_tokens is the hard threshold
        # for batching and will never be changed.
        target_seq_length = max_num_tokens
        if np.random.random() < self.short_seq_prob:
            target_seq_length = np.random.randint(2, max_num_tokens)
        # loop through all sentences in document
        while curr < len(doc):
            sent_id = doc[curr]
            current_chunk.append(sent_id)
            current_length = sum(sizes[current_chunk])
            # split chunk and generate pair when exceed target_seq_length or
            # finish the loop
            if curr == len(doc) - 1 or current_length >= target_seq_length:
                # split the chunk into 2 parts
                a_end = 1
                if len(current_chunk) > 2:
                    a_end = np.random.randint(1, len(current_chunk) - 1)
                sent_a = current_chunk[:a_end]
                len_a = sum(sizes[sent_a])
                # generate next sentence label, note that if there is only 1 sentence
                # in current chunk, label is always 0
                next_sent_label = (
                    1 if np.random.rand() > 0.5 and len(current_chunk) != 1 else 0
                )
                if not next_sent_label:
                    # if next sentence label is 0, sample sent_b from a random doc
                    target_b_length = target_seq_length - len_a
                    rand_doc_id = self._skip_sampling(len(self.block_indices), [doc_id])
                    random_doc = self.block_indices[rand_doc_id]
                    random_start = np.random.randint(0, len(random_doc))
                    sent_b = []
                    len_b = 0
                    for j in range(random_start, len(random_doc)):
                        sent_b.append(random_doc[j])
                        len_b = sum(sizes[sent_b])
                        if len_b >= target_b_length:
                            break
                    # return the second part of the chunk since it's not used
                    num_unused_segments = len(current_chunk) - a_end
                    curr -= num_unused_segments
                else:
                    # if next sentence label is 1, use the second part of chunk as sent_B
                    sent_b = current_chunk[a_end:]
                    len_b = sum(sizes[sent_b])
                # currently sent_a and sent_B may be longer than max_num_tokens,
                # truncate them and return block idx and offsets for them
                sent_a, sent_b = self._truncate_sentences(
                    sent_a, sent_b, max_num_tokens
                )
                self.sent_pairs.append((sent_a, sent_b, next_sent_label))
                self.sizes.append(3 + sent_a[3] + sent_b[3])
                current_chunk = []
            curr += 1

    def _skip_sampling(self, total, skip_ids):
        """
        Generate a random integer which is not in skip_ids. Sample range is [0, total)
        TODO: ids in skip_ids should be consecutive, we can extend it to more generic version later
        """
        rand_id = np.random.randint(total - len(skip_ids))
        return rand_id if rand_id < min(skip_ids) else rand_id + len(skip_ids)

    def _truncate_sentences(self, sent_a, sent_b, max_num_tokens):
        """
        Trancate a pair of sentence to limit total length under max_num_tokens
        Logics:
            1. Truncate longer sentence
            2. Tokens to be truncated could be at the beginning or the end of the sentnce
        Returns:
            Truncated sentences represented by dataset idx
        """
        len_a, len_b = sum(self.dataset.sizes[sent_a]), sum(self.dataset.sizes[sent_b])
        front_cut_a = front_cut_b = end_cut_a = end_cut_b = 0

        while True:
            total_length = (
                len_a + len_b - front_cut_a - front_cut_b - end_cut_a - end_cut_b
            )
            if total_length <= max_num_tokens:
                break

            if len_a - front_cut_a - end_cut_a > len_b - front_cut_b - end_cut_b:
                if np.random.rand() < 0.5:
                    front_cut_a += 1
                else:
                    end_cut_a += 1
            else:
                if np.random.rand() < 0.5:
                    front_cut_b += 1
                else:
                    end_cut_b += 1

        # calculate ds indices as well as offsets and return
        truncated_sent_a = self._cut_sentence(sent_a, front_cut_a, end_cut_a)
        truncated_sent_b = self._cut_sentence(sent_b, front_cut_b, end_cut_b)
        return truncated_sent_a, truncated_sent_b

    def _cut_sentence(self, sent, front_cut, end_cut):
        """
        Cut a sentence based on the numbers of tokens to be cut from beginning and end
        Represent the sentence as dataset idx and return
        """
        start_ds_idx, end_ds_idx, offset = sent[0], sent[-1], 0
        target_len = sum(self.dataset.sizes[sent]) - front_cut - end_cut
        while front_cut > 0:
            if self.dataset.sizes[start_ds_idx] > front_cut:
                offset += front_cut
                break
            else:
                front_cut -= self.dataset.sizes[start_ds_idx]
                start_ds_idx += 1
        while end_cut > 0:
            if self.dataset.sizes[end_ds_idx] > end_cut:
                break
            else:
                end_cut -= self.dataset.sizes[end_ds_idx]
                end_ds_idx -= 1
        return start_ds_idx, offset, end_ds_idx, target_len

    def _fetch_block(self, start_ds_idx, offset, end_ds_idx, length):
        """
        Fetch a block of tokens based on its dataset idx
        """
        buffer = torch.cat(
            [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)]
        )
        s, e = offset, offset + length
        return buffer[s:e]

    def __getitem__(self, index):
        block1, block2, next_sent_label = self.sent_pairs[index]
        block1 = self._fetch_block(*block1)
        block2 = self._fetch_block(*block2)
        return block1, block2, next_sent_label

    def __len__(self):
        return len(self.sizes)

    @property
    def supports_prefetch(self):
        return getattr(self.dataset, "supports_prefetch", False)

    def prefetch(self, indices):
        prefetch_idx = set()
        for index in indices:
            for block1, block2, _ in [self.sent_pairs[index]]:
                for ds_idx in range(block1[0], block1[2] + 1):
                    prefetch_idx.add(ds_idx)
                for ds_idx in range(block2[0], block2[2] + 1):
                    prefetch_idx.add(ds_idx)
        self.dataset.prefetch(prefetch_idx)