File size: 21,094 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
import torch
from onmt.translate import penalties
from onmt.translate.decode_strategy import DecodeStrategy

import warnings


class BeamSearchBase(DecodeStrategy):
    """Generation beam search.

    Note that the attributes list is not exhaustive. Rather, it highlights
    tensors to document their shape. (Since the state variables' "batch"
    size decreases as beams finish, we denote this axis with a B rather than
    ``batch_size``).

    Args:
        beam_size (int): Number of beams to use (see base ``parallel_paths``).
        batch_size (int): See base.
        pad (int): See base.
        bos (int): See base.
        eos (int): See base.
        unk (int): See base.
        start (int): See base.
        n_best (int): Don't stop until at least this many beams have
            reached EOS.
        global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance.
        min_length (int): See base.
        max_length (int): See base.
        return_attention (bool): See base.
        block_ngram_repeat (int): See base.
        exclusion_tokens (set[int]): See base.

    Attributes:
        top_beam_finished (ByteTensor): Shape ``(B,)``.
        _batch_offset (LongTensor): Shape ``(B,)``.
        _beam_offset (LongTensor): Shape ``(batch_size x beam_size,)``.
        alive_seq (LongTensor): See base.
        topk_log_probs (FloatTensor): Shape ``(B, beam_size,)``. These
            are the scores used for the topk operation.
        src_len (LongTensor): Lengths of encodings. Used for
            masking attentions.
        select_indices (LongTensor or NoneType): Shape
            ``(B x beam_size,)``. This is just a flat view of the
            ``_batch_index``.
        topk_scores (FloatTensor): Shape
            ``(B, beam_size)``. These are the
            scores a sequence will receive if it finishes.
        topk_ids (LongTensor): Shape ``(B, beam_size)``. These are the
            word indices of the topk predictions.
        _batch_index (LongTensor): Shape ``(B, beam_size)``.
        _prev_penalty (FloatTensor or NoneType): Shape
            ``(B, beam_size)``. Initialized to ``None``.
        _coverage (FloatTensor or NoneType): Shape
            ``(1, B x beam_size, inp_seq_len)``.
        hypotheses (list[list[Tuple[Tensor]]]): Contains a tuple
            of score (float), sequence (long), and attention (float or None).
    """

    def __init__(
        self,
        beam_size,
        batch_size,
        pad,
        bos,
        eos,
        unk,
        start,
        n_best,
        global_scorer,
        min_length,
        max_length,
        return_attention,
        block_ngram_repeat,
        exclusion_tokens,
        stepwise_penalty,
        ratio,
        ban_unk_token,
    ):
        super(BeamSearchBase, self).__init__(
            pad,
            bos,
            eos,
            unk,
            start,
            batch_size,
            beam_size,
            global_scorer,
            min_length,
            block_ngram_repeat,
            exclusion_tokens,
            return_attention,
            max_length,
            ban_unk_token,
        )
        # beam parameters
        self.beam_size = beam_size
        self.n_best = n_best
        self.ratio = ratio

        ### new adding
        self.topk_scores_list = []
        self.topk_ids_list = []
        self.nbest_beam_sequences = []
        self.sequence_total_scores = []
        self.beam_data = []

        # beam state
        self.top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8)
        # BoolTensor was introduced in pytorch 1.2
        try:
            self.top_beam_finished = self.top_beam_finished.bool()
        except AttributeError:
            pass
        self._batch_offset = torch.arange(batch_size, dtype=torch.long)

        self.select_indices = None
        self.done = False
        # "global state" of the old beam
        self._prev_penalty = None
        self._coverage = None

        self._stepwise_cov_pen = stepwise_penalty and self.global_scorer.has_cov_pen
        self._vanilla_cov_pen = not stepwise_penalty and self.global_scorer.has_cov_pen
        self._cov_pen = self.global_scorer.has_cov_pen

        self.src_len = None

    def initialize(self, *args, **kwargs):
        raise NotImplementedError

    def initialize_(self, enc_out, src_len, src_map, device, target_prefix):
        super(BeamSearchBase, self).initialize(
            enc_out, src_len, src_map, device, target_prefix
        )

        self.best_scores = torch.full(
            [self.batch_size], -1e10, dtype=torch.float, device=device
        )
        self._beam_offset = torch.arange(
            0,
            self.batch_size * self.beam_size,
            step=self.beam_size,
            dtype=torch.long,
            device=device,
        )
        self.topk_log_probs = (
            torch.tensor([0.0] + [float("-inf")] * (self.beam_size - 1), device=device)
            .repeat(self.batch_size)
            .reshape(self.batch_size, self.beam_size)
        )
        # buffers for the topk scores and 'backpointer'
        self.topk_scores = torch.empty(
            (self.batch_size, self.beam_size), dtype=torch.float, device=device
        )
        self.topk_ids = torch.empty(
            (self.batch_size, self.beam_size), dtype=torch.long, device=device
        )
        self._batch_index = torch.empty(
            [self.batch_size, self.beam_size], dtype=torch.long, device=device
        )

    @property
    def current_predictions(self):
        return self.alive_seq[:, -1]

    @property
    def current_backptr(self):
        # for testing
        return self.select_indices.view(self.batch_size, self.beam_size).fmod(
            self.beam_size
        )

    @property
    def batch_offset(self):
        return self._batch_offset

    def _pick(self, log_probs, out=None):
        """Take a token pick decision for a step.

        Args:
            log_probs (FloatTensor): (B * beam_size, vocab_size)
            out (Tensor, LongTensor): output buffers to reuse, optional.

        Returns:
            topk_scores (FloatTensor): (B, beam_size)
            topk_ids (LongTensor): (B, beam_size)
        """
        vocab_size = log_probs.size(-1)
        # maybe fix some prediction at this step by modifying log_probs
        log_probs = self.target_prefixing(log_probs)

        # Flatten probs into a list of possibilities.
        curr_scores = log_probs.reshape(-1, self.beam_size * vocab_size)
        if out is not None:
            torch.topk(curr_scores, self.beam_size, dim=-1, out=out)
            return
        topk_scores, topk_ids = torch.topk(curr_scores, self.beam_size, dim=-1)
        return topk_scores, topk_ids

    def update_finished(self):
        # Penalize beams that finished.
        _B_old = self.topk_log_probs.shape[0]
        step = self.alive_seq.shape[-1]  # 1 greater than the step in advance
        self.topk_log_probs.masked_fill_(self.is_finished, -1e10)
        # on real data (newstest2017) with the pretrained transformer,
        # it's faster to not move this back to the original device
        self.is_finished = self.is_finished.to("cpu")
        self.top_beam_finished |= self.is_finished[:, 0].eq(1)
        predictions = self.alive_seq.view(_B_old, self.beam_size, step)
        attention = (
            self.alive_attn.view(
                _B_old, self.beam_size, step - 1, self.alive_attn.size(-1)
            )
            if self.alive_attn is not None
            else None
        )
        non_finished_batch = []
        for i in range(self.is_finished.size(0)):  # Batch level
            b = self._batch_offset[i]

            finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1)
            # Store finished hypotheses for this batch.
            for j in finished_hyp:  # Beam level: finished beam j in batch i
                if self.ratio > 0:
                    s = self.topk_scores[i, j] / (step + 1)
                    if self.best_scores[b] < s:
                        self.best_scores[b] = s
                self.hypotheses[b].append(
                    (
                        self.topk_scores[i, j],
                        predictions[i, j, 1:],  # Ignore start_token.
                        attention[i, j, :, : self.src_len[i]]
                        if attention is not None
                        else None,
                    )
                )
            # End condition is the top beam finished and we can return
            # n_best hypotheses.
            if self.ratio > 0:
                pred_len = self.src_len[i] * self.ratio
                finish_flag = (
                    (self.topk_scores[i, 0] / pred_len) <= self.best_scores[b]
                ) or self.is_finished[i].all()
            else:
                finish_flag = self.top_beam_finished[i] != 0
            if finish_flag and len(self.hypotheses[b]) >= self.beam_size:
                best_hyp = sorted(self.hypotheses[b], key=lambda x: x[0], reverse=True)[
                    : self.n_best
                ]
                for n, (score, pred, attn) in enumerate(best_hyp):
                    self.scores[b].append(score)
                    self.predictions[b].append(pred)  # ``(batch, n_best,)``
                    self.attention[b].append(attn if attn is not None else [])
            else:
                non_finished_batch.append(i)

        non_finished = torch.tensor(non_finished_batch)
        # If all sentences are translated, no need to go further.
        if len(non_finished) == 0:
            self.done = True
            return

        _B_new = non_finished.shape[0]
        self.remove_finished_batches(
            _B_new, _B_old, non_finished, predictions, attention, step
        )

    def remove_finished_batches(
        self, _B_new, _B_old, non_finished, predictions, attention, step
    ):
        # Remove finished batches for the next step.
        self.top_beam_finished = self.top_beam_finished.index_select(0, non_finished)
        self._batch_offset = self._batch_offset.index_select(0, non_finished)
        non_finished = non_finished.to(self.topk_ids.device)
        self.topk_log_probs = self.topk_log_probs.index_select(0, non_finished)
        self._batch_index = self._batch_index.index_select(0, non_finished)
        self.select_indices = self._batch_index.view(_B_new * self.beam_size)
        self.alive_seq = predictions.index_select(0, non_finished).view(
            -1, self.alive_seq.size(-1)
        )
        self.topk_scores = self.topk_scores.index_select(0, non_finished)
        self.topk_ids = self.topk_ids.index_select(0, non_finished)
        self.maybe_update_target_prefix(self.select_indices)
        if self.alive_attn is not None:
            inp_seq_len = self.alive_attn.size(-1)
            self.alive_attn = attention.index_select(0, non_finished).view(
                _B_new * self.beam_size, step - 1, inp_seq_len
            )
            if self._cov_pen:
                self._coverage = (
                    self._coverage.view(_B_old, self.beam_size, 1, inp_seq_len)
                    .index_select(0, non_finished)
                    .view(_B_new * self.beam_size, 1, inp_seq_len)
                )
                if self._stepwise_cov_pen:
                    self._prev_penalty = self._prev_penalty.index_select(
                        0, non_finished
                    )

    def advance(self, log_probs, attn):
        vocab_size = log_probs.size(-1)

        # using integer division to get an integer _B without casting
        _B = log_probs.shape[0] // self.beam_size

        if self._stepwise_cov_pen and self._prev_penalty is not None:
            self.topk_log_probs += self._prev_penalty
            self.topk_log_probs -= self.global_scorer.cov_penalty(
                self._coverage + attn, self.global_scorer.beta
            ).view(_B, self.beam_size)

        # force the output to be longer than self.min_length
        step = len(self)
        self.ensure_min_length(log_probs)
        self.ensure_unk_removed(log_probs)
        # Multiply probs by the beam probability.
        log_probs += self.topk_log_probs.view(_B * self.beam_size, 1)

        # if the sequence ends now, then the penalty is the current
        # length + 1, to include the EOS token
        length_penalty = self.global_scorer.length_penalty(
            step + 1, alpha=self.global_scorer.alpha
        )
        
        ## new adding
        length_penalty = 1
        
        curr_scores = log_probs / length_penalty

        # Avoid any direction that would repeat unwanted ngrams
        self.block_ngram_repeats(curr_scores)

        # Pick up candidate token by curr_scores
        self._pick(curr_scores, out=(self.topk_scores, self.topk_ids))

        # Recover log probs.
        # Length penalty is just a scalar. It doesn't matter if it's applied
        # before or after the topk.
        torch.mul(self.topk_scores, length_penalty, out=self.topk_log_probs)

        # Resolve beam origin and map to batch index flat representation.
        self._batch_index = torch.div(self.topk_ids, vocab_size, rounding_mode="trunc")
        self._batch_index += self._beam_offset[:_B].unsqueeze(1)
        self.select_indices = self._batch_index.view(_B * self.beam_size)
        self.topk_ids.fmod_(vocab_size)  # resolve true word ids

        # Append last prediction.
        self.alive_seq = torch.cat(
            [
                self.alive_seq.index_select(0, self.select_indices),
                self.topk_ids.view(_B * self.beam_size, 1),
            ],
            -1,
        )

        self.maybe_update_forbidden_tokens()

        if self.return_attention or self._cov_pen:
            current_attn = attn.index_select(0, self.select_indices)
            if step == 1:
                self.alive_attn = current_attn
                # update global state (step == 1)
                if self._cov_pen:  # coverage penalty
                    self._prev_penalty = torch.zeros_like(self.topk_log_probs)
                    self._coverage = current_attn
            else:
                self.alive_attn = self.alive_attn.index_select(0, self.select_indices)
                self.alive_attn = torch.cat([self.alive_attn, current_attn], 1)
                # update global state (step > 1)
                if self._cov_pen:
                    self._coverage = self._coverage.index_select(0, self.select_indices)
                    self._coverage += current_attn
                    self._prev_penalty = self.global_scorer.cov_penalty(
                        self._coverage, beta=self.global_scorer.beta
                    ).view(_B, self.beam_size)

        if self._vanilla_cov_pen:
            # shape: (batch_size x beam_size, 1)
            cov_penalty = self.global_scorer.cov_penalty(
                self._coverage, beta=self.global_scorer.beta
            )
            self.topk_scores -= cov_penalty.view(_B, self.beam_size).float()

        self.is_finished = self.topk_ids.eq(self.eos)
        
        
        # # new addding
        # print("self.topk_scores", self.topk_scores)
        # print("length_penalty:", length_penalty)
        # self.topk_scores_list.append(torch.exp(self.topk_scores))
        # self.topk_ids_list.append(self.topk_ids.clone())
        # self.nbest_beam_sequences.append(self.alive_seq.clone())
        # # Record the total scores of the current sequences (New addition)
        # self.sequence_total_scores.append(self.topk_scores)
        # print("self.nbest_beam_sequences:", self.nbest_beam_sequences[-1])
        # print("Total score of nbest_beam_sequences:", self.sequence_total_scores[-1])
        # print("exp(self.topk_scores)", torch.exp(self.topk_scores))
        # print("self.topk_ids: ", self.topk_ids)
        # print()
        
        # ### new adding
        # # collecting data of each step
        # step_data = {
        #     "step": len(self),
        #     "token_ids": self.topk_ids.tolist(),
        #     "nbest_beam_sequences": self.nbest_beam_sequences[-1].to("cpu"),
        #     "total_scores": self.sequence_total_scores[-1].tolist(),
        #     "exp(total_scores)": torch.exp(self.topk_scores).to("cpu"),
        # }
        # self.beam_data.append(step_data)
    
        
        self.ensure_max_length()


class BeamSearch(BeamSearchBase):
    """
    Beam search for seq2seq/encoder-decoder models
    """

    def initialize(
        self, enc_out, src_len, src_map=None, device=None, target_prefix=None
    ):
        """Initialize for decoding.
        Repeat src objects `beam_size` times.
        """

        (fn_map_state, enc_out, src_map, target_prefix) = self.initialize_tile(
            enc_out, src_len, src_map, target_prefix
        )
        if device is None:
            device = self.get_device_from_enc_out(enc_out)

        super(BeamSearch, self).initialize_(
            enc_out, self.src_len, src_map, device, target_prefix
        )

        return fn_map_state, enc_out, self.src_len, src_map


class BeamSearchLM(BeamSearchBase):
    """
    Beam search for language/decoder only models
    """

    def initialize(self, src, src_len, src_map=None, device=None, target_prefix=None):
        """Initialize for decoding.
        Repeat src objects `beam_size` times.
        """
        (fn_map_state, _, src_map, target_prefix) = self.initialize_tile(
            None, src_len, src_map, target_prefix
        )
        if device is None:
            device = src.device

        super(BeamSearchLM, self).initialize_(
            None,
            self.src_len,
            src_map=src_map,
            device=device,
            target_prefix=target_prefix,
        )

        return fn_map_state, src, self.src_len, src_map

    def advance(self, log_probs, attn):
        super(BeamSearchLM, self).advance(log_probs, attn)

        # in LM task src_len is associated with currently generated src
        # and therefore needs to follow the generation
        self.src_len += 1

    def remove_finished_batches(
        self, _B_new, _B_old, non_finished, predictions, attention, step
    ):
        super(BeamSearchLM, self).remove_finished_batches(
            _B_new, _B_old, non_finished, predictions, attention, step
        )

        # in LM task src_len is associated with currently generated src
        # and therefore needs to follow the generation
        non_finished = non_finished.to(self.topk_ids.device)
        self.src_len = (
            self.src_len.view(_B_old, self.beam_size)
            .index_select(0, non_finished)
            .view(_B_new * self.beam_size)
        )


class GNMTGlobalScorer(object):
    """NMT re-ranking.

    Args:
       alpha (float): Length parameter.
       beta (float):  Coverage parameter.
       length_penalty (str): Length penalty strategy.
       coverage_penalty (str): Coverage penalty strategy.

    Attributes:
        alpha (float): See above.
        beta (float): See above.
        length_penalty (callable): See :class:`penalties.PenaltyBuilder`.
        coverage_penalty (callable): See :class:`penalties.PenaltyBuilder`.
        has_cov_pen (bool): See :class:`penalties.PenaltyBuilder`.
        has_len_pen (bool): See :class:`penalties.PenaltyBuilder`.
    """

    @classmethod
    def from_opt(cls, opt):
        return cls(opt.alpha, opt.beta, opt.length_penalty, opt.coverage_penalty)

    def __init__(self, alpha, beta, length_penalty, coverage_penalty):
        self._validate(alpha, beta, length_penalty, coverage_penalty)
        self.alpha = alpha
        self.beta = beta
        penalty_builder = penalties.PenaltyBuilder(coverage_penalty, length_penalty)
        self.has_cov_pen = penalty_builder.has_cov_pen
        # Term will be subtracted from probability
        self.cov_penalty = penalty_builder.coverage_penalty

        self.has_len_pen = penalty_builder.has_len_pen
        # Probability will be divided by this
        self.length_penalty = penalty_builder.length_penalty

    @classmethod
    def _validate(cls, alpha, beta, length_penalty, coverage_penalty):
        # these warnings indicate that either the alpha/beta
        # forces a penalty to be a no-op, or a penalty is a no-op but
        # the alpha/beta would suggest otherwise.
        if length_penalty is not None and alpha == 0.0:
            warnings.warn(
                "Using length penalty with alpha==0 "
                "is equivalent to using length penalty none."
            )
        if coverage_penalty is None or coverage_penalty == "none":
            if beta != 0:
                warnings.warn(
                    "Non-default `beta` with no coverage penalty. "
                    "`beta` has no effect."
                )
        else:
            # using some coverage penalty
            if beta == 0.0:
                warnings.warn(
                    "Non-default coverage penalty with beta==0 "
                    "is equivalent to using coverage penalty none."
                )