Truong-Phuc Nguyen commited on
Commit
c8a708c
1 Parent(s): 2704216

Upload 6 files

Browse files
plms/__pycache__/exceptions.cpython-38.pyc ADDED
Binary file (1.91 kB). View file
 
plms/__pycache__/language_model.cpython-38.pyc ADDED
Binary file (21.8 kB). View file
 
plms/__pycache__/spacy_module.cpython-38.pyc ADDED
Binary file (3.68 kB). View file
 
plms/exceptions.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Custom exceptions """
2
+
3
+
4
+ class ExceedMaxLengthError(Exception):
5
+ """ Token exceed max length. """
6
+
7
+ def __init__(self, max_length=None):
8
+ self.message = f'Input sentence exceeds max length of {max_length}'
9
+ super().__init__(self.message)
10
+
11
+
12
+ class HighlightNotFoundError(Exception):
13
+ """ Highlight is not in the sentence. """
14
+
15
+ def __init__(self, highlight: str, input_sentence: str):
16
+ self.message = f'Highlight `{highlight}` not found in the input sentence `{input_sentence}`'
17
+ super().__init__(self.message)
18
+
19
+
20
+ class AnswerNotFoundError(Exception):
21
+ """ Answer cannot found in the context. """
22
+
23
+ def __init__(self, context: str):
24
+ self.message = f'Model cannot find any answer candidates in `{context}`'
25
+ super().__init__(self.message)
26
+
27
+
28
+ class APIError(Exception):
29
+ """ Error from huggingface inference API. """
30
+
31
+ def __init__(self, context: str):
32
+ self.message = f'Huggingface API Error:\n`{context}`'
33
+ super().__init__(self.message)
plms/language_model.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import pickle
4
+ import re
5
+ import urllib
6
+ from itertools import chain
7
+ from typing import List, Dict
8
+ from multiprocessing import Pool
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ import torch
12
+ from torch.nn import functional
13
+ import transformers
14
+ from .exceptions import ExceedMaxLengthError, HighlightNotFoundError, AnswerNotFoundError
15
+ from .spacy_module import SpacyPipeline, VALID_METHODS
16
+
17
+ __all__ = ('TransformersQG', 'ADDITIONAL_SP_TOKENS', 'TASK_PREFIX', 'clean', 'internet_connection')
18
+
19
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # to turn off warning message
20
+ TASK_PREFIX = {
21
+ "ae": "extract answers",
22
+ "qg": "generate question",
23
+ "qag": "generate question and answer",
24
+ "qa": "answer question"
25
+ }
26
+ CE_IGNORE_INDEX = -100
27
+ ADDITIONAL_SP_TOKENS = {'hl': '<hl>'}
28
+ NUM_WORKERS = int(os.getenv('NUM_WORKERS', '0'))
29
+ PARALLEL_PROCESSING = bool(int(os.getenv('PARALLEL_PROCESSING', '0')))
30
+ DEFAULT_MODELS = {
31
+ 'vi': 'VietAI/vit5-base'
32
+ }
33
+
34
+ def pickle_save(obj, path: str):
35
+ with open(path, "wb") as fp:
36
+ pickle.dump(obj, fp)
37
+
38
+
39
+ def pickle_load(path: str):
40
+ with open(path, "rb") as fp: # Unpickling
41
+ return pickle.load(fp)
42
+
43
+
44
+ def clean(string):
45
+ string = re.sub(r'\A\s*', '', string)
46
+ string = re.sub(r'\s*\Z', '', string)
47
+ if len(string) > 0:
48
+ return string
49
+ return None
50
+
51
+
52
+ def internet_connection(host='http://google.com'):
53
+ try:
54
+ urllib.request.urlopen(host)
55
+ return True
56
+ except:
57
+ return False
58
+
59
+
60
+ def load_language_model(model_name,
61
+ cache_dir: str = None,
62
+ use_auth_token: bool = False,
63
+ torch_dtype=None,
64
+ device_map: str = None,
65
+ low_cpu_mem_usage: bool = False):
66
+ """ load language model from huggingface model hub """
67
+ # tokenizer
68
+ local_files_only = not internet_connection()
69
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
70
+ model_name, cache_dir=cache_dir, local_files_only=local_files_only, use_auth_token=use_auth_token)
71
+ config = transformers.AutoConfig.from_pretrained(
72
+ model_name, local_files_only=local_files_only, cache_dir=cache_dir, use_auth_token=use_auth_token)
73
+ # model
74
+ if config.model_type == 't5': # T5 model requires T5ForConditionalGeneration class
75
+ model_class = transformers.T5ForConditionalGeneration.from_pretrained
76
+ elif config.model_type == 'mt5':
77
+ model_class = transformers.MT5ForConditionalGeneration.from_pretrained
78
+ elif config.model_type == 'bart':
79
+ model_class = transformers.BartForConditionalGeneration.from_pretrained
80
+ elif config.model_type == 'mbart':
81
+ model_class = transformers.MBartForConditionalGeneration.from_pretrained
82
+ elif config.model_type == 'switch_transformers':
83
+ model_class = transformers.SwitchTransformersForConditionalGeneration.from_pretrained
84
+ else:
85
+ raise ValueError(f'unsupported model type: {config.model_type}')
86
+
87
+ param = {'config': config, "local_files_only": local_files_only, "use_auth_token": use_auth_token,
88
+ "low_cpu_mem_usage": low_cpu_mem_usage, "cache_dir": cache_dir}
89
+ if torch_dtype is not None:
90
+ param['torch_dtype'] = torch_dtype
91
+ if device_map is not None:
92
+ param['device_map'] = device_map
93
+ model = model_class(model_name, **param)
94
+ # add new special tokens to the tokenizer and the model if they don't have it
95
+ tokenizer.add_special_tokens({'additional_special_tokens': list(ADDITIONAL_SP_TOKENS.values())})
96
+ model.resize_token_embeddings(len(tokenizer))
97
+ return tokenizer, model, config
98
+
99
+
100
+ def label_smoothed_loss(logits, labels, epsilon):
101
+ """ https://github.com/huggingface/transformers/blob/55bb4c06f7be141c6d895dbe1f11018dc8580b2d/src/transformers/trainer_pt_utils.py#L430 """
102
+ log_probs = - functional.log_softmax(logits, dim=-1)
103
+ if labels.dim() == log_probs.dim() - 1:
104
+ labels = labels.unsqueeze(-1)
105
+
106
+ padding_mask = labels.eq(CE_IGNORE_INDEX)
107
+ # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
108
+ # will ignore them in any case.
109
+ labels.clamp_min_(0)
110
+
111
+ nll_loss = log_probs.gather(dim=-1, index=labels)
112
+ nll_loss.masked_fill_(padding_mask, 0.0)
113
+
114
+ # works for fp16 input tensor too, by internally upcasting it to fp32
115
+ smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
116
+ smoothed_loss.masked_fill_(padding_mask, 0.0)
117
+
118
+ # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
119
+ num_active_elements = padding_mask.numel() - padding_mask.long().sum()
120
+ nll_loss = nll_loss.sum() / num_active_elements
121
+ smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
122
+ return (1 - epsilon) * nll_loss + epsilon * smoothed_loss
123
+
124
+
125
+ class Dataset(torch.utils.data.Dataset):
126
+ """ torch.utils.data.Dataset wrapper converting into tensor """
127
+ float_tensors = ['attention_mask']
128
+
129
+ def __init__(self, data: List):
130
+ self.data = data
131
+
132
+ def __len__(self):
133
+ return len(self.data)
134
+
135
+ def to_tensor(self, name, data):
136
+ if name in self.float_tensors:
137
+ return torch.tensor(data, dtype=torch.float32)
138
+ return torch.tensor(data, dtype=torch.long)
139
+
140
+ def __getitem__(self, idx):
141
+ return {k: self.to_tensor(k, v) for k, v in self.data[idx].items()}
142
+
143
+
144
+ class EncodePlus:
145
+ """ Wrapper of encode_plus for multiprocessing. """
146
+
147
+ def __init__(self,
148
+ tokenizer,
149
+ max_length: int = 512,
150
+ max_length_output: int = 34,
151
+ drop_overflow_error_text: bool = False,
152
+ skip_overflow_error: bool = False,
153
+ drop_highlight_error_text: bool = False,
154
+ prefix_type: str = None,
155
+ padding: bool = True):
156
+ """ Wrapper of encode_plus for multiprocessing.
157
+
158
+ @param tokenizer: transforms.Tokenizer
159
+ @param max_length: Max text length of input.
160
+ @param max_length_output: Max text length of output.
161
+ @param drop_overflow_error_text: If true, return None when the input exceeds the max length.
162
+ @param skip_overflow_error: If true, raise an error when the input exceeds the max length.
163
+ @param drop_highlight_error_text: If true, raise an error when a highlight span is not found in the paragraph.
164
+ @param prefix_type: Either of `qg` or `answer_extraction`, which is to add at the beginning of the text.
165
+ @param padding: Pad the sequence to the max length.
166
+ """
167
+ self.prefix = TASK_PREFIX[prefix_type] if prefix_type is not None else None
168
+ self.tokenizer = tokenizer
169
+ self.max_length = max_length
170
+ self.max_length_output = max_length_output
171
+ # NOTE: for model training, we should drop the exceeded input but not for the evaluator
172
+ self.drop_overflow_error_text = drop_overflow_error_text
173
+ self.skip_overflow_error = skip_overflow_error
174
+ self.drop_highlight_error_text = drop_highlight_error_text
175
+ # truncation should be true for the batch process, but not necessary to process single input
176
+ self.param_in = {'truncation': True, 'max_length': self.max_length}
177
+ self.param_out = {'truncation': True, 'max_length': self.max_length_output}
178
+ if padding:
179
+ self.param_in['padding'] = 'max_length'
180
+ self.param_out['padding'] = 'max_length'
181
+
182
+ def __call__(self, inputs):
183
+ return self.encode_plus(*inputs)
184
+
185
+ def encode_plus(self, input_sequence: str, output_sequence: str = None, input_highlight: str = None):
186
+ """ encode_plus
187
+
188
+ @param input_sequence: Input sequence.
189
+ @param output_sequence: Output sequence.
190
+ @param input_highlight: Sub-sequence of `input_sequence` to be surrounded by <hl>.
191
+ @return: The output of `encode_plus`.
192
+ """
193
+ # add highlight to the input
194
+ if input_highlight is not None:
195
+ position = input_sequence.find(input_highlight)
196
+ if position == -1:
197
+ if self.drop_highlight_error_text:
198
+ return None
199
+ raise HighlightNotFoundError(input_highlight, input_sequence)
200
+ input_sequence = '{0}{1} {2} {1}{3}'.format(
201
+ input_sequence[:position], ADDITIONAL_SP_TOKENS['hl'], input_highlight,
202
+ input_sequence[position+len(input_highlight):])
203
+ if self.prefix is not None:
204
+ input_sequence = f'{self.prefix}: {input_sequence}'
205
+
206
+ # handling overflow text
207
+ # drop_overflow_error_text ==> remove the overflow sentence from input
208
+ # skip_overflow_error ==> keep the overflow sentence
209
+ # none of them ==> raise error
210
+ if self.drop_overflow_error_text or not self.skip_overflow_error:
211
+ if len(self.tokenizer.encode(input_sequence)) > self.max_length:
212
+ if not self.drop_overflow_error_text: # raise error for overflow text
213
+ raise ExceedMaxLengthError(self.max_length)
214
+ return None # remove overflow text
215
+ if output_sequence is not None:
216
+ if len(self.tokenizer.encode(output_sequence)) > self.max_length_output:
217
+ if not self.drop_overflow_error_text: # raise error for overflow text
218
+ raise ExceedMaxLengthError(self.max_length)
219
+ return None # remove overflow text
220
+ if type(self.tokenizer) is transformers.models.mbart.tokenization_mbart_fast.MBartTokenizerFast:
221
+ encode = self.tokenizer(input_sequence, **self.param_in)
222
+ else:
223
+ encode = self.tokenizer(text_target=input_sequence, **self.param_in)
224
+ if output_sequence is not None:
225
+ encode['labels'] = self.tokenizer.encode(output_sequence, **self.param_out)
226
+ return encode
227
+
228
+
229
+ class TransformersQG:
230
+ """ Transformers Language Model for Question Generation. """
231
+
232
+ def __init__(self,
233
+ model: str = None,
234
+ max_length: int = 512,
235
+ max_length_output: int = 256,
236
+ model_ae: str = None,
237
+ max_length_ae: int = 512,
238
+ max_length_output_ae: int = 64,
239
+ cache_dir: str = None,
240
+ add_prefix: bool = None,
241
+ language: str = 'vi',
242
+ label_smoothing: float = None,
243
+ skip_overflow_error: bool = False,
244
+ drop_overflow_error_text: bool = False,
245
+ drop_highlight_error_text: bool = False,
246
+ drop_answer_error_text: bool = False,
247
+ use_auth_token: bool = False,
248
+ torch_dtype=None,
249
+ device_map: str = None,
250
+ low_cpu_mem_usage: bool = False,
251
+ is_qg: bool = None,
252
+ is_qag: bool = None,
253
+ is_qa: bool = None,
254
+ is_ae: bool = None):
255
+ """ Transformers Language Model for Question Generation.
256
+
257
+ @param model: Model alias or path to local model file.
258
+ @param max_length: Max text length of input.
259
+ @param max_length_output: Max text length of output.
260
+ @param cache_dir: Directory to cache transformers model files.
261
+ @param add_prefix: Whether model uses task-specific prefix (eg. True for T5 but False for BART models).
262
+ @param language: Language alias for SpaCy language-specific pipelines (sentencizer/keyword extraction).
263
+ @param label_smoothing: [Fine-tuning parameter] Label smoothing.
264
+ @param drop_overflow_error_text: If true, return None when the input exceeds the max length.
265
+ @param skip_overflow_error: If true, raise an error when the input exceeds the max length.
266
+ @param drop_highlight_error_text: If true, raise an error when a highlight span is not found in the paragraph.
267
+ @param use_auth_token: [optional] Huggingface transformers argument of `use_auth_token`
268
+ """
269
+
270
+ # take default model given the language
271
+ if model is None:
272
+ assert language in DEFAULT_MODELS.keys(),\
273
+ f"Model with language '{language}' is not available. Please choose language from " \
274
+ f"'{DEFAULT_MODELS.keys()}' or specify 'model'."
275
+ model = DEFAULT_MODELS[language]
276
+
277
+ # classify model type
278
+ self.is_qg = 'qg' in model.split('-') if is_qg is None else is_qg
279
+ self.is_ae = 'ae' in model.split('-') if is_ae is None else is_ae
280
+ self.is_qa = 'qa' in model.split('-') if is_qa is None else is_qa
281
+ self.is_qag = 'qag' in model.split('-') if is_qag is None else is_qag
282
+ # configs
283
+ self.model_name = model
284
+ self.max_length = max_length
285
+ self.max_length_output = max_length_output
286
+ self.label_smoothing = label_smoothing
287
+ self.drop_overflow_error_text = drop_overflow_error_text
288
+ self.skip_overflow_error = skip_overflow_error
289
+ self.drop_highlight_error_text = drop_highlight_error_text
290
+ self.drop_answer_error_text = drop_answer_error_text
291
+ self.model_name_ae = model_ae
292
+ self.max_length_ae = max_length_ae
293
+ self.max_length_output_ae = max_length_output_ae
294
+ # load model
295
+ self.tokenizer, self.model, config = load_language_model(
296
+ self.model_name, cache_dir=cache_dir, use_auth_token=use_auth_token, device_map=device_map,
297
+ torch_dtype=torch_dtype, low_cpu_mem_usage=low_cpu_mem_usage)
298
+ if 'add_prefix' not in config.to_dict().keys():
299
+ # this means the model is not fine-tuned
300
+ # assert add_prefix, '`add_prefix` is required for non-fine-tuned models'
301
+ self.add_prefix = add_prefix
302
+ else:
303
+ self.add_prefix = config.add_prefix
304
+
305
+ # set default behaviour for answer extraction
306
+ if self.model_name_ae is None:
307
+ self.model_name_ae = self.model_name if self.is_ae else "positionrank"
308
+ # load answer extraction model
309
+ self.answer_model_type = None
310
+ if self.model_name_ae in VALID_METHODS:
311
+ logging.info(f'use spaCy answer extraction model: {self.model_name_ae}')
312
+ self.tokenizer_ae = self.model_ae = self.add_prefix_ae = None
313
+ self.spacy_module = SpacyPipeline(language, self.model_name_ae)
314
+ self.answer_model_type = 'spacy'
315
+ else:
316
+ logging.info(f'use LMQG fine-tuned answer extraction model: {self.model_name_ae}')
317
+ if self.model_name == self.model_name_ae:
318
+ logging.info("the same model as QG is used as AE")
319
+ assert self.is_ae, f"the model ({self.model_name_ae}) is not fine-tuned for AE"
320
+ self.tokenizer_ae = self.model_ae = self.add_prefix_ae = None
321
+ self.answer_model_type = 'multitask'
322
+ else:
323
+ logging.info(f"loading 2nd model for AE: {self.model_name_ae}")
324
+ self.tokenizer_ae, self.model_ae, config_ae = load_language_model(model_ae, cache_dir=cache_dir, use_auth_token=use_auth_token)
325
+ self.add_prefix_ae = config_ae.add_prefix
326
+ self.answer_model_type = 'pipeline'
327
+ self.spacy_module = SpacyPipeline(language)
328
+
329
+ # GPU setup
330
+ self.device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'
331
+ self.parallel = False
332
+ if torch.cuda.device_count() > 1:
333
+ self.parallel = True
334
+ self.model = torch.nn.DataParallel(self.model)
335
+ if self.model_ae is not None:
336
+ self.model_ae = torch.nn.DataParallel(self.model_ae)
337
+ self.model.to(self.device)
338
+ if self.model_ae is not None:
339
+ self.model_ae.to(self.device)
340
+ logging.info(f'Model `{self.model_name}`')
341
+ logging.info(f'\t * Num of GPU in use: {torch.cuda.device_count()}')
342
+ logging.info(f'\t * Prefix: {self.add_prefix}')
343
+ logging.info(f'\t * Language: {language} (ignore at the training phase)')
344
+
345
+ def generate_qa_end2end(self,
346
+ list_context: str or List,
347
+ batch_size: int = None,
348
+ num_beams: int = 4,
349
+ cache_path: str = None,
350
+ splitting_symbol: str = ' [SEP] ',
351
+ question_prefix: str = "question: ",
352
+ answer_prefix: str = ", answer: "):
353
+ """ Generate question from paragraph and answer. Note that `list_answer` is needed unless they are already
354
+ highlighted in the `list_context`. eg) "I live in <hl> Tokyo <hl>."
355
+
356
+ @param list_context: List of input texts.
357
+ @param batch_size: Batch size.
358
+ @param num_beams: Number of beam for model generation.
359
+ @param cache_path: Path to pre-compute features.
360
+ @return: List of generated sentences.
361
+ """
362
+ logging.info(f'running model for `question_answer_pair_generation`')
363
+ assert self.is_qag, "`generate_qa_end2end` is available for end2end_qag_model"
364
+ prefix_type = 'qag' if self.add_prefix else None
365
+ single_input = type(list_context) is str
366
+ list_context = [list_context] if single_input else list_context
367
+ output = self.generate_prediction(
368
+ list_context, prefix_type=prefix_type, cache_path=cache_path, num_beams=num_beams, batch_size=batch_size
369
+ )
370
+
371
+ def format_qa(list_raw_string):
372
+ tmp = []
373
+ for raw_string in list_raw_string:
374
+ if len(raw_string.split(answer_prefix)) != 2 or question_prefix not in raw_string:
375
+ logging.info(f"invalid prediction: {raw_string}")
376
+ else:
377
+ q, a = raw_string.split(answer_prefix)
378
+ a = re.sub(r'\A\s+', '', a)
379
+ a = re.sub(r'\s+\Z', '', a)
380
+ q = q.replace(question_prefix, "")
381
+ q = re.sub(r'\A\s+', '', q)
382
+ q = re.sub(r'\s+\Z', '', q)
383
+ tmp.append((q, a))
384
+ return tmp
385
+
386
+ output = [format_qa(o.split(splitting_symbol)) for o in output]
387
+ return output[0] if single_input else output
388
+
389
+ def generate_qa(self,
390
+ list_context: str or List,
391
+ batch_size: int = None,
392
+ num_beams: int = 4,
393
+ cache_path: str = None,
394
+ num_questions: int = None,
395
+ sentence_level: bool = False):
396
+ """ Generate question given context.
397
+
398
+ @param list_context: Input text.
399
+ @param batch_size: Batch size.
400
+ @param num_beams: Number of beam for model generation.
401
+ @param cache_path: Path to pre-compute features.
402
+ @param num_questions: Max number of questions.
403
+ @param sentence_level: Run prediction on each sentence of the context independently to reduce complexity.
404
+ @return: List of generated sentences.
405
+ """
406
+ if self.is_qag:
407
+ return self.generate_qa_end2end(list_context, batch_size, num_beams, cache_path)
408
+ single_input = type(list_context) is str
409
+ list_context = [list_context] if single_input else list_context
410
+ original_input_length = len(list_context)
411
+
412
+ logging.info('running model for `ae`')
413
+ list_answer = self.generate_a(
414
+ list_context,
415
+ batch_size=batch_size,
416
+ num_beams=num_beams,
417
+ cache_path=cache_path,
418
+ sentence_level=sentence_level,
419
+ num_questions=num_questions
420
+ )
421
+ valid_context_id = [n for n, a in enumerate(list_answer) if a is not None]
422
+ list_context = [list_context[n] for n in valid_context_id]
423
+ list_answer = [list_answer[n] for n in valid_context_id]
424
+ qg_input, qg_hl, list_length = [], [], [0]
425
+ for c, a in zip(list_context, list_answer):
426
+ qg_hl += a
427
+ qg_input += [c] * len(a)
428
+ list_length.append(list_length[-1] + len(a))
429
+ logging.info('running model for `qg`')
430
+ list_question = self.generate_q(
431
+ qg_input,
432
+ list_answer=qg_hl,
433
+ batch_size=batch_size,
434
+ cache_path=cache_path,
435
+ num_beams=num_beams,
436
+ sentence_level=sentence_level
437
+ )
438
+
439
+ assert len(qg_hl) == len(list_question), f"{len(qg_input)} != {len(list_question)}"
440
+
441
+ # return to nested list
442
+ list_question = [list_question[list_length[n - 1]:list_length[n]] for n in range(1, len(list_length))]
443
+ list_answer = [qg_hl[list_length[n - 1]:list_length[n]] for n in range(1, len(list_length))]
444
+ output_list = [None] * original_input_length
445
+
446
+ for n, _id in enumerate(valid_context_id):
447
+ output_list[_id] = [(q, a) for q, a in zip(list_question[n], list_answer[n])]
448
+ return output_list[0] if single_input else output_list
449
+
450
+ def generate_a(self,
451
+ context: str or List,
452
+ batch_size: int = None,
453
+ num_beams: int = 4,
454
+ cache_path: str = None,
455
+ sentence_level: bool = False,
456
+ num_questions: int = None):
457
+ """ Generate answers from each sentence.
458
+
459
+ @param context: Input text.
460
+ @param batch_size: Batch size.
461
+ @param num_beams: Number of beam for model generation.
462
+ @param cache_path: Path to pre-compute features.
463
+ @param sentence_level: Run prediction on each sentence of the context independently to reduce complexity.
464
+ @param num_questions: Max number of questions.
465
+ @return: List of generated answers.
466
+ """
467
+ logging.info(f'running model for `answer_extraction`')
468
+ if self.answer_model_type == 'spacy':
469
+ num_questions = 10 if num_questions is None else num_questions
470
+ if type(context) is str:
471
+ return self.spacy_module.keyword(context, num_questions)
472
+ else:
473
+ return [self.spacy_module.keyword(c, num_questions) for c in context]
474
+ single_input = type(context) is str
475
+ context = [context] if single_input else context
476
+ list_sentences = [self.spacy_module.sentence(c) for c in context] # split into sentence
477
+ list_inputs = [[c] * len(s) for c, s in zip(context, list_sentences)]
478
+ list_length = [0] + np.cumsum([len(s) for s in list_sentences]).tolist()
479
+ if sentence_level:
480
+ list_inputs = list_sentences
481
+ # flatten inputs
482
+ flat_sentences = list(chain(*list_sentences))
483
+ flat_inputs = list(chain(*list_inputs))
484
+ if self.answer_model_type == 'multitask':
485
+ answer = self.generate_prediction(
486
+ flat_inputs, # list_input,
487
+ highlights=flat_sentences, # highlights=list_sentence,
488
+ prefix_type='ae' if self.add_prefix else None,
489
+ cache_path=cache_path,
490
+ num_beams=num_beams,
491
+ batch_size=batch_size
492
+ )
493
+ elif self.answer_model_type == 'pipeline':
494
+ answer = self.generate_prediction(
495
+ flat_inputs, # list_input,
496
+ highlights=flat_sentences, # highlights=list_sentence,
497
+ prefix_type='ae' if self.add_prefix_ae else None,
498
+ cache_path=cache_path,
499
+ num_beams=num_beams,
500
+ batch_size=batch_size,
501
+ switch_to_model_ae=True
502
+ )
503
+ else:
504
+ raise ValueError(f"unknown answer model type: {self.answer_model_type}")
505
+ # return to nested list
506
+ answer = [clean(a) for a in answer]
507
+ list_answer = [answer[list_length[n - 1]:list_length[n]] for n in range(1, len(list_length))]
508
+ list_answer = [[a for a, c in zip(a_sent, c_sent) if a is not None and a in c]
509
+ for a_sent, c_sent in zip(list_answer, list_inputs)]
510
+ list_answer = [None if len(a) == 0 else a for a in list_answer]
511
+ if not self.drop_answer_error_text:
512
+ if any(a is None for a in list_answer):
513
+ raise AnswerNotFoundError([context[n] for n, a in enumerate(list_answer) if a is None][0])
514
+ return list_answer[0] if single_input else list_answer
515
+
516
+ def generate_q(self,
517
+ list_context: str or List,
518
+ list_answer: List = None,
519
+ batch_size: int = None,
520
+ num_beams: int = 4,
521
+ cache_path: str = None,
522
+ sentence_level: bool = False):
523
+ """ Generate question from paragraph and answer. Note that `list_answer` is needed unless they are already
524
+ highlighted in the `list_context`. eg) "I live in <hl> Tokyo <hl>."
525
+
526
+ @param list_context: List of input texts.
527
+ @param list_answer: List of answers in the `list_context` that are highlighted by <hl>.
528
+ @param batch_size: Batch size.
529
+ @param num_beams: Number of beam for model generation.
530
+ @param cache_path: Path to pre-compute features.
531
+ @param sentence_level: Run prediction on each sentence of the context independently to reduce complexity.
532
+ @return: List of generated sentences.
533
+ """
534
+ assert self.is_qg, "model is not fine-tuned for QG"
535
+ if list_answer is not None:
536
+ assert type(list_context) is type(list_answer), f"{type(list_context)} != {type(list_answer)}"
537
+ single_input = False
538
+ if type(list_context) is str:
539
+ list_context = [list_context]
540
+ list_answer = [list_answer] if list_answer is not None else None
541
+ single_input = True
542
+ output = self.generate_prediction(
543
+ list_context,
544
+ highlights=list_answer,
545
+ prefix_type='qg' if self.add_prefix else None,
546
+ cache_path=cache_path,
547
+ num_beams=num_beams,
548
+ batch_size=batch_size,
549
+ sentence_level=sentence_level
550
+ )
551
+ if single_input:
552
+ return output[0]
553
+ return output
554
+
555
+ def generate_prediction(self,
556
+ inputs: List,
557
+ highlights: List or None = None,
558
+ prefix_type: str = None,
559
+ num_beams: int = 4,
560
+ batch_size: int = None,
561
+ cache_path: str = None,
562
+ sentence_level: bool = False,
563
+ switch_to_model_ae: bool = False):
564
+ """ General method to generate model prediction
565
+
566
+ @param inputs: List of input sequences.
567
+ @param highlights: List of sub-sequences from list_context to be highlighted by <hl>.
568
+ @param batch_size: Batch size.
569
+ @param num_beams: Number of beam for model generation.
570
+ @param cache_path: Path to pre-compute features.
571
+ @param prefix_type: Either of `qg` or `answer_extraction`, which is to add at the beginning of the text.
572
+ @return: List of generated sequences.
573
+ """
574
+ self.eval()
575
+ if switch_to_model_ae:
576
+ assert self.model_ae is not None and self.tokenizer_ae is not None
577
+ model = self.model_ae
578
+ tokenizer = self.tokenizer_ae
579
+ max_length_output = self.max_length_output_ae
580
+ else:
581
+ model = self.model
582
+ tokenizer = self.tokenizer
583
+ max_length_output = self.max_length_output
584
+
585
+ if sentence_level:
586
+ assert highlights is not None, '`sentence_level` needs `highlights`.'
587
+ assert len(highlights) == len(inputs), str([len(highlights), len(inputs)])
588
+ list_sentence = []
589
+ for context, answer in zip(inputs, highlights):
590
+ s = [sentence for sentence in self.spacy_module.sentence(context) if answer in sentence]
591
+ list_sentence.append(s[0] if len(s) != 0 else context)
592
+ inputs = list_sentence
593
+
594
+ assert type(inputs) is list, inputs
595
+ encode_list = self.text_to_encode(
596
+ inputs,
597
+ highlights=highlights,
598
+ prefix_type=prefix_type,
599
+ cache_path=cache_path,
600
+ switch_to_model_ae=switch_to_model_ae
601
+ )
602
+ loader = self.get_data_loader(encode_list, batch_size=batch_size)
603
+ outputs = []
604
+ for encode in loader:
605
+ with torch.no_grad():
606
+ if 'labels' in encode:
607
+ encode.pop('labels')
608
+ encode = {k: v.to(self.device) for k, v in encode.items()}
609
+ encode['max_length'] = max_length_output
610
+ encode['num_beams'] = num_beams
611
+ tensor = model.module.generate(**encode) if self.parallel else model.generate(**encode)
612
+ outputs += tokenizer.batch_decode(tensor, skip_special_tokens=True)
613
+ return outputs
plms/spacy_module.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spacy
2
+
3
+ __all__ = 'SpacyPipeline'
4
+
5
+ MODELS = {
6
+ "en": "en_core_web_sm",
7
+ "ja": "ja_core_news_sm",
8
+ "zh": "zh_core_web_sm",
9
+ "de": "de_core_news_sm",
10
+ "es": "es_core_news_sm",
11
+ "it": "it_core_news_sm",
12
+ "ko": "ko_core_news_sm",
13
+ "ru": "ru_core_news_sm",
14
+ "fr": "fr_core_news_sm",
15
+ "vi": "vi_core_news_lg"
16
+ }
17
+ VALID_METHODS = ['positionrank', 'textrank', 'biasedtextrank', 'positionrank', 'ner']
18
+
19
+
20
+ class SpacyPipeline:
21
+
22
+ def __init__(self, language, algorithm: str = None):
23
+ model = "vi_core_news_lg" if language not in MODELS else MODELS[language]
24
+
25
+ self.nlp = spacy.load(model)
26
+ self.nlp.add_pipe("sentencizer")
27
+ self.algorithm = algorithm
28
+ self.library = None
29
+ if self.algorithm is not None and self.algorithm != 'ner':
30
+ assert algorithm in VALID_METHODS, f'invalid algorithm {algorithm}\n- valid list: {VALID_METHODS}'
31
+ if self.algorithm == 'yake':
32
+ import spacy_ke # need to load yake
33
+ self.nlp.add_pipe("yake")
34
+ self.library = 'spacy_ke'
35
+ elif self.algorithm in ['textrank', 'biasedtextrank', 'positionrank']:
36
+ import pytextrank
37
+ self.nlp.add_pipe(algorithm)
38
+ self.library = 'pytextrank'
39
+ else:
40
+ raise ValueError(f'unknown algorithm: {self.algorithm}')
41
+
42
+ def _get_keyword(self, output, original_document=None, n=None):
43
+ if self.algorithm == 'ner':
44
+ return [str(i) for i in output.ents]
45
+ assert original_document is not None
46
+ assert n is not None
47
+ if self.library == 'spacy_ke':
48
+ return [str(term) for term, score in output._.extract_keywords(n) if str(term) in original_document]
49
+ return [str(i.text) for i in output._.phrases[:n] if str(i.text) in original_document]
50
+
51
+ def sentence_keyword(self, string: str, n: int = 10):
52
+ out = self.nlp(string)
53
+ sentence = [str(i) for i in out.sents if len(i) > 0]
54
+ keyword = self._get_keyword(out, string, n)
55
+ return sentence, keyword
56
+
57
+ def sentence(self, string: str):
58
+ return [str(i) for i in self.nlp(string).sents if len(i) > 0]
59
+
60
+ def token(self, string: str):
61
+ return [str(i) for i in self.nlp.tokenizer(string)]
62
+
63
+ def keyword(self, string: str, n: int = 10):
64
+ return self._get_keyword(self.nlp(string), string, n)
65
+
66
+ def ner(self, string: str, n: int = None):
67
+ keywords = self.nlp(string).ents
68
+ return keywords[:min(len(keywords), n)] if n is not None else keywords
69
+
70
+ @property
71
+ def language(self):
72
+ return self.nlp.lang