File size: 1,910 Bytes
50c5b0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import logging
from typing import List

import torch
from transformers import (
    LogitsProcessor,
)


class StopAfterTokenIsGenerated(LogitsProcessor):
    def __init__(self, stops: List[torch.tensor], eos_token_id: int):
        super().__init__()

        self.stops = stops
        self.eos_token_id = eos_token_id
        logging.info(f"Stopping criteria words ids: {self.stops}")
        self.first_batch = True

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        """
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
            scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
                Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
                search or log softmax for each vocabulary token when using beam search

        Return:
            `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.

        """
        if self.first_batch:
            self.first_batch = False
            return scores

        for seq_no, seq in enumerate(input_ids):
            # logging.info(seq_no)
            for stop in self.stops:
                stop = stop.to(device=seq.device, dtype=seq.dtype)
                if (
                    len(seq) >= len(stop)
                    and torch.all((stop == seq[-len(stop) :])).item()
                ):
                    scores[seq_no, :] = -float("inf")
                    scores[seq_no, self.eos_token_id] = 0
                    logging.info(f"Stopping criteria found: {stop}")
                    break

        return scores

    def reset(self):
        self.first_batch = True