Shaltiel commited on
Commit
8f0599d
1 Parent(s): 46741a2

Upload 4 files

Browse files
BertForJointParsing.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import re
3
+ from operator import itemgetter
4
+ import torch
5
+ from torch import nn
6
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
7
+ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
8
+ from transformers.models.bert.modeling_bert import BertOnlyMLMHead
9
+ from transformers.utils import ModelOutput
10
+ from .BertForSyntaxParsing import BertSyntaxParsingHead, SyntaxLabels, SyntaxLogitsOutput, parse_logits as syntax_parse_logits
11
+ from .BertForPrefixMarking import BertPrefixMarkingHead, parse_logits as prefix_parse_logits, encode_sentences_for_bert_for_prefix_marking, get_prefixes_from_str
12
+ from .BertForMorphTagging import BertMorphTaggingHead, MorphLogitsOutput, MorphLabels, parse_logits as morph_parse_logits
13
+
14
+ import warnings
15
+
16
+ @dataclass
17
+ class JointParsingOutput(ModelOutput):
18
+ loss: Optional[torch.FloatTensor] = None
19
+ # logits will contain the optional predictions for the given labels
20
+ logits: Optional[Union[SyntaxLogitsOutput, None]] = None
21
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
22
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
23
+ # if no labels are given, we will always include the syntax logits separately
24
+ syntax_logits: Optional[SyntaxLogitsOutput] = None
25
+ ner_logits: Optional[torch.FloatTensor] = None
26
+ prefix_logits: Optional[torch.FloatTensor] = None
27
+ lex_logits: Optional[torch.FloatTensor] = None
28
+ morph_logits: Optional[MorphLogitsOutput] = None
29
+
30
+ # wrapper class to wrap a torch.nn.Module so that you can store a module in multiple linked
31
+ # properties without registering the parameter multiple times
32
+ class ModuleRef:
33
+ def __init__(self, module: torch.nn.Module):
34
+ self.module = module
35
+
36
+ def forward(self, *args, **kwargs):
37
+ return self.module.forward(*args, **kwargs)
38
+
39
+ def __call__(self, *args, **kwargs):
40
+ return self.module(*args, **kwargs)
41
+
42
+ class BertForJointParsing(BertPreTrainedModel):
43
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
44
+
45
+ def __init__(self, config, do_syntax=None, do_ner=None, do_prefix=None, do_lex=None, do_morph=None, syntax_head_size=64):
46
+ super().__init__(config)
47
+
48
+ self.bert = BertModel(config, add_pooling_layer=False)
49
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
50
+ # create all the heads as None, and then populate them as defined
51
+ self.syntax, self.ner, self.prefix, self.lex, self.morph = (None,)*5
52
+
53
+ if do_syntax is not None:
54
+ config.do_syntax = do_syntax
55
+ config.syntax_head_size = syntax_head_size
56
+ if do_ner is not None: config.do_ner = do_ner
57
+ if do_prefix is not None: config.do_prefix = do_prefix
58
+ if do_lex is not None: config.do_lex = do_lex
59
+ if do_morph is not None: config.do_morph = do_morph
60
+
61
+ # add all the individual heads
62
+ if config.do_syntax:
63
+ self.syntax = BertSyntaxParsingHead(config)
64
+ if config.do_ner:
65
+ self.num_labels = config.num_labels
66
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) # name it same as in BertForTokenClassification
67
+ self.ner = ModuleRef(self.classifier)
68
+ if config.do_prefix:
69
+ self.prefix = BertPrefixMarkingHead(config)
70
+ if config.do_lex:
71
+ self.cls = BertOnlyMLMHead(config) # name it the same as in BertForMaskedLM
72
+ self.lex = ModuleRef(self.cls)
73
+ if config.do_morph:
74
+ self.morph = BertMorphTaggingHead(config)
75
+
76
+ # Initialize weights and apply final processing
77
+ self.post_init()
78
+
79
+ def get_output_embeddings(self):
80
+ return self.cls.predictions.decoder if self.lex is not None else None
81
+
82
+ def set_output_embeddings(self, new_embeddings):
83
+ if self.lex is not None:
84
+
85
+ self.cls.predictions.decoder = new_embeddings
86
+
87
+ def forward(
88
+ self,
89
+ input_ids: Optional[torch.Tensor] = None,
90
+ attention_mask: Optional[torch.Tensor] = None,
91
+ token_type_ids: Optional[torch.Tensor] = None,
92
+ position_ids: Optional[torch.Tensor] = None,
93
+ prefix_class_id_options: Optional[torch.Tensor] = None,
94
+ labels: Optional[Union[SyntaxLabels, MorphLabels, torch.Tensor]] = None,
95
+ labels_type: Optional[Literal['syntax', 'ner', 'prefix', 'lex', 'morph']] = None,
96
+ head_mask: Optional[torch.Tensor] = None,
97
+ inputs_embeds: Optional[torch.Tensor] = None,
98
+ output_attentions: Optional[bool] = None,
99
+ output_hidden_states: Optional[bool] = None,
100
+ return_dict: Optional[bool] = None,
101
+ compute_syntax_mst: Optional[bool] = None
102
+ ):
103
+ if return_dict is False:
104
+ warnings.warn("Specified `return_dict=False` but the flag is ignored and treated as always True in this model.")
105
+
106
+ if labels is not None and labels_type is None:
107
+ raise ValueError("Cannot specify labels without labels_type")
108
+
109
+ if labels_type == 'seg' and prefix_class_id_options is None:
110
+ raise ValueError('Cannot calculate prefix logits without prefix_class_id_options')
111
+
112
+ if compute_syntax_mst is not None and self.syntax is None:
113
+ raise ValueError("Cannot compute syntax MST when the syntax head isn't loaded")
114
+
115
+
116
+ bert_outputs = self.bert(
117
+ input_ids,
118
+ attention_mask=attention_mask,
119
+ token_type_ids=token_type_ids,
120
+ position_ids=position_ids,
121
+ head_mask=head_mask,
122
+ inputs_embeds=inputs_embeds,
123
+ output_attentions=output_attentions,
124
+ output_hidden_states=output_hidden_states,
125
+ return_dict=True,
126
+ )
127
+
128
+ # calculate the extended attention mask for any child that might need it
129
+ extended_attention_mask = None
130
+ if attention_mask is not None:
131
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.size())
132
+
133
+ # extract the hidden states, and apply the dropout
134
+ hidden_states = self.dropout(bert_outputs[0])
135
+
136
+ logits = None
137
+ syntax_logits = None
138
+ ner_logits = None
139
+ prefix_logits = None
140
+ lex_logits = None
141
+ morph_logits = None
142
+
143
+ # Calculate the syntax
144
+ if self.syntax is not None and (labels is None or labels_type == 'syntax'):
145
+ # apply the syntax head
146
+ loss, syntax_logits = self.syntax(hidden_states, extended_attention_mask, labels, compute_syntax_mst)
147
+ logits = syntax_logits
148
+
149
+ # Calculate the NER
150
+ if self.ner is not None and (labels is None or labels_type == 'ner'):
151
+ ner_logits = self.ner(hidden_states)
152
+ logits = ner_logits
153
+ if labels is not None:
154
+ loss_fct = nn.CrossEntropyLoss()
155
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
156
+
157
+ # Calculate the segmentation
158
+ if self.prefix is not None and (labels is None or labels_type == 'prefix'):
159
+ loss, prefix_logits = self.prefix(hidden_states, prefix_class_id_options, labels)
160
+ logits = prefix_logits
161
+
162
+ # Calculate the lexeme
163
+ if self.lex is not None and (labels is None or labels_type == 'lex'):
164
+ lex_logits = self.lex(hidden_states)
165
+ logits = lex_logits
166
+ if labels is not None:
167
+ loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
168
+ loss = loss_fct(lex_logits.view(-1, self.config.vocab_size), labels.view(-1))
169
+
170
+ if self.morph is not None and (labels is None or labels_type == 'morph'):
171
+ loss, morph_logits = self.morph(hidden_states, labels)
172
+ logits = morph_logits
173
+
174
+ # no labels => logits = None
175
+ if labels is None: logits = None
176
+
177
+ return JointParsingOutput(
178
+ loss,
179
+ logits,
180
+ hidden_states=bert_outputs.hidden_states,
181
+ attentions=bert_outputs.attentions,
182
+ # all the predicted logits section
183
+ syntax_logits=syntax_logits,
184
+ ner_logits=ner_logits,
185
+ prefix_logits=prefix_logits,
186
+ lex_logits=lex_logits,
187
+ morph_logits=morph_logits
188
+ )
189
+
190
+ def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, output_style: Literal['json', 'ud', 'iahlt_ud'] = 'json'):
191
+ is_single_sentence = isinstance(sentences, str)
192
+ if is_single_sentence:
193
+ sentences = [sentences]
194
+
195
+ if output_style not in ['json', 'ud', 'iahlt_ud']:
196
+ raise ValueError('output_style must be in json/ud/iahlt_ud')
197
+ if output_style in ['ud', 'iahlt_ud'] and (self.prefix is None or self.morph is None or self.syntax is None or self.lex is None):
198
+ raise ValueError("Cannot output UD format when any of the prefix,morph,syntax, and lex heads aren't loaded.")
199
+
200
+ # predict the logits for the sentence
201
+ if self.prefix is not None:
202
+ inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
203
+ else:
204
+ inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_offsets_mapping=True, return_tensors='pt')
205
+
206
+ offset_mapping = inputs.pop('offset_mapping')
207
+ # Copy the tensors to the right device, and parse!
208
+ inputs = {k:v.to(self.device) for k,v in inputs.items()}
209
+ output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
210
+
211
+ input_ids = inputs['input_ids'].tolist() # convert once
212
+ final_output = [dict(text=sentence, tokens=combine_token_wordpieces(ids, offsets, tokenizer)) for sentence, ids, offsets in zip(sentences, input_ids, offset_mapping)]
213
+ # Syntax logits: each sentence gets a dict(tree: List[dict(word,dep_head,dep_head_idx,dep_func)], root_idx: int)
214
+ if output.syntax_logits is not None:
215
+ for sent_idx,parsed in enumerate(syntax_parse_logits(input_ids, sentences, tokenizer, output.syntax_logits)):
216
+ merge_token_list(final_output[sent_idx]['tokens'], parsed['tree'], 'syntax')
217
+ final_output[sent_idx]['root_idx'] = parsed['root_idx']
218
+
219
+ # Prefix logits: each sentence gets a list([prefix_segment, word_without_prefix]) - **WITH CLS & SEP**
220
+ if output.prefix_logits is not None:
221
+ for sent_idx,parsed in enumerate(prefix_parse_logits(input_ids, sentences, tokenizer, output.prefix_logits)):
222
+ merge_token_list(final_output[sent_idx]['tokens'], map(tuple, parsed[1:-1]), 'seg')
223
+
224
+ # Lex logits each sentence gets a list(tuple(word, lexeme))
225
+ if output.lex_logits is not None:
226
+ for sent_idx, parsed in enumerate(lex_parse_logits(input_ids, sentences, tokenizer, output.lex_logits)):
227
+ merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'lex')
228
+
229
+ # morph logits each sentences get a dict(text=str, tokens=list(dict(token, pos, feats, prefixes, suffix, suffix_feats?)))
230
+ if output.morph_logits is not None:
231
+ for sent_idx,parsed in enumerate(morph_parse_logits(input_ids, sentences, tokenizer, output.morph_logits)):
232
+ merge_token_list(final_output[sent_idx]['tokens'], parsed['tokens'], 'morph')
233
+
234
+ # NER logits each sentence gets a list(tuple(word, ner))
235
+ if output.ner_logits is not None:
236
+ for sent_idx,parsed in enumerate(ner_parse_logits(input_ids, sentences, tokenizer, output.ner_logits, self.config.id2label)):
237
+ if per_token_ner:
238
+ merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
239
+ final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(final_output[sent_idx], parsed)
240
+
241
+ if output_style in ['ud', 'iahlt_ud']:
242
+ final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
243
+
244
+ if is_single_sentence:
245
+ final_output = final_output[0]
246
+ return final_output
247
+
248
+
249
+
250
+ def aggregate_ner_tokens(final_output, parsed):
251
+ entities = []
252
+ prev = None
253
+ for token_idx, (d, (word, pred)) in enumerate(zip(final_output['tokens'], parsed)):
254
+ # O does nothing
255
+ if pred == 'O': prev = None
256
+ # B- || I-entity != prev (different entity or none)
257
+ elif pred.startswith('B-') or pred[2:] != prev:
258
+ prev = pred[2:]
259
+ entities.append([[word], dict(label=prev, start=d['offsets']['start'], end=d['offsets']['end'], token_start=token_idx, token_end=token_idx)])
260
+ else:
261
+ entities[-1][0].append(word)
262
+ entities[-1][1]['end'] = d['offsets']['end']
263
+ entities[-1][1]['token_end'] = token_idx
264
+
265
+ return [dict(phrase=' '.join(words), **d) for words, d in entities]
266
+
267
+ def merge_token_list(src, update, key):
268
+ for token_src, token_update in zip(src, update):
269
+ token_src[key] = token_update
270
+
271
+ def combine_token_wordpieces(input_ids: List[int], offset_mapping: torch.Tensor, tokenizer: BertTokenizerFast):
272
+ offset_mapping = offset_mapping.tolist()
273
+ ret = []
274
+ special_toks = tokenizer.all_special_tokens
275
+ for token, offsets in zip(tokenizer.convert_ids_to_tokens(input_ids), offset_mapping):
276
+ if token in special_toks: continue
277
+ if token.startswith('##'):
278
+ ret[-1]['token'] += token[2:]
279
+ ret[-1]['offsets']['end'] = offsets[1]
280
+ else: ret.append(dict(token=token, offsets=dict(start=offsets[0], end=offsets[1])))
281
+ return ret
282
+
283
+ def ner_parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]):
284
+ predictions = torch.argmax(logits, dim=-1).tolist()
285
+ batch_ret = []
286
+
287
+ special_toks = tokenizer.all_special_tokens
288
+ for batch_idx in range(len(sentences)):
289
+
290
+ ret = []
291
+ batch_ret.append(ret)
292
+
293
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[batch_idx])
294
+ for tok_idx in range(len(tokens)):
295
+ token = tokens[tok_idx]
296
+ if token in special_toks: continue
297
+
298
+ # wordpieces should just be appended to the previous word
299
+ # we modify the last token in ret
300
+ # by discarding the original end position and replacing it with the new token's end position
301
+ if token.startswith('##'):
302
+ continue
303
+ # for each token, we append a tuple containing: token, label, start position, end position
304
+ ret.append((token, id2label[predictions[batch_idx][tok_idx]]))
305
+
306
+ return batch_ret
307
+
308
+ def lex_parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
309
+
310
+ predictions = torch.argsort(logits, dim=-1, descending=True)[..., :3].tolist()
311
+ batch_ret = []
312
+
313
+ special_toks = tokenizer.all_special_tokens
314
+ for batch_idx in range(len(sentences)):
315
+ intermediate_ret = []
316
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[batch_idx])
317
+ for tok_idx in range(len(tokens)):
318
+ token = tokens[tok_idx]
319
+ if token in special_toks: continue
320
+
321
+ # wordpieces should just be appended to the previous word
322
+ if token.startswith('##'):
323
+ intermediate_ret[-1] = (intermediate_ret[-1][0] + token[2:], intermediate_ret[-1][1])
324
+ continue
325
+ intermediate_ret.append((token, tokenizer.convert_ids_to_tokens(predictions[batch_idx][tok_idx])))
326
+
327
+ # build the final output taking into account valid letters
328
+ ret = []
329
+ batch_ret.append(ret)
330
+ for (token, lexemes) in intermediate_ret:
331
+ # must overlap on at least 2 non אהוי letters
332
+ possible_lets = set(c for c in token if c not in 'אהוי')
333
+ final_lex = '[BLANK]'
334
+ for lex in lexemes:
335
+ if sum(c in possible_lets for c in lex) >= min([2, len(possible_lets), len([c for c in lex if c not in 'אהוי'])]):
336
+ final_lex = lex
337
+ break
338
+ ret.append((token, final_lex))
339
+
340
+ return batch_ret
341
+
342
+ ud_prefixes_to_pos = {
343
+ 'ש': ['SCONJ'],
344
+ 'מש': ['SCONJ'],
345
+ 'כש': ['SCONJ'],
346
+ 'לכש': ['SCONJ'],
347
+ 'בש': ['SCONJ'],
348
+ 'לש': ['SCONJ'],
349
+ 'ו': ['CCONJ'],
350
+ 'ל': ['ADP'],
351
+ 'ה': ['DET', 'SCONJ'],
352
+ 'מ': ['ADP', 'SCONJ'],
353
+ 'ב': ['ADP'],
354
+ 'כ': ['ADP', 'ADV'],
355
+ }
356
+ ud_suffix_to_htb_str = {
357
+ 'Gender=Masc|Number=Sing|Person=3': '_הוא',
358
+ 'Gender=Masc|Number=Plur|Person=3': '_הם',
359
+ 'Gender=Fem|Number=Sing|Person=3': '_היא',
360
+ 'Gender=Fem|Number=Plur|Person=3': '_הן',
361
+ 'Gender=Fem,Masc|Number=Plur|Person=1': '_אנחנו',
362
+ 'Gender=Fem,Masc|Number=Sing|Person=1': '_אני',
363
+ 'Gender=Masc|Number=Plur|Person=2': '_אתם',
364
+ 'Gender=Masc|Number=Sing|Person=3': '_הוא',
365
+ 'Gender=Masc|Number=Sing|Person=2': '_אתה',
366
+ 'Gender=Fem|Number=Sing|Person=2': '_את',
367
+ 'Gender=Masc|Number=Plur|Person=3': '_הם'
368
+ }
369
+ def convert_output_to_ud(output_sentences, style: Literal['htb', 'iahlt']):
370
+ if style not in ['htb', 'iahlt']:
371
+ raise ValueError('style must be htb/iahlt')
372
+
373
+ final_output = []
374
+ for sent_idx, sentence in enumerate(output_sentences):
375
+ # next, go through each word and insert it in the UD format. Store in a temp format for the post process
376
+ intermediate_output = []
377
+ ranges = []
378
+ # store a mapping between each word index and the actual line it appears in
379
+ idx_to_key = {-1: 0}
380
+ for word_idx,word in enumerate(sentence['tokens']):
381
+ try:
382
+ # handle blank lexemes
383
+ if word['lex'] == '[BLANK]':
384
+ word['lex'] = word['seg'][-1]
385
+ except KeyError:
386
+ import json
387
+ print(json.dumps(sentence, ensure_ascii=False, indent=2))
388
+ exit(0)
389
+
390
+ start = len(intermediate_output)
391
+ # Add in all the prefixes
392
+ if len(word['seg']) > 1:
393
+ for pre in get_prefixes_from_str(word['seg'][0], greedy=True):
394
+ # pos - just take the first valid pos that appears in the predicted prefixes list.
395
+ pos = next((pos for pos in ud_prefixes_to_pos[pre] if pos in word['morph']['prefixes']), ud_prefixes_to_pos[pre][0])
396
+ dep, func = ud_get_prefix_dep(pre, word, word_idx)
397
+ intermediate_output.append(dict(word=pre, lex=pre, pos=pos, dep=dep, func=func, feats='_'))
398
+
399
+ # if there was an implicit heh, add it in dependent on the method
400
+ if not 'ה' in pre and intermediate_output[-1]['pos'] == 'ADP' and 'DET' in word['morph']['prefixes']:
401
+ if style == 'htb':
402
+ intermediate_output.append(dict(word='ה_', lex='ה', pos='DET', dep=word_idx, func='det', feats='_'))
403
+ elif style == 'iahlt':
404
+ intermediate_output[-1]['feats'] = 'Definite=Def|PronType=Art'
405
+
406
+
407
+ idx_to_key[word_idx] = len(intermediate_output) + 1
408
+ # add the main word in!
409
+ intermediate_output.append(dict(
410
+ word=word['seg'][-1], lex=word['lex'], pos=word['morph']['pos'],
411
+ dep=word['syntax']['dep_head_idx'], func=word['syntax']['dep_func'],
412
+ feats='|'.join(f'{k}={v}' for k,v in word['morph']['feats'].items())))
413
+
414
+ # if we have suffixes, this changes things
415
+ if word['morph']['suffix']:
416
+ # first determine the dependency info:
417
+ # For adp, num, det - they main word points to here, and the suffix points to the dependency
418
+ entry_to_assign_suf_dep = None
419
+ if word['morph']['pos'] in ['ADP', 'NUM', 'DET']:
420
+ entry_to_assign_suf_dep = intermediate_output[-1]
421
+ intermediate_output[-1]['func'] = 'case'
422
+ dep = word['syntax']['dep_head_idx']
423
+ func = word['syntax']['dep_func']
424
+ else:
425
+ # if pos is verb -> obj, num -> dep, default to -> nmod:poss
426
+ dep = word_idx
427
+ func = {'VERB': 'obj', 'NUM': 'dep'}.get(word['morph']['pos'], 'nmod:poss')
428
+
429
+ s_word, s_lex = word['seg'][-1], word['lex']
430
+ # update the word of the string and extract the string of the suffix!
431
+ # for IAHLT:
432
+ if style == 'iahlt':
433
+ # we need to shorten the main word and extract the suffix
434
+ # if it is longer than the lexeme - just take off the lexeme.
435
+ if len(s_word) > len(s_lex):
436
+ idx = len(s_lex)
437
+ # Otherwise, try to find the last letter of the lexeme, and fail that just take the last letter
438
+ else:
439
+ # take either len-1, or the last occurence (which can be -1 === len-1)
440
+ idx = min([len(s_word) - 1, s_word.rfind(s_lex[-1])])
441
+ # extract the suffix and update the main word
442
+ suf = s_word[idx:]
443
+ intermediate_output[-1]['word'] = s_word[:idx]
444
+ # for htb:
445
+ elif style == 'htb':
446
+ # main word becomes the lexeme, the suffix is based on the features
447
+ intermediate_output[-1]['word'] = (s_lex if s_lex != s_word else s_word[:-1]) + '_'
448
+ suf_feats = word['morph']['suffix_feats']
449
+ suf = ud_suffix_to_htb_str.get(f"Gender={suf_feats.get('Gender', 'Fem,Masc')}|Number={suf_feats.get('Number', 'Sing')}|Person={suf_feats.get('Person', '3')}", "_הוא")
450
+ # for HTB, if the function is poss, then add a shel pointing to the next word
451
+ if func == 'nmod:poss' and s_lex != 'של':
452
+ intermediate_output.append(dict(word='_של_', lex='של', pos='ADP', dep=len(intermediate_output) + 2, func='case', feats='_', absolute_dep=True))
453
+ # add the main suffix in
454
+ intermediate_output.append(dict(word=suf, lex='הוא', pos='PRON', dep=dep, func=func, feats='|'.join(f'{k}={v}' for k,v in word['morph']['suffix_feats'].items())))
455
+ if entry_to_assign_suf_dep:
456
+ entry_to_assign_suf_dep['dep'] = len(intermediate_output)
457
+ entry_to_assign_suf_dep['absolute_dep'] = True
458
+
459
+ end = len(intermediate_output)
460
+ ranges.append((start, end, word['token']))
461
+
462
+ # now that we have the intermediate output, combine it to the final output
463
+ cur_output = []
464
+ final_output.append(cur_output)
465
+ # first, add the headers
466
+ cur_output.append(f'# sent_id = {sent_idx + 1}')
467
+ cur_output.append(f'# text = {sentence["text"]}')
468
+
469
+ # add in all the actual entries
470
+ for start,end,token in ranges:
471
+ if end - start > 1:
472
+ cur_output.append(f'{start + 1}-{end}\t{token}\t_\t_\t_\t_\t_\t_\t_\t_')
473
+ for idx,output in enumerate(intermediate_output[start:end], start + 1):
474
+ # compute the actual dependency location
475
+ dep = output['dep'] if output.get('absolute_dep', False) else idx_to_key[output['dep']]
476
+ func = normalize_dep_rel(output['func'], style)
477
+ # and add the full ud string in
478
+ cur_output.append('\t'.join([
479
+ str(idx),
480
+ output['word'],
481
+ output['lex'],
482
+ output['pos'],
483
+ output['pos'],
484
+ output['feats'],
485
+ str(dep),
486
+ func,
487
+ '_', '_'
488
+ ]))
489
+ return final_output
490
+
491
+ def normalize_dep_rel(dep, style: Literal['htb', 'iahlt']):
492
+ if style == 'iahlt':
493
+ if dep == 'compound:smixut': return 'compound'
494
+ if dep == 'nsubj:cop': return 'nsubj'
495
+ if dep == 'mark:q': return 'mark'
496
+ if dep == 'case:gen' or dep == 'case:acc': return 'case'
497
+ return dep
498
+
499
+
500
+ def ud_get_prefix_dep(pre, word, word_idx):
501
+ does_follow_main = False
502
+
503
+ # shin goes to the main word for verbs, otherwise follows the word
504
+ if pre.endswith('ש'):
505
+ does_follow_main = word['morph']['pos'] != 'VERB'
506
+ func = 'mark'
507
+ # vuv goes to the main word if the function is in the list, otherwise follows
508
+ elif pre == 'ו':
509
+ does_follow_main = word['syntax']['dep_func'] not in ["conj", "acl:recl", "parataxis", "root", "acl", "amod", "list", "appos", "dep", "flatccomp"]
510
+ func = 'cc'
511
+ else:
512
+ # for adj, noun, propn, pron, verb - prefixes go to the main word
513
+ if word['morph']['pos'] in ["ADJ", "NOUN", "PROPN", "PRON", "VERB"]:
514
+ does_follow_main = False
515
+ # otherwise - prefix follows the word if the function is in the list
516
+ else: does_follow_main = word['syntax']['dep_func'] in ["compound:affix", "det", "aux", "nummod", "advmod", "dep", "cop", "mark", "fixed"]
517
+
518
+ func = 'case'
519
+ if pre == 'ה':
520
+ func = 'det' if 'DET' in word['morph']['prefixes'] else 'mark'
521
+
522
+ return (word['syntax']['dep_head_idx'] if does_follow_main else word_idx), func
523
+
BertForMorphTagging.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from operator import itemgetter
3
+ from transformers.utils import ModelOutput
4
+ import torch
5
+ from torch import nn
6
+ from typing import Dict, List, Tuple, Optional
7
+ from dataclasses import dataclass
8
+ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
9
+
10
+ ALL_POS = ['DET', 'NOUN', 'VERB', 'CCONJ', 'ADP', 'PRON', 'PUNCT', 'ADJ', 'ADV', 'SCONJ', 'NUM', 'PROPN', 'AUX', 'X', 'INTJ', 'SYM']
11
+ ALL_PREFIX_POS = ['SCONJ', 'DET', 'ADV', 'CCONJ', 'ADP', 'NUM']
12
+ ALL_SUFFIX_POS = ['none', 'ADP_PRON', 'PRON']
13
+ ALL_FEATURES = [
14
+ ('Gender', ['none', 'Masc', 'Fem', 'Fem,Masc']),
15
+ ('Number', ['none', 'Sing', 'Plur', 'Plur,Sing', 'Dual', 'Dual,Plur']),
16
+ ('Person', ['none', '1', '2', '3', '1,2,3']),
17
+ ('Tense', ['none', 'Past', 'Fut', 'Pres', 'Imp'])
18
+ ]
19
+
20
+ @dataclass
21
+ class MorphLogitsOutput(ModelOutput):
22
+ prefix_logits: torch.FloatTensor = None
23
+ pos_logits: torch.FloatTensor = None
24
+ features_logits: List[torch.FloatTensor] = None
25
+ suffix_logits: torch.FloatTensor = None
26
+ suffix_features_logits: List[torch.FloatTensor] = None
27
+
28
+ def detach(self):
29
+ return MorphLogitsOutput(self.prefix_logits.detach(), self.pos_logits.detach(), [logits.deatch() for logits in self.features_logits], self.suffix_logits.detach(), [logits.deatch() for logits in self.suffix_features_logits])
30
+
31
+
32
+ @dataclass
33
+ class MorphTaggingOutput(ModelOutput):
34
+ loss: Optional[torch.FloatTensor] = None
35
+ logits: Optional[MorphLogitsOutput] = None
36
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
37
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
38
+
39
+ @dataclass
40
+ class MorphLabels(ModelOutput):
41
+ prefix_labels: Optional[torch.FloatTensor] = None
42
+ pos_labels: Optional[torch.FloatTensor] = None
43
+ features_labels: Optional[List[torch.FloatTensor]] = None
44
+ suffix_labels: Optional[torch.FloatTensor] = None
45
+ suffix_features_labels: Optional[List[torch.FloatTensor]] = None
46
+
47
+ def detach(self):
48
+ return MorphLabels(self.prefix_labels.detach(), self.pos_labels.detach(), [labels.detach() for labels in self.features_labels], self.suffix_labels.detach(), [labels.detach() for labels in self.suffix_features_labels])
49
+
50
+ def to(self, device):
51
+ return MorphLabels(self.prefix_labels.to(device), self.pos_labels.to(device), [feat.to(device) for feat in self.features_labels], self.suffix_labels.to(device), [feat.to(device) for feat in self.suffix_features_labels])
52
+
53
+ class BertMorphTaggingHead(nn.Module):
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.config = config
57
+
58
+ self.num_prefix_classes = len(ALL_PREFIX_POS)
59
+ self.num_pos_classes = len(ALL_POS)
60
+ self.num_suffix_classes = len(ALL_SUFFIX_POS)
61
+ self.num_features_classes = list(map(len, map(itemgetter(1), ALL_FEATURES)))
62
+ # we need a classifier for prefix cls and POS cls
63
+ # the prefix will use BCEWithLogits for multiple labels cls
64
+ self.prefix_cls = nn.Linear(config.hidden_size, self.num_prefix_classes)
65
+ # and pos + feats will use good old cross entropy for single label
66
+ self.pos_cls = nn.Linear(config.hidden_size, self.num_pos_classes)
67
+ self.features_cls = nn.ModuleList([nn.Linear(config.hidden_size, len(features)) for _, features in ALL_FEATURES])
68
+ # and suffix + feats will also be cross entropy
69
+ self.suffix_cls = nn.Linear(config.hidden_size, self.num_suffix_classes)
70
+ self.suffix_features_cls = nn.ModuleList([nn.Linear(config.hidden_size, len(features)) for _, features in ALL_FEATURES])
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ labels: Optional[MorphLabels] = None):
76
+ # run each of the classifiers on the transformed output
77
+ prefix_logits = self.prefix_cls(hidden_states)
78
+ pos_logits = self.pos_cls(hidden_states)
79
+ suffix_logits = self.suffix_cls(hidden_states)
80
+ features_logits = [cls(hidden_states) for cls in self.features_cls]
81
+ suffix_features_logits = [cls(hidden_states) for cls in self.suffix_features_cls]
82
+
83
+ loss = None
84
+ if labels is not None:
85
+ # step 1: prefix labels loss
86
+ loss_fct = nn.BCEWithLogitsLoss(weight=(labels.prefix_labels != -100).float())
87
+ loss = loss_fct(prefix_logits, labels.prefix_labels)
88
+ # step 2: pos labels loss
89
+ loss_fct = nn.CrossEntropyLoss()
90
+ loss += loss_fct(pos_logits.view(-1, self.num_pos_classes), labels.pos_labels.view(-1))
91
+ # step 2b: features
92
+ for feat_logits,feat_labels,num_features in zip(features_logits, labels.features_labels, self.num_features_classes):
93
+ loss += loss_fct(feat_logits.view(-1, num_features), feat_labels.view(-1))
94
+ # step 3: suffix logits loss
95
+ loss += loss_fct(suffix_logits.view(-1, self.num_suffix_classes), labels.suffix_labels.view(-1))
96
+ # step 3b: suffix features
97
+ for feat_logits,feat_labels,num_features in zip(suffix_features_logits, labels.suffix_features_labels, self.num_features_classes):
98
+ loss += loss_fct(feat_logits.view(-1, num_features), feat_labels.view(-1))
99
+
100
+ return loss, MorphLogitsOutput(prefix_logits, pos_logits, features_logits, suffix_logits, suffix_features_logits)
101
+
102
+ class BertForMorphTagging(BertPreTrainedModel):
103
+
104
+ def __init__(self, config):
105
+ super().__init__(config)
106
+
107
+ self.bert = BertModel(config, add_pooling_layer=False)
108
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
109
+ self.morph = BertMorphTaggingHead(config)
110
+
111
+ # Initialize weights and apply final processing
112
+ self.post_init()
113
+
114
+ def forward(
115
+ self,
116
+ input_ids: Optional[torch.Tensor] = None,
117
+ attention_mask: Optional[torch.Tensor] = None,
118
+ token_type_ids: Optional[torch.Tensor] = None,
119
+ position_ids: Optional[torch.Tensor] = None,
120
+ labels: Optional[MorphLabels] = None,
121
+ head_mask: Optional[torch.Tensor] = None,
122
+ inputs_embeds: Optional[torch.Tensor] = None,
123
+ output_attentions: Optional[bool] = None,
124
+ output_hidden_states: Optional[bool] = None,
125
+ return_dict: Optional[bool] = None,
126
+ ):
127
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
128
+
129
+ bert_outputs = self.bert(
130
+ input_ids,
131
+ attention_mask=attention_mask,
132
+ token_type_ids=token_type_ids,
133
+ position_ids=position_ids,
134
+ head_mask=head_mask,
135
+ inputs_embeds=inputs_embeds,
136
+ output_attentions=output_attentions,
137
+ output_hidden_states=output_hidden_states,
138
+ return_dict=return_dict,
139
+ )
140
+
141
+ hidden_states = bert_outputs[0]
142
+ hidden_states = self.dropout(hidden_states)
143
+
144
+ loss, logits = self.morph(hidden_states, labels)
145
+
146
+ if not return_dict:
147
+ return (loss,logits) + bert_outputs[2:]
148
+
149
+ return MorphTaggingOutput(
150
+ loss=loss,
151
+ logits=logits,
152
+ hidden_states=bert_outputs.hidden_states,
153
+ attentions=bert_outputs.attentions,
154
+ )
155
+
156
+ def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
157
+ # tokenize the inputs and convert them to relevant device
158
+ inputs = tokenizer(sentences, padding=padding, truncation=True, return_tensors='pt')
159
+ inputs = {k:v.to(self.device) for k,v in inputs.items()}
160
+ # calculate the logits
161
+ logits = self.forward(**inputs, return_dict=True).logits
162
+ return parse_logits(inputs['input_ids'].tolist(), sentences, tokenizer, logits)
163
+
164
+ def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: MorphLogitsOutput):
165
+ prefix_logits, pos_logits, feats_logits, suffix_logits, suffix_feats_logits = \
166
+ logits.prefix_logits, logits.pos_logits, logits.features_logits, logits.suffix_logits, logits.suffix_features_logits
167
+
168
+ prefix_predictions = (prefix_logits > 0.5).int().tolist() # Threshold at 0.5 for multi-label classification
169
+ pos_predictions = pos_logits.argmax(axis=-1).tolist()
170
+ suffix_predictions = suffix_logits.argmax(axis=-1).tolist()
171
+ feats_predictions = [logits.argmax(axis=-1).tolist() for logits in feats_logits]
172
+ suffix_feats_predictions = [logits.argmax(axis=-1).tolist() for logits in suffix_feats_logits]
173
+
174
+ # create the return dictionary
175
+ # for each sentence, return a dict object with the following files { text, tokens }
176
+ # Where tokens is a list of dicts, where each dict is:
177
+ # { pos: str, feats: dict, prefixes: List[str], suffix: str | bool, suffix_feats: dict | None}
178
+ special_toks = tokenizer.all_special_tokens
179
+ ret = []
180
+ for sent_idx,sentence in enumerate(sentences):
181
+ input_id_strs = tokenizer.convert_ids_to_tokens(input_ids[sent_idx])
182
+ # iterate through each token in the sentence, ignoring special tokens
183
+ tokens = []
184
+ for token_idx,token_str in enumerate(input_id_strs):
185
+ if token_str in special_toks: continue
186
+ if token_str.startswith('##'):
187
+ tokens[-1]['token'] += token_str[2:]
188
+ continue
189
+ tokens.append(dict(
190
+ token=token_str,
191
+ pos=ALL_POS[pos_predictions[sent_idx][token_idx]],
192
+ feats=get_features_dict_from_predictions(feats_predictions, (sent_idx, token_idx)),
193
+ prefixes=[ALL_PREFIX_POS[idx] for idx,i in enumerate(prefix_predictions[sent_idx][token_idx]) if i > 0],
194
+ suffix=get_suffix_or_false(ALL_SUFFIX_POS[suffix_predictions[sent_idx][token_idx]]),
195
+ ))
196
+ if tokens[-1]['suffix']:
197
+ tokens[-1]['suffix_feats'] = get_features_dict_from_predictions(suffix_feats_predictions, (sent_idx, token_idx))
198
+ ret.append(dict(text=sentence, tokens=tokens))
199
+ return ret
200
+
201
+ def get_suffix_or_false(suffix):
202
+ return False if suffix == 'none' else suffix
203
+
204
+ def get_features_dict_from_predictions(predictions, idx):
205
+ ret = {}
206
+ for (feat_idx, (feat_name, feat_values)) in enumerate(ALL_FEATURES):
207
+ val = feat_values[predictions[feat_idx][idx[0]][idx[1]]]
208
+ if val != 'none':
209
+ ret[feat_name] = val
210
+ return ret
211
+
212
+
BertForPrefixMarking.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.utils import ModelOutput
2
+ import torch
3
+ from torch import nn
4
+ from typing import Dict, List, Tuple, Optional
5
+ from dataclasses import dataclass
6
+ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
7
+
8
+ # define the classes, and the possible prefixes for each class
9
+ POSSIBLE_PREFIX_CLASSES = [ ['לכש', 'כש', 'מש', 'בש', 'לש'], ['מ'], ['ש'], ['ה'], ['ו'], ['כ'], ['ל'], ['ב'] ]
10
+ # map each individual prefix to it's class number
11
+ PREFIXES_TO_CLASS = {w:i for i,l in enumerate(POSSIBLE_PREFIX_CLASSES) for w in l}
12
+ # keep a list of all the prefixes, sorted by length, so that we can decompose
13
+ # a given prefixes and figure out the classes
14
+ ALL_PREFIX_ITEMS = list(sorted(PREFIXES_TO_CLASS.keys(), key=len, reverse=True))
15
+ TOTAL_POSSIBLE_PREFIX_CLASSES = len(POSSIBLE_PREFIX_CLASSES)
16
+
17
+ def get_prefixes_from_str(s, greedy=False):
18
+ # keep trimming prefixes from the string
19
+ while len(s) > 0 and s[0] in PREFIXES_TO_CLASS:
20
+ # find the longest string to trim
21
+ next_pre = next((pre for pre in ALL_PREFIX_ITEMS if s.startswith(pre)), None)
22
+ if next_pre is None:
23
+ return
24
+ yield next_pre
25
+ # if the chosen prefix is more than one letter, there is always an option that the
26
+ # prefix is actually just the first letter of the prefix - so offer that up as a valid prefix
27
+ # as well. We will still jump to the length of the longer one, since if the next two/three
28
+ # letters are a prefix, they have to be the longest one
29
+ if not greedy and len(next_pre) > 1:
30
+ yield next_pre[0]
31
+ s = s[len(next_pre):]
32
+
33
+ def get_prefix_classes_from_str(s, greedy=False):
34
+ for pre in get_prefixes_from_str(s, greedy):
35
+ yield PREFIXES_TO_CLASS[pre]
36
+
37
+ @dataclass
38
+ class PrefixesClassifiersOutput(ModelOutput):
39
+ loss: Optional[torch.FloatTensor] = None
40
+ logits: Optional[torch.FloatTensor] = None
41
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
42
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
43
+
44
+ class BertPrefixMarkingHead(nn.Module):
45
+ def __init__(self, config) -> None:
46
+ super().__init__()
47
+ self.config = config
48
+
49
+ # an embedding table containing an embedding for each prefix class + 1 for NONE
50
+ # we will concatenate either the embedding/NONE for each class - and we want the concatenate
51
+ # size to be the hidden_size
52
+ prefix_class_embed = config.hidden_size // TOTAL_POSSIBLE_PREFIX_CLASSES
53
+ self.prefix_class_embeddings = nn.Embedding(TOTAL_POSSIBLE_PREFIX_CLASSES + 1, prefix_class_embed)
54
+
55
+ # one layer for transformation, apply an activation, then another N classifiers for each prefix class
56
+ self.transform = nn.Linear(config.hidden_size + prefix_class_embed * TOTAL_POSSIBLE_PREFIX_CLASSES, config.hidden_size)
57
+ self.activation = nn.Tanh()
58
+ self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, 2) for _ in range(TOTAL_POSSIBLE_PREFIX_CLASSES)])
59
+
60
+ def forward(
61
+ self,
62
+ hidden_states: torch.Tensor,
63
+ prefix_class_id_options: torch.Tensor,
64
+ labels: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
65
+
66
+ # encode the prefix_class_id_options
67
+ # If input_ids is batch x seq_len
68
+ # Then sequence_output is batch x seq_len x hidden_dim
69
+ # So prefix_class_id_options is batch x seq_len x TOTAL_POSSIBLE_PREFIX_CLASSES
70
+ # Looking up the embeddings should give us batch x seq_len x TOTAL_POSSIBLE_PREFIX_CLASSES x hidden_dim / N
71
+ possible_class_embed = self.prefix_class_embeddings(prefix_class_id_options)
72
+ # then flatten the final dimension - now we have batch x seq_len x hidden_dim_2
73
+ possible_class_embed = possible_class_embed.reshape(possible_class_embed.shape[:-2] + (-1,))
74
+
75
+ # concatenate the new class embed into the sequence output before the transform
76
+ pre_transform_output = torch.cat((hidden_states, possible_class_embed), dim=-1) # batch x seq_len x (hidden_dim + hidden_dim_2)
77
+ pre_logits_output = self.activation(self.transform(pre_transform_output))# batch x seq_len x hidden_dim
78
+
79
+ # run each of the classifiers on the transformed output
80
+ logits = torch.cat([cls(pre_logits_output).unsqueeze(-2) for cls in self.classifiers], dim=-2)
81
+
82
+ loss = None
83
+ if labels is not None:
84
+ loss_fct = nn.CrossEntropyLoss()
85
+ loss = loss_fct(logits.view(-1, 2), labels.view(-1))
86
+
87
+ return (loss, logits)
88
+
89
+
90
+
91
+ class BertForPrefixMarking(BertPreTrainedModel):
92
+
93
+ def __init__(self, config):
94
+ super().__init__(config)
95
+
96
+ self.bert = BertModel(config, add_pooling_layer=False)
97
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
98
+ self.prefix = BertPrefixMarkingHead(config)
99
+
100
+ # Initialize weights and apply final processing
101
+ self.post_init()
102
+
103
+ def forward(
104
+ self,
105
+ input_ids: Optional[torch.Tensor] = None,
106
+ attention_mask: Optional[torch.Tensor] = None,
107
+ token_type_ids: Optional[torch.Tensor] = None,
108
+ prefix_class_id_options: Optional[torch.Tensor] = None,
109
+ position_ids: Optional[torch.Tensor] = None,
110
+ labels: Optional[torch.Tensor] = None,
111
+ head_mask: Optional[torch.Tensor] = None,
112
+ inputs_embeds: Optional[torch.Tensor] = None,
113
+ output_attentions: Optional[bool] = None,
114
+ output_hidden_states: Optional[bool] = None,
115
+ return_dict: Optional[bool] = None,
116
+ ):
117
+ r"""
118
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
119
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
120
+ """
121
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
122
+
123
+ bert_outputs = self.bert(
124
+ input_ids,
125
+ attention_mask=attention_mask,
126
+ token_type_ids=token_type_ids,
127
+ position_ids=position_ids,
128
+ head_mask=head_mask,
129
+ inputs_embeds=inputs_embeds,
130
+ output_attentions=output_attentions,
131
+ output_hidden_states=output_hidden_states,
132
+ return_dict=return_dict,
133
+ )
134
+
135
+ hidden_states = bert_outputs[0]
136
+ hidden_states = self.dropout(hidden_states)
137
+
138
+ loss, logits = self.prefix.forward(hidden_states, prefix_class_id_options, labels)
139
+ if not return_dict:
140
+ return (loss,logits,) + bert_outputs[2:]
141
+
142
+ return PrefixesClassifiersOutput(
143
+ loss=loss,
144
+ logits=logits,
145
+ hidden_states=bert_outputs.hidden_states,
146
+ attentions=bert_outputs.attentions,
147
+ )
148
+
149
+ def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
150
+ # step 1: encode the sentences through using the tokenizer, and get the input tensors + prefix id tensors
151
+ inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
152
+ inputs.pop('offset_mapping')
153
+ inputs = {k:v.to(self.device) for k,v in inputs.items()}
154
+
155
+ # run through bert
156
+ logits = self.forward(**inputs, return_dict=True).logits
157
+ return parse_logits(inputs['input_ids'].tolist(), sentences, tokenizer, logits)
158
+
159
+ def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.FloatTensor):
160
+ # extract the predictions by argmaxing the final dimension (batch x sequence x prefixes x prediction)
161
+ logit_preds = torch.argmax(logits, axis=3).tolist()
162
+
163
+ ret = []
164
+
165
+ for sent_idx,sent_ids in enumerate(input_ids):
166
+ tokens = tokenizer.convert_ids_to_tokens(sent_ids)
167
+ ret.append([])
168
+ for tok_idx,token in enumerate(tokens):
169
+ # If we've reached the pad token, then we are at the end
170
+ if token == tokenizer.pad_token: continue
171
+ if token.startswith('##'): continue
172
+
173
+ # combine the next tokens in? only if it's a breakup
174
+ next_tok_idx = tok_idx + 1
175
+ while next_tok_idx < len(tokens) and tokens[next_tok_idx].startswith('##'):
176
+ token += tokens[next_tok_idx][2:]
177
+ next_tok_idx += 1
178
+
179
+ prefix_len = get_predicted_prefix_len_from_logits(token, logit_preds[sent_idx][tok_idx])
180
+
181
+ if not prefix_len:
182
+ ret[-1].append([token])
183
+ else:
184
+ ret[-1].append([token[:prefix_len], token[prefix_len:]])
185
+ return ret
186
+
187
+ def encode_sentences_for_bert_for_prefix_marking(tokenizer: BertTokenizerFast, sentences: List[str], padding='longest', truncation=True):
188
+ inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_offsets_mapping=True, return_tensors='pt')
189
+ # create our prefix_id_options array which will be like the input ids shape but with an addtional
190
+ # dimension containing for each prefix whether it can be for that word
191
+ prefix_id_options = torch.full(inputs['input_ids'].shape + (TOTAL_POSSIBLE_PREFIX_CLASSES,), TOTAL_POSSIBLE_PREFIX_CLASSES, dtype=torch.long)
192
+
193
+ # go through each token, and fill in the vector accordingly
194
+ for sent_idx, sent_ids in enumerate(inputs['input_ids']):
195
+ tokens = tokenizer.convert_ids_to_tokens(sent_ids)
196
+ for tok_idx, token in enumerate(tokens):
197
+ # if the first letter isn't a valid prefix letter, nothing to talk about
198
+ if len(token) < 2 or not token[0] in PREFIXES_TO_CLASS: continue
199
+
200
+ # combine the next tokens in? only if it's a breakup
201
+ next_tok_idx = tok_idx + 1
202
+ while next_tok_idx < len(tokens) and tokens[next_tok_idx].startswith('##'):
203
+ token += tokens[next_tok_idx][2:]
204
+ next_tok_idx += 1
205
+
206
+ # find all the possible prefixes - and mark them as 0 (and in the possible mark it as it's value for embed lookup)
207
+ for pre_class in get_prefix_classes_from_str(token):
208
+ prefix_id_options[sent_idx, tok_idx, pre_class] = pre_class
209
+
210
+ inputs['prefix_class_id_options'] = prefix_id_options
211
+ return inputs
212
+
213
+ def get_predicted_prefix_len_from_logits(token, token_logits):
214
+ # Go through each possible prefix, and check if the prefix is yes - and if
215
+ # so increase the counter of the matched length, otherwise break out. That will solve cases
216
+ # of predicting prefix combinations that don't exist on the word.
217
+ # For example, if we have the word ושכשהלכתי and the model predict ו & כש, then we will only
218
+ # take the vuv because in order to get the כש we need the ש as well.
219
+ # Two extra items:
220
+ # 1] Don't allow the same prefix multiple times
221
+ # 2] Always check that the word starts with that prefix - otherwise it's bad
222
+ # (except for the case of multi-letter prefix, where we force the next to be last)
223
+ cur_len, skip_next, last_check, seen_prefixes = 0, False, False, set()
224
+ for prefix in get_prefixes_from_str(token):
225
+ # Are we skipping this prefix? This will be the case where we matched כש, don't allow ש
226
+ if skip_next:
227
+ skip_next = False
228
+ continue
229
+ # check for duplicate prefixes, we don't allow two of the same prefix
230
+ # if it predicted two of the same, then we will break out
231
+ if prefix in seen_prefixes: break
232
+ seen_prefixes.add(prefix)
233
+
234
+ # check if we predicted this prefix
235
+ if token_logits[PREFIXES_TO_CLASS[prefix]]:
236
+ cur_len += len(prefix)
237
+ if last_check: break
238
+ skip_next = len(prefix) > 1
239
+ # Otherwise, we predicted no. If we didn't, then this is the end of the prefix
240
+ # and time to break out. *Except* if it's a multi letter prefix, then we allow
241
+ # just the next letter - e.g., if כש doesn't match, then we allow כ, but then we know
242
+ # the word continues with a ש, and if it's not כש, then it's not כ-ש- (invalid)
243
+ elif len(prefix) > 1:
244
+ last_check = True
245
+ else:
246
+ break
247
+
248
+ return cur_len
BertForSyntaxParsing.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from transformers.utils import ModelOutput
3
+ import torch
4
+ from torch import nn
5
+ from typing import Dict, List, Tuple, Optional, Union
6
+ from dataclasses import dataclass
7
+ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
8
+
9
+ ALL_FUNCTION_LABELS = ["nsubj", "nsubj:cop", "punct", "mark", "mark:q", "case", "case:gen", "case:acc", "fixed", "obl", "det", "amod", "acl:relcl", "nmod", "cc", "conj", "root", "compound:smixut", "cop", "compound:affix", "advmod", "nummod", "appos", "nsubj:pass", "nmod:poss", "xcomp", "obj", "aux", "parataxis", "advcl", "ccomp", "csubj", "acl", "obl:tmod", "csubj:pass", "dep", "dislocated", "nmod:tmod", "nmod:npmod", "flat", "obl:npmod", "goeswith", "reparandum", "orphan", "list", "discourse", "iobj", "vocative", "expl", "flat:name"]
10
+
11
+ @dataclass
12
+ class SyntaxLogitsOutput(ModelOutput):
13
+ dependency_logits: torch.FloatTensor = None
14
+ function_logits: torch.FloatTensor = None
15
+ dependency_head_indices: torch.LongTensor = None
16
+
17
+ def detach(self):
18
+ return SyntaxTaggingOutput(self.dependency_logits.detach(), self.function_logits.detach(), self.dependency_head_indices.detach())
19
+
20
+ @dataclass
21
+ class SyntaxTaggingOutput(ModelOutput):
22
+ loss: Optional[torch.FloatTensor] = None
23
+ logits: Optional[SyntaxLogitsOutput] = None
24
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
25
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
26
+
27
+ @dataclass
28
+ class SyntaxLabels(ModelOutput):
29
+ dependency_labels: Optional[torch.LongTensor] = None
30
+ function_labels: Optional[torch.LongTensor] = None
31
+
32
+ def detach(self):
33
+ return SyntaxLabels(self.dependency_labels.detach(), self.function_labels.detach())
34
+
35
+ def to(self, device):
36
+ return SyntaxLabels(self.dependency_labels.to(device), self.function_labels.to(device))
37
+
38
+ class BertSyntaxParsingHead(nn.Module):
39
+ def __init__(self, config):
40
+ super().__init__()
41
+ self.config = config
42
+
43
+ # the attention query & key values
44
+ self.head_size = config.syntax_head_size# int(config.hidden_size / config.num_attention_heads * 2)
45
+ self.query = nn.Linear(config.hidden_size, self.head_size)
46
+ self.key = nn.Linear(config.hidden_size, self.head_size)
47
+ # the function classifier gets two encoding values and predicts the labels
48
+ self.num_function_classes = len(ALL_FUNCTION_LABELS)
49
+ self.cls = nn.Linear(config.hidden_size * 2, self.num_function_classes)
50
+
51
+ def forward(
52
+ self,
53
+ hidden_states: torch.Tensor,
54
+ extended_attention_mask: Optional[torch.Tensor],
55
+ labels: Optional[SyntaxLabels] = None,
56
+ compute_mst: bool = False) -> Tuple[torch.Tensor, SyntaxLogitsOutput]:
57
+
58
+ # Take the dot product between "query" and "key" to get the raw attention scores.
59
+ query_layer = self.query(hidden_states)
60
+ key_layer = self.key(hidden_states)
61
+ attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / math.sqrt(self.head_size)
62
+
63
+ # add in the attention mask
64
+ if extended_attention_mask is not None:
65
+ if extended_attention_mask.ndim == 4:
66
+ extended_attention_mask = extended_attention_mask.squeeze(1)
67
+ attention_scores += extended_attention_mask# batch x seq x seq
68
+
69
+ # At this point take the hidden_state of the word and of the dependency word, and predict the function
70
+ # If labels are provided, use the labels.
71
+ if self.training and labels is not None:
72
+ # Note that the labels can have -100, so just set those to zero with a max
73
+ dep_indices = labels.dependency_labels.clamp_min(0)
74
+ # Otherwise - check if he wants the MST or just the argmax
75
+ elif compute_mst:
76
+ dep_indices = compute_mst_tree(attention_scores, extended_attention_mask)
77
+ else:
78
+ dep_indices = torch.argmax(attention_scores, dim=-1)
79
+
80
+ # After we retrieved the dependency indicies, create a tensor of teh batch indices, and and retrieve the vectors of the heads to calculate the function
81
+ batch_indices = torch.arange(dep_indices.size(0)).view(-1, 1).expand(-1, dep_indices.size(1)).to(dep_indices.device)
82
+ dep_vectors = hidden_states[batch_indices, dep_indices, :] # batch x seq x dim
83
+
84
+ # concatenate that with the last hidden states, and send to the classifier output
85
+ cls_inputs = torch.cat((hidden_states, dep_vectors), dim=-1)
86
+ function_logits = self.cls(cls_inputs)
87
+
88
+ loss = None
89
+ if labels is not None:
90
+ loss_fct = nn.CrossEntropyLoss()
91
+ # step 1: dependency scores loss - this is applied to the attention scores
92
+ loss = loss_fct(attention_scores.view(-1, hidden_states.size(-2)), labels.dependency_labels.view(-1))
93
+ # step 2: function loss
94
+ loss += loss_fct(function_logits.view(-1, self.num_function_classes), labels.function_labels.view(-1))
95
+
96
+ return (loss, SyntaxLogitsOutput(attention_scores, function_logits, dep_indices))
97
+
98
+
99
+ class BertForSyntaxParsing(BertPreTrainedModel):
100
+
101
+ def __init__(self, config):
102
+ super().__init__(config)
103
+
104
+ self.bert = BertModel(config, add_pooling_layer=False)
105
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
106
+ self.syntax = BertSyntaxParsingHead(config)
107
+
108
+ # Initialize weights and apply final processing
109
+ self.post_init()
110
+
111
+ def forward(
112
+ self,
113
+ input_ids: Optional[torch.Tensor] = None,
114
+ attention_mask: Optional[torch.Tensor] = None,
115
+ token_type_ids: Optional[torch.Tensor] = None,
116
+ position_ids: Optional[torch.Tensor] = None,
117
+ labels: Optional[SyntaxLabels] = None,
118
+ head_mask: Optional[torch.Tensor] = None,
119
+ inputs_embeds: Optional[torch.Tensor] = None,
120
+ output_attentions: Optional[bool] = None,
121
+ output_hidden_states: Optional[bool] = None,
122
+ return_dict: Optional[bool] = None,
123
+ compute_syntax_mst: Optional[bool] = None,
124
+ ):
125
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
126
+
127
+ bert_outputs = self.bert(
128
+ input_ids,
129
+ attention_mask=attention_mask,
130
+ token_type_ids=token_type_ids,
131
+ position_ids=position_ids,
132
+ head_mask=head_mask,
133
+ inputs_embeds=inputs_embeds,
134
+ output_attentions=output_attentions,
135
+ output_hidden_states=output_hidden_states,
136
+ return_dict=return_dict,
137
+ )
138
+
139
+ extended_attention_mask = None
140
+ if attention_mask is not None:
141
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.size())
142
+ # apply the syntax head
143
+ loss, logits = self.syntax(self.dropout(bert_outputs[0]), extended_attention_mask, labels, compute_syntax_mst)
144
+
145
+ if not return_dict:
146
+ return (loss,(logits.dependency_logits, logits.function_logits)) + bert_outputs[2:]
147
+
148
+ return SyntaxTaggingOutput(
149
+ loss=loss,
150
+ logits=logits,
151
+ hidden_states=bert_outputs.hidden_states,
152
+ attentions=bert_outputs.attentions,
153
+ )
154
+
155
+ def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, compute_mst=True):
156
+ if isinstance(sentences, str):
157
+ sentences = [sentences]
158
+
159
+ # predict the logits for the sentence
160
+ inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt')
161
+ inputs = {k:v.to(self.device) for k,v in inputs.items()}
162
+ logits = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_mst).logits
163
+ return parse_logits(inputs['input_ids'].tolist(), sentences, tokenizer, logits)
164
+
165
+ def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: SyntaxLogitsOutput):
166
+ outputs = []
167
+
168
+ special_toks = tokenizer.all_special_tokens
169
+ for i in range(len(sentences)):
170
+ deps = logits.dependency_head_indices[i].tolist()
171
+ funcs = logits.function_logits.argmax(-1)[i].tolist()
172
+ toks = [tok for tok in tokenizer.convert_ids_to_tokens(input_ids[i]) if tok not in special_toks]
173
+
174
+ # first, go through the tokens and create a mapping between each dependency index and the index without wordpieces
175
+ # wordpieces. At the same time, append the wordpieces in
176
+ idx_mapping = {-1:-1} # default root
177
+ real_idx = -1
178
+ for i in range(len(toks)):
179
+ if not toks[i].startswith('##'):
180
+ real_idx += 1
181
+ idx_mapping[i] = real_idx
182
+
183
+ # build our tree, keeping tracking of the root idx
184
+ tree = []
185
+ root_idx = 0
186
+ for i in range(len(toks)):
187
+ if toks[i].startswith('##'):
188
+ tree[-1]['word'] += toks[i][2:]
189
+ continue
190
+
191
+ dep_idx = deps[i + 1] - 1 # increase 1 for cls, decrease 1 for cls
192
+ if dep_idx == len(toks): dep_idx = i - 1 # if he predicts sep, then just point to the previous word
193
+
194
+ dep_head = 'root' if dep_idx == -1 else toks[dep_idx]
195
+ dep_func = ALL_FUNCTION_LABELS[funcs[i + 1]]
196
+
197
+ if dep_head == 'root': root_idx = len(tree)
198
+ tree.append(dict(word=toks[i], dep_head_idx=idx_mapping[dep_idx], dep_func=dep_func))
199
+ # append the head word
200
+ for d in tree:
201
+ d['dep_head'] = tree[d['dep_head_idx']]['word']
202
+
203
+ outputs.append(dict(tree=tree, root_idx=root_idx))
204
+ return outputs
205
+
206
+
207
+ def compute_mst_tree(attention_scores: torch.Tensor, extended_attention_mask: torch.LongTensor):
208
+ # attention scores should be 3 dimensions - batch x seq x seq (if it is 2 - just unsqueeze)
209
+ if attention_scores.ndim == 2: attention_scores = attention_scores.unsqueeze(0)
210
+ if attention_scores.ndim != 3 or attention_scores.shape[1] != attention_scores.shape[2]:
211
+ raise ValueError(f'Expected attention scores to be of shape batch x seq x seq, instead got {attention_scores.shape}')
212
+
213
+ batch_size, seq_len, _ = attention_scores.shape
214
+ # start by softmaxing so the scores are comparable
215
+ attention_scores = attention_scores.softmax(dim=-1)
216
+
217
+ batch_indices = torch.arange(batch_size, device=attention_scores.device)
218
+ seq_indices = torch.arange(seq_len, device=attention_scores.device)
219
+
220
+ seq_lens = torch.full((batch_size,), seq_len)
221
+
222
+ if extended_attention_mask is not None:
223
+ seq_lens = torch.argmax((extended_attention_mask != 0).int(), dim=2).squeeze(1)
224
+ # zero out any padding
225
+ attention_scores[extended_attention_mask.squeeze(1) != 0] = 0
226
+
227
+ # set the values for the CLS and sep to all by very low, so they never get chosen as a replacement arc
228
+ attention_scores[:, 0, :] = 0
229
+ attention_scores[batch_indices, seq_lens - 1, :] = 0
230
+ attention_scores[batch_indices, :, seq_lens - 1] = 0 # can never predict sep
231
+ # set the values for each token pointing to itself be 0
232
+ attention_scores[:, seq_indices, seq_indices] = 0
233
+
234
+ # find the root, and make him super high so we never have a conflict
235
+ root_cands = torch.argsort(attention_scores[:, :, 0], dim=-1)
236
+ attention_scores[batch_indices.unsqueeze(1), root_cands, 0] = 0
237
+ attention_scores[batch_indices, root_cands[:, -1], 0] = 1.0
238
+
239
+ # we start by getting the argmax for each score, and then computing the cycles and contracting them
240
+ sorted_indices = torch.argsort(attention_scores, dim=-1, descending=True)
241
+ indices = sorted_indices[:, :, 0].clone() # take the argmax
242
+
243
+ attention_scores = attention_scores.tolist()
244
+ seq_lens = seq_lens.tolist()
245
+ sorted_indices = [[sub_l[:slen] for sub_l in l[:slen]] for l,slen in zip(sorted_indices.tolist(), seq_lens)]
246
+
247
+
248
+ # go through each batch item and make sure our tree works
249
+ for batch_idx in range(batch_size):
250
+ # We have one root - detect the cycles and contract them. A cycle can never contain the root so really
251
+ # for every cycle, we look at all the nodes, and find the highest arc out of the cycle for any values. Replace that and tada
252
+ has_cycle, cycle_nodes = detect_cycle(indices[batch_idx], seq_lens[batch_idx])
253
+ contracted_arcs = set()
254
+ while has_cycle:
255
+ base_idx, head_idx = choose_contracting_arc(indices[batch_idx], sorted_indices[batch_idx], cycle_nodes, contracted_arcs, seq_lens[batch_idx], attention_scores[batch_idx])
256
+ indices[batch_idx, base_idx] = head_idx
257
+ contracted_arcs.add(base_idx)
258
+ # find the next cycle
259
+ has_cycle, cycle_nodes = detect_cycle(indices[batch_idx], seq_lens[batch_idx])
260
+
261
+ return indices
262
+
263
+ def detect_cycle(indices: torch.LongTensor, seq_len: int):
264
+ # Simple cycle detection algorithm
265
+ # Returns a boolean indicating if a cycle is detected and the nodes involved in the cycle
266
+ visited = set()
267
+ for node in range(1, seq_len - 1): # ignore the CLS/SEP tokens
268
+ if node in visited:
269
+ continue
270
+ current_path = set()
271
+ while node not in visited:
272
+ visited.add(node)
273
+ current_path.add(node)
274
+ node = indices[node].item()
275
+ if node == 0: break # roots never point to anything
276
+ if node in current_path:
277
+ return True, current_path # Cycle detected
278
+ return False, None
279
+
280
+ def choose_contracting_arc(indices: torch.LongTensor, sorted_indices: List[List[int]], cycle_nodes: set, contracted_arcs: set, seq_len: int, scores: List[List[float]]):
281
+ # Chooses the highest-scoring, non-cycling arc from a graph. Iterates through 'cycle_nodes' to find
282
+ # the best arc based on 'scores', avoiding cycles and zero node connections.
283
+ # For each node, we only look at the next highest scoring non-cycling arc
284
+ best_base_idx, best_head_idx = -1, -1
285
+ score = 0
286
+
287
+ # convert the indices to a list once, to avoid multiple conversions (saves a few seconds)
288
+ currents = indices.tolist()
289
+ for base_node in cycle_nodes:
290
+ if base_node in contracted_arcs: continue
291
+ # we don't want to take anything that has a higher score than the current value - we can end up in an endless loop
292
+ # Since the indices are sorted, as soon as we find our current item, we can move on to the next.
293
+ current = currents[base_node]
294
+ found_current = False
295
+
296
+ for head_node in sorted_indices[base_node]:
297
+ if head_node == current:
298
+ found_current = True
299
+ continue
300
+ if head_node in contracted_arcs: continue
301
+ if not found_current or head_node in cycle_nodes or head_node == 0:
302
+ continue
303
+
304
+ current_score = scores[base_node][head_node]
305
+ if current_score > score:
306
+ best_base_idx, best_head_idx, score = base_node, head_node, current_score
307
+ break
308
+
309
+ if best_base_idx == -1:
310
+ raise ValueError('Stuck in endless loop trying to compute syntax mst. Please try again setting compute_syntax_mst=False')
311
+
312
+ return best_base_idx, best_head_idx