File size: 17,126 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
"""Transforms relate to noising from BART: based on code of fairseq."""
import math
import numpy as np
import torch

from typing import Sequence, Callable
from onmt.constants import DefaultTokens, SubwordMarker
from onmt.transforms import register_transform
from .transform import Transform


def _subword_start_by_joiner(tokens: Sequence[str]) -> Sequence[bool]:
    """Find word start in a subword list marked by joiner."""
    flag = [True] * len(tokens)
    for i, token in enumerate(tokens):
        if token.startswith(SubwordMarker.JOINER) and i != 0:
            flag[i] = False
        if token.endswith(SubwordMarker.JOINER):
            try:
                flag[i + 1] = False
            except IndexError:
                print("Sentence `{}` not correct!".format(" ".join(token)))
                raise
    return flag


def _subword_start_by_spacer(tokens: Sequence[str]) -> Sequence[bool]:
    """Find word start in a subword list marked by spacer(as prefix)."""
    flag = [x.startswith(SubwordMarker.SPACER) for x in tokens]
    flag[0] = True
    return flag


def word_start_finder(ignore_subword=False, is_joiner=False) -> Callable:
    """Return callable to find all word start in the token list."""
    if not ignore_subword:
        if is_joiner:
            return _subword_start_by_joiner
        else:
            return _subword_start_by_spacer
    else:
        return lambda tokens: [True] * len(tokens)


class BARTNoising(object):
    """Noise from BART."""

    def __init__(
        self,
        vocab,
        mask_tok=DefaultTokens.MASK,
        mask_ratio=0.0,
        insert_ratio=0.0,
        permute_sent_ratio=0.0,
        poisson_lambda=3.0,
        replace_length=-1,
        rotate_ratio=0.0,
        mask_length="subword",
        random_ratio=0.0,
        is_joiner=False,
        full_stop_token=DefaultTokens.SENT_FULL_STOPS,
    ):
        if vocab is None:
            raise ValueError("Inject BART noise requires a valid vocabulary.")
        self.vocab = vocab

        self.mask_tok = mask_tok

        self.mask_ratio = mask_ratio
        self.random_ratio = random_ratio
        self.insert_ratio = insert_ratio
        self.rotate_ratio = rotate_ratio
        self.permute_sent_ratio = permute_sent_ratio

        self.full_stop_token = full_stop_token

        # -1: keep everything (i.e. 1 mask per token)
        #  0: replace everything (i.e. no mask)
        #  1: 1 mask per span
        if replace_length not in [-1, 0, 1]:
            raise ValueError(f"invalid arg: replace_length={replace_length}")
        self.replace_length = replace_length

        if mask_length not in ["subword", "word", "span-poisson"]:
            raise ValueError(f"invalid arg: mask-length={mask_length}")
        if mask_length == "subword" and replace_length not in [0, 1]:
            raise ValueError("if using subwords, use replace-length=1 or 0")

        if mask_length == "subword" or is_joiner is None:
            # view each subword as word start / input is word level token
            self._is_word_start = word_start_finder(ignore_subword=True)
        else:
            self._is_word_start = word_start_finder(is_joiner=is_joiner)

        self.mask_span_distribution = None
        if mask_length == "span-poisson":
            self.mask_span_distribution = self._make_poisson(poisson_lambda)
        self.mask_length = mask_length
        self.poisson_lambda = poisson_lambda

    @staticmethod
    def set_random_seed(seed):
        """Call this before use to ensure reproducibility."""
        np.random.seed(seed)
        torch.manual_seed(seed)

    def _make_poisson(self, poisson_lambda):
        lambda_to_the_k = 1
        e_to_the_minus_lambda = math.exp(-poisson_lambda)
        k_factorial = 1
        ps = []
        for k in range(0, 128):
            ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
            lambda_to_the_k *= poisson_lambda
            k_factorial *= k + 1
            if ps[-1] < 0.0000001:
                break
        ps = torch.FloatTensor(ps)
        return torch.distributions.Categorical(ps)

    def _get_sentence_borders(self, tokens):
        """Return lengths of each sentence in the token sequence."""
        full_stops = np.array(
            [True if token in self.full_stop_token else False for token in tokens]
        )
        # Pretend it ends with a full stop so last span is a sentence
        full_stops[-1] = True
        # Tokens that are full stops, where the previous token is not
        sentence_lens = (full_stops[1:] * ~full_stops[:-1]).nonzero()[0] + 2
        return sentence_lens

    def permute_sentences(self, tokens, p=1.0):
        if len(tokens) == 1:
            return tokens
        sentence_lens = self._get_sentence_borders(tokens)
        n_sentences = sentence_lens.size
        if n_sentences == 1:
            return tokens

        n_to_permute = math.ceil((n_sentences * 2 * p) / 2.0)

        substitutions = np.random.permutation(n_sentences)[:n_to_permute]
        ordering = np.arange(0, n_sentences)
        ordering[substitutions] = substitutions[np.random.permutation(n_to_permute)]

        result = [tok for tok in tokens]
        index = 0
        for i in ordering:
            sentence = tokens[(sentence_lens[i - 1] if i > 0 else 0) : sentence_lens[i]]
            result[index : index + len(sentence)] = sentence
            index += len(sentence)
        assert len(result) == len(tokens), "Error when permute sentences."
        return result

    def whole_word_mask(self, tokens, p=1.0):  # text span mask/infilling
        is_word_start = torch.tensor(self._is_word_start(tokens)).int()
        n_mask = int(math.ceil(is_word_start.sum() * p))
        n_insert = 0
        if n_mask == 0:
            return tokens

        if self.mask_span_distribution is not None:  # Text (span) Infilling
            lengths = self.mask_span_distribution.sample(sample_shape=(n_mask,))

            # Make sure we have enough to mask
            cum_length = torch.cumsum(lengths, 0)
            while cum_length[-1] < n_mask:
                lengths = torch.cat(
                    [
                        lengths,
                        self.mask_span_distribution.sample(sample_shape=(n_mask,)),
                    ],
                    dim=0,
                )
                cum_length = torch.cumsum(lengths, 0)

            # Trim to masking budget
            i = 0
            while cum_length[i] < n_mask:
                i += 1
            lengths[i] = n_mask - (0 if i == 0 else cum_length[i - 1])
            n_mask = i + 1
            lengths = lengths[:n_mask]

            # Handle 0-length mask (inserts) separately
            lengths = lengths[lengths > 0]
            n_insert = n_mask - lengths.size(0)
            n_mask -= n_insert
            if n_mask == 0:
                return self.insertion_noise(tokens, n_insert / len(tokens))

            assert (lengths > 0).all()
        else:  # Token Masking
            lengths = torch.ones((n_mask,)).long()
        # assert is_word_start[-1] == 0
        word_starts = is_word_start.nonzero(as_tuple=False)
        indices = word_starts[torch.randperm(word_starts.size(0))[:n_mask]].squeeze(1)
        mask_random = torch.FloatTensor(n_mask).uniform_() < self.random_ratio

        tokens_length = len(tokens)
        # assert tokens_length - 1 not in indices
        to_keep = torch.ones(tokens_length, dtype=torch.bool)

        if self.replace_length == 0:
            to_keep[indices] = 0
        else:
            # keep index, but replace it with [MASK]
            for i in indices.tolist():
                tokens[i] = self.mask_tok
            random_tok_ids = torch.randint(
                0, len(self.vocab), size=(mask_random.sum(),)
            ).tolist()
            for i, rid in zip(indices[mask_random].tolist(), random_tok_ids):
                tokens[i] = self.vocab[rid]

        if tokens_length - 1 in indices:
            uncompleted = indices != tokens_length - 1
            indices = indices[uncompleted]
            mask_random = mask_random[uncompleted]
            lengths = lengths[uncompleted]

        # acts as a long length, so spans don't go over the end of doc
        is_word_start[-1] = 255

        if self.mask_span_distribution is not None:
            assert len(lengths.size()) == 1
            assert lengths.size() == indices.size()
            lengths -= 1  # 1 for the position already masked
            while indices.size(0) > 0:
                assert lengths.size() == indices.size()
                # next position from each word_start
                lengths -= is_word_start[indices + 1].long()
                uncompleted = lengths >= 0
                indices = indices[uncompleted] + 1
                mask_random = mask_random[uncompleted]
                lengths = lengths[uncompleted]
                if self.replace_length != -1:
                    # delete token: 1 mask/remove per span
                    to_keep[indices] = 0
                else:
                    # keep index, but replace it with [MASK]: 1 mask per token
                    for i in indices.tolist():
                        tokens[i] = self.mask_tok
                    random_tok_ids = torch.randint(
                        0, len(self.vocab), size=(mask_random.sum(),)
                    ).tolist()
                    for i, rid in zip(indices[mask_random].tolist(), random_tok_ids):
                        tokens[i] = self.vocab[rid]
        else:
            # A bit faster when all lengths are 1
            while indices.size(0) > 0:
                # to cover whole token
                uncompleted = is_word_start[indices + 1] == 0
                indices = indices[uncompleted] + 1
                mask_random = mask_random[uncompleted]
                if self.replace_length != -1:
                    # delete token
                    to_keep[indices] = 0
                else:
                    # keep index, but replace it with [MASK]
                    for i in indices.tolist():
                        tokens[i] = self.mask_tok
                    random_tok_ids = torch.randint(
                        0, len(self.vocab), size=(mask_random.sum(),)
                    ).tolist()
                    for i, rid in zip(indices[mask_random].tolist(), random_tok_ids):
                        tokens[i] = self.vocab[rid]

                # assert tokens_length - 1 not in indices

        tokens = [tok for tok, keep in zip(tokens, to_keep.tolist()) if keep is True]

        if n_insert > 0:
            tokens = self.insertion_noise(tokens, n_insert / len(tokens))

        return tokens

    def insertion_noise(self, tokens, p=1.0):
        n_tokens = len(tokens)
        n_insert = math.ceil(n_tokens * p)
        if n_insert == 0:
            return tokens
        n_random = math.ceil(n_insert * self.random_ratio)

        noise_indices = np.random.permutation(n_tokens + n_insert)[:n_insert]
        noise_mask = np.zeros(shape=(n_tokens + n_insert,), dtype=bool)
        noise_mask[noise_indices] = 1

        result = np.empty(shape=(n_tokens + n_insert,), dtype=object)
        result[noise_indices[n_random:]] = self.mask_tok
        if n_random > 0:
            result[noise_indices[:n_random]] = np.random.choice(
                self.vocab, size=n_random
            )
        result[~noise_mask] = tokens

        assert all([item is not None for item in result]), "Error when inserting noise."
        return result.tolist()

    def rolling_noise(self, tokens, p=1.0):
        if np.random.random() >= p:
            return tokens
        offset = np.random.randint(0, max(1, len(tokens) - 1) + 1)
        return tokens[offset:] + tokens[0:offset]

    def apply(self, tokens):
        if self.permute_sent_ratio > 0.0:
            tokens = self.permute_sentences(tokens, self.permute_sent_ratio)

        if self.mask_ratio > 0.0:
            tokens = self.whole_word_mask(tokens, self.mask_ratio)

        if self.insert_ratio > 0.0:
            tokens = self.insertion_noise(tokens, self.insert_ratio)

        if self.rotate_ratio > 0.0:
            tokens = self.rolling_noise(tokens, self.rotate_ratio)
        return tokens

    def __repr__(self):
        cls_name = type(self).__name__
        kwargs = {}
        if self.permute_sent_ratio > 0.0:
            kwargs["permute_sent_ratio"] = self.permute_sent_ratio
            kwargs["full_stop_token"] = self.full_stop_token
        if self.insert_ratio > 0.0:
            kwargs["insert_ratio"] = self.insert_ratio
        if self.rotate_ratio > 0.0:
            kwargs["rotate_ratio"] = self.rotate_ratio
        if self.random_ratio > 0.0:
            kwargs["random_ratio"] = self.random_ratio
        if self.mask_ratio > 0.0:
            kwargs["mask_ratio"] = self.mask_ratio
            kwargs["mask_length"] = self.mask_length
            kwargs["poisson_lambda"] = self.poisson_lambda
            kwargs["replace_length"] = self.replace_length
        cls_args = ", ".join([f"{kw}={arg}" for kw, arg in kwargs.items()])
        return "{}({})".format(cls_name, cls_args)


@register_transform(name="bart")
class BARTNoiseTransform(Transform):
    def __init__(self, opts):
        super().__init__(opts)

    def _set_seed(self, seed):
        """set seed to ensure reproducibility."""
        BARTNoising.set_random_seed(seed)

    @classmethod
    def add_options(cls, parser):
        """Avalilable options relate to BART."""
        group = parser.add_argument_group("Transform/BART")
        group.add(
            "--permute_sent_ratio",
            "-permute_sent_ratio",
            type=float,
            default=0.0,
            help="Permute this proportion of sentences "
            "(boundaries defined by {}) in all inputs.".format(
                DefaultTokens.SENT_FULL_STOPS
            ),
        )
        group.add(
            "--rotate_ratio",
            "-rotate_ratio",
            type=float,
            default=0.0,
            help="Rotate this proportion of inputs.",
        )
        group.add(
            "--insert_ratio",
            "-insert_ratio",
            type=float,
            default=0.0,
            help="Insert this percentage of additional random tokens.",
        )
        group.add(
            "--random_ratio",
            "-random_ratio",
            type=float,
            default=0.0,
            help="Instead of using {}, use random token "
            "this often.".format(DefaultTokens.MASK),
        )

        group.add(
            "--mask_ratio",
            "-mask_ratio",
            type=float,
            default=0.0,
            help="Fraction of words/subwords that will be masked.",
        )
        group.add(
            "--mask_length",
            "-mask_length",
            type=str,
            default="subword",
            choices=["subword", "word", "span-poisson"],
            help="Length of masking window to apply.",
        )
        group.add(
            "--poisson_lambda",
            "-poisson_lambda",
            type=float,
            default=3.0,
            help="Lambda for Poisson distribution to sample span length "
            "if `-mask_length` set to span-poisson.",
        )
        group.add(
            "--replace_length",
            "-replace_length",
            type=int,
            default=-1,
            choices=[-1, 0, 1],
            help="When masking N tokens, replace with 0, 1, "
            "or N tokens. (use -1 for N)",
        )

    @classmethod
    def require_vocab(cls):
        """Override this method to inform it need vocab to start."""
        return True

    def warm_up(self, vocabs):
        super().warm_up(vocabs)

        subword_type = self.opts.src_subword_type
        if self.opts.mask_length == "subword":
            if subword_type == "none":
                raise ValueError(
                    f"src_subword_type={subword_type} incompatible with "
                    f"mask_length={self.opts.mask_length}!"
                )
        is_joiner = (subword_type == "bpe") if subword_type != "none" else None
        self.bart_noise = BARTNoising(
            self.vocabs["src"].ids_to_tokens,
            mask_tok=DefaultTokens.MASK,
            mask_ratio=self.opts.mask_ratio,
            insert_ratio=self.opts.insert_ratio,
            permute_sent_ratio=self.opts.permute_sent_ratio,
            poisson_lambda=self.opts.poisson_lambda,
            replace_length=self.opts.replace_length,
            rotate_ratio=self.opts.rotate_ratio,
            mask_length=self.opts.mask_length,
            random_ratio=self.opts.random_ratio,
            is_joiner=is_joiner,
        )

    def apply(self, example, is_train=False, stats=None, **kwargs):
        """Apply BART noise to src side tokens."""
        if is_train:
            src = self.bart_noise.apply(example["src"])
            example["src"] = src
        return example

    def _repr_args(self):
        """Return str represent key arguments for BART."""
        return repr(self.bart_noise)