File size: 18,012 Bytes
dd9b3ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07eef75
dd9b3ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07eef75
 
 
 
dd9b3ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
"""Wrapper of Seq2Labels model. Fixes errors based on model predictions"""
from collections import defaultdict
from difflib import SequenceMatcher
import logging
import re
from time import time
from typing import List, Union
import warnings

import torch
from transformers import AutoTokenizer
from modeling_seq2labels import Seq2LabelsModel
from vocabulary import Vocabulary
from utils import PAD, UNK, START_TOKEN, get_target_sent_by_edits

logging.getLogger("werkzeug").setLevel(logging.ERROR)
logger = logging.getLogger(__file__)


class GecBERTModel(torch.nn.Module):
    def __init__(
        self,
        vocab_path=None,
        model_paths=None,
        weights=None,
        device=None,
        max_len=64,
        min_len=3,
        lowercase_tokens=False,
        log=False,
        iterations=3,
        min_error_probability=0.0,
        confidence=0,
        resolve_cycles=False,
        split_chunk=False,
        chunk_size=48,
        overlap_size=12,
        min_words_cut=6,
        punc_dict={':', ".", ",", "?"},
    ):
        r"""
        Args:
            vocab_path (`str`):
                Path to vocabulary directory.
            model_paths (`List[str]`):
                List of model paths.
            weights (`int`, *Optional*, defaults to None):
                Weights of each model. Only relevant if `is_ensemble is True`.
            device (`int`, *Optional*, defaults to None):
                Device to load model. If not set, device will be automatically choose.
            max_len (`int`, defaults to 64):
                Max sentence length to be processed (all longer will be truncated).
            min_len (`int`, defaults to 3):
                Min sentence length to be processed (all shorted will be returned w/o changes).
            lowercase_tokens (`bool`, defaults to False):
                Whether to lowercase tokens.
            log (`bool`, defaults to False):
                Whether to enable logging.
            iterations (`int`, defaults to 3):
                Max iterations to run during inference.
            special_tokens_fix (`bool`, defaults to True):
               Whether to fix problem with [CLS], [SEP] tokens tokenization.
            min_error_probability (`float`, defaults to `0.0`):
                Minimum probability for each action to apply.
            confidence (`float`, defaults to `0.0`):
                How many probability to add to $KEEP token.
            split_chunk (`bool`, defaults to False):
                Whether to split long sentences to multiple segments of `chunk_size`.
                !Warning: if `chunk_size > max_len`, each segment will be truncate to `max_len`.
            chunk_size (`int`, defaults to 48):
                Length of each segment (in words). Only relevant if `split_chunk is True`.
            overlap_size (`int`, defaults to 12):
                Overlap size (in words) between two consecutive segments. Only relevant if `split_chunk is True`.
            min_words_cut (`int`, defaults to 6):
                Minimun number of words to be cut while merging two consecutive segments.
                Only relevant if `split_chunk is True`.
            punc_dict (List[str], defaults to `{':', ".", ",", "?"}`):
                List of punctuations.
        """
        super().__init__()
        if isinstance(model_paths, str):
            model_paths = [model_paths]
        self.model_weights = list(map(float, weights)) if weights else [1] * len(model_paths)
        self.device = (
            torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
        )
        self.max_len = max_len
        self.min_len = min_len
        self.lowercase_tokens = lowercase_tokens
        self.min_error_probability = min_error_probability
        self.vocab = Vocabulary.from_files(vocab_path)
        self.incorr_index = self.vocab.get_token_index("INCORRECT", "d_tags")
        self.log = log
        self.iterations = iterations
        self.confidence = confidence
        self.resolve_cycles = resolve_cycles

        assert (
            chunk_size > 0 and chunk_size // 2 >= overlap_size
        ), "Chunk merging required overlap size must be smaller than half of chunk size"
        self.split_chunk = split_chunk
        self.chunk_size = chunk_size
        self.overlap_size = overlap_size
        self.min_words_cut = min_words_cut
        self.stride = chunk_size - overlap_size
        self.punc_dict = punc_dict
        self.punc_str = '[' + ''.join([f'\{x}' for x in punc_dict]) + ']'
        # set training parameters and operations

        self.indexers = []
        self.models = []
        for model_path in model_paths:
            model = Seq2LabelsModel.from_pretrained(model_path)
            config = model.config
            model_name = config.pretrained_name_or_path
            special_tokens_fix = config.special_tokens_fix
            self.indexers.append(self._get_indexer(model_name, special_tokens_fix))
            model.eval().to(self.device)
            self.models.append(model)

    def _get_indexer(self, weights_name, special_tokens_fix):
        tokenizer = AutoTokenizer.from_pretrained(
            weights_name, do_basic_tokenize=False, do_lower_case=self.lowercase_tokens, model_max_length=1024
        )
        # to adjust all tokenizers
        if hasattr(tokenizer, 'encoder'):
            tokenizer.vocab = tokenizer.encoder
        if hasattr(tokenizer, 'sp_model'):
            tokenizer.vocab = defaultdict(lambda: 1)
            for i in range(tokenizer.sp_model.get_piece_size()):
                tokenizer.vocab[tokenizer.sp_model.id_to_piece(i)] = i

        if special_tokens_fix:
            tokenizer.add_tokens([START_TOKEN])
            tokenizer.vocab[START_TOKEN] = len(tokenizer) - 1
        return tokenizer
    
    def forward(self, text: Union[str, List[str], List[List[str]]], is_split_into_words=False):
        # Input type checking for clearer error
        def _is_valid_text_input(t):
            if isinstance(t, str):
                # Strings are fine
                return True
            elif isinstance(t, (list, tuple)):
                # List are fine as long as they are...
                if len(t) == 0:
                    # ... empty
                    return True
                elif isinstance(t[0], str):
                    # ... list of strings
                    return True
                elif isinstance(t[0], (list, tuple)):
                    # ... list with an empty list or with a list of strings
                    return len(t[0]) == 0 or isinstance(t[0][0], str)
                else:
                    return False
            else:
                return False

        if not _is_valid_text_input(text):
            raise ValueError(
                "text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) "
                "or `List[List[str]]` (batch of pretokenized examples)."
            )
        
        if is_split_into_words:
            is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
        else:
            is_batched = isinstance(text, (list, tuple))
            if is_batched:
                text = [x.split() for x in text]
            else:
                text = text.split()
        
        if not is_batched:
            text = [text]
        
        return self.handle_batch(text)

    def split_chunks(self, batch):
        # return batch pairs of indices
        result = []
        indices = []
        for tokens in batch:
            start = len(result)
            num_token = len(tokens)
            if num_token <= self.chunk_size:
                result.append(tokens)
            elif num_token > self.chunk_size and num_token < (self.chunk_size * 2 - self.overlap_size):
                split_idx = (num_token + self.overlap_size + 1) // 2
                result.append(tokens[:split_idx])
                result.append(tokens[split_idx - self.overlap_size :])
            else:
                for i in range(0, num_token - self.overlap_size, self.stride):
                    result.append(tokens[i : i + self.chunk_size])

            indices.append((start, len(result)))

        return result, indices

    def check_alnum(self, s):
        if len(s) < 2:
            return False
        return not (s.isalpha() or s.isdigit())

    def apply_chunk_merging(self, tokens, next_tokens):
        # Return next tokens if current tokens list is empty
        if not tokens:
            return next_tokens

        source_token_idx = []
        target_token_idx = []
        source_tokens = []
        target_tokens = []
        num_keep = self.overlap_size - self.min_words_cut
        i = 0
        while len(source_token_idx) < self.overlap_size and -i < len(tokens):
            i -= 1
            if tokens[i] not in self.punc_dict:
                source_token_idx.insert(0, i)
                source_tokens.insert(0, tokens[i].lower())

        i = 0
        while len(target_token_idx) < self.overlap_size and i < len(next_tokens):
            if next_tokens[i] not in self.punc_dict:
                target_token_idx.append(i)
                target_tokens.append(next_tokens[i].lower())
            i += 1

        matcher = SequenceMatcher(None, source_tokens, target_tokens)
        diffs = list(matcher.get_opcodes())

        for diff in diffs:
            tag, i1, i2, j1, j2 = diff
            if tag == "equal":
                if i1 >= num_keep:
                    tail_idx = source_token_idx[i1]
                    head_idx = target_token_idx[j1]
                    break
                elif i2 > num_keep:
                    tail_idx = source_token_idx[num_keep]
                    head_idx = target_token_idx[j2 - i2 + num_keep]
                    break
            elif tag == "delete" and i1 == 0:
                num_keep += i2 // 2

        tokens = tokens[:tail_idx] + next_tokens[head_idx:]
        return tokens

    def merge_chunks(self, batch):
        result = []
        if len(batch) == 1 or self.overlap_size == 0:
            for sub_tokens in batch:
                result.extend(sub_tokens)
        else:
            for _, sub_tokens in enumerate(batch):
                try:
                    result = self.apply_chunk_merging(result, sub_tokens)
                except Exception as e:
                    print(e)

        result = " ".join(result)
        return result

    def predict(self, batches):
        t11 = time()
        predictions = []
        for batch, model in zip(batches, self.models):
            batch = batch.to(self.device)
            with torch.no_grad():
                prediction = model.forward(**batch)
            predictions.append(prediction)

        preds, idx, error_probs = self._convert(predictions)
        t55 = time()
        if self.log:
            print(f"Inference time {t55 - t11}")
        return preds, idx, error_probs

    def get_token_action(self, token, index, prob, sugg_token):
        """Get lost of suggested actions for token."""
        # cases when we don't need to do anything
        if prob < self.min_error_probability or sugg_token in [UNK, PAD, '$KEEP']:
            return None

        if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE':
            start_pos = index
            end_pos = index + 1
        elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"):
            start_pos = index + 1
            end_pos = index + 1

        if sugg_token == "$DELETE":
            sugg_token_clear = ""
        elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"):
            sugg_token_clear = sugg_token[:]
        else:
            sugg_token_clear = sugg_token[sugg_token.index('_') + 1 :]

        return start_pos - 1, end_pos - 1, sugg_token_clear, prob

    def preprocess(self, token_batch):
        seq_lens = [len(sequence) for sequence in token_batch if sequence]
        if not seq_lens:
            return []
        max_len = min(max(seq_lens), self.max_len)
        batches = []
        for indexer in self.indexers:
            token_batch = [[START_TOKEN] + sequence[:max_len] for sequence in token_batch]
            batch = indexer(
                token_batch,
                return_tensors="pt",
                padding=True,
                is_split_into_words=True,
                truncation=True,
                add_special_tokens=False,
            )
            offset_batch = []
            for i in range(len(token_batch)):
                word_ids = batch.word_ids(batch_index=i)
                offsets = [0]
                for i in range(1, len(word_ids)):
                    if word_ids[i] != word_ids[i - 1]:
                        offsets.append(i)
                offset_batch.append(torch.LongTensor(offsets))

            batch["input_offsets"] = torch.nn.utils.rnn.pad_sequence(
                offset_batch, batch_first=True, padding_value=0
            ).to(torch.long)

            batches.append(batch)

        return batches

    def _convert(self, data):
        all_class_probs = torch.zeros_like(data[0]['logits'])
        error_probs = torch.zeros_like(data[0]['max_error_probability'])
        for output, weight in zip(data, self.model_weights):
            class_probabilities_labels = torch.softmax(output['logits'], dim=-1)
            all_class_probs += weight * class_probabilities_labels / sum(self.model_weights)
            class_probabilities_d = torch.softmax(output['detect_logits'], dim=-1)
            error_probs_d = class_probabilities_d[:, :, self.incorr_index]
            incorr_prob = torch.max(error_probs_d, dim=-1)[0]
            error_probs += weight * incorr_prob / sum(self.model_weights)

        max_vals = torch.max(all_class_probs, dim=-1)
        probs = max_vals[0].tolist()
        idx = max_vals[1].tolist()
        return probs, idx, error_probs.tolist()

    def update_final_batch(self, final_batch, pred_ids, pred_batch, prev_preds_dict):
        new_pred_ids = []
        total_updated = 0
        for i, orig_id in enumerate(pred_ids):
            orig = final_batch[orig_id]
            pred = pred_batch[i]
            prev_preds = prev_preds_dict[orig_id]
            if orig != pred and pred not in prev_preds:
                final_batch[orig_id] = pred
                new_pred_ids.append(orig_id)
                prev_preds_dict[orig_id].append(pred)
                total_updated += 1
            elif orig != pred and pred in prev_preds:
                # update final batch, but stop iterations
                final_batch[orig_id] = pred
                total_updated += 1
            else:
                continue
        return final_batch, new_pred_ids, total_updated

    def postprocess_batch(self, batch, all_probabilities, all_idxs, error_probs):
        all_results = []
        noop_index = self.vocab.get_token_index("$KEEP", "labels")
        for tokens, probabilities, idxs, error_prob in zip(batch, all_probabilities, all_idxs, error_probs):
            length = min(len(tokens), self.max_len)
            edits = []

            # skip whole sentences if there no errors
            if max(idxs) == 0:
                all_results.append(tokens)
                continue

            # skip whole sentence if probability of correctness is not high
            if error_prob < self.min_error_probability:
                all_results.append(tokens)
                continue

            for i in range(length + 1):
                # because of START token
                if i == 0:
                    token = START_TOKEN
                else:
                    token = tokens[i - 1]
                # skip if there is no error
                if idxs[i] == noop_index:
                    continue

                sugg_token = self.vocab.get_token_from_index(idxs[i], namespace='labels')
                action = self.get_token_action(token, i, probabilities[i], sugg_token)
                if not action:
                    continue

                edits.append(action)
            all_results.append(get_target_sent_by_edits(tokens, edits))
        return all_results

    def handle_batch(self, full_batch, merge_punc=True):
        """
        Handle batch of requests.
        """
        if self.split_chunk:
            full_batch, indices = self.split_chunks(full_batch)
        else:
            indices = None
        final_batch = full_batch[:]
        batch_size = len(full_batch)
        prev_preds_dict = {i: [final_batch[i]] for i in range(len(final_batch))}
        short_ids = [i for i in range(len(full_batch)) if len(full_batch[i]) < self.min_len]
        pred_ids = [i for i in range(len(full_batch)) if i not in short_ids]
        total_updates = 0

        for n_iter in range(self.iterations):
            orig_batch = [final_batch[i] for i in pred_ids]

            sequences = self.preprocess(orig_batch)

            if not sequences:
                break
            probabilities, idxs, error_probs = self.predict(sequences)

            pred_batch = self.postprocess_batch(orig_batch, probabilities, idxs, error_probs)
            if self.log:
                print(f"Iteration {n_iter + 1}. Predicted {round(100*len(pred_ids)/batch_size, 1)}% of sentences.")

            final_batch, pred_ids, cnt = self.update_final_batch(final_batch, pred_ids, pred_batch, prev_preds_dict)
            total_updates += cnt

            if not pred_ids:
                break
        if self.split_chunk:
            final_batch = [self.merge_chunks(final_batch[start:end]) for (start, end) in indices]
        else:
            final_batch = [" ".join(x) for x in final_batch]
        if merge_punc:
            final_batch = [re.sub(r'\s+(%s)' % self.punc_str, r'\1', x) for x in final_batch]

        return final_batch