Upload 2 files
Browse files- BertForJointParsing.py +26 -12
- BertForSyntaxParsing.py +1 -1
BertForJointParsing.py
CHANGED
@@ -186,13 +186,15 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
186 |
morph_logits=morph_logits
|
187 |
)
|
188 |
|
189 |
-
def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False,
|
190 |
is_single_sentence = isinstance(sentences, str)
|
191 |
if is_single_sentence:
|
192 |
sentences = [sentences]
|
193 |
|
194 |
-
if
|
195 |
-
raise ValueError(
|
|
|
|
|
196 |
|
197 |
# predict the logits for the sentence
|
198 |
if self.prefix is not None:
|
@@ -233,8 +235,8 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
233 |
merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
|
234 |
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
|
235 |
|
236 |
-
if
|
237 |
-
final_output = convert_output_to_ud(final_output,
|
238 |
|
239 |
if is_single_sentence:
|
240 |
final_output = final_output[0]
|
@@ -339,7 +341,10 @@ ud_suffix_to_htb_str = {
|
|
339 |
'Gender=Fem|Number=Sing|Person=2': '_ืืช',
|
340 |
'Gender=Masc|Number=Plur|Person=3': '_ืื'
|
341 |
}
|
342 |
-
def convert_output_to_ud(output_sentences,
|
|
|
|
|
|
|
343 |
final_output = []
|
344 |
for sent_idx, sentence in enumerate(output_sentences):
|
345 |
# next, go through each word and insert it in the UD format. Store in a temp format for the post process
|
@@ -363,9 +368,9 @@ def convert_output_to_ud(output_sentences, htb_extras=False):
|
|
363 |
|
364 |
# if there was an implicit heh, add it in dependent on the method
|
365 |
if not 'ื' in pre and intermediate_output[-1]['pos'] == 'ADP' and 'DET' in word['morph']['prefixes']:
|
366 |
-
if
|
367 |
intermediate_output.append(dict(word='ื_', lex='ื', pos='DET', dep=word_idx, func='det', feats='_'))
|
368 |
-
|
369 |
intermediate_output[-1]['feats'] = 'Definite=Def|PronType=Art'
|
370 |
|
371 |
|
@@ -394,7 +399,7 @@ def convert_output_to_ud(output_sentences, htb_extras=False):
|
|
394 |
s_word, s_lex = word['seg'][-1], word['lex']
|
395 |
# update the word of the string and extract the string of the suffix!
|
396 |
# for IAHLT:
|
397 |
-
if
|
398 |
# we need to shorten the main word and extract the suffix
|
399 |
# if it is longer than the lexeme - just take off the lexeme.
|
400 |
if len(s_word) > len(s_lex):
|
@@ -407,7 +412,7 @@ def convert_output_to_ud(output_sentences, htb_extras=False):
|
|
407 |
suf = s_word[idx:]
|
408 |
intermediate_output[-1]['word'] = s_word[:idx]
|
409 |
# for htb:
|
410 |
-
|
411 |
# main word becomes the lexeme, the suffix is based on the features
|
412 |
intermediate_output[-1]['word'] = (s_lex if s_lex != s_word else s_word[:-1]) + '_'
|
413 |
suf_feats = word['morph']['suffix_feats']
|
@@ -438,6 +443,7 @@ def convert_output_to_ud(output_sentences, htb_extras=False):
|
|
438 |
for idx,output in enumerate(intermediate_output[start:end], start + 1):
|
439 |
# compute the actual dependency location
|
440 |
dep = output['dep'] if output.get('absolute_dep', False) else idx_to_key[output['dep']]
|
|
|
441 |
# and add the full ud string in
|
442 |
cur_output.append('\t'.join([
|
443 |
str(idx),
|
@@ -447,12 +453,20 @@ def convert_output_to_ud(output_sentences, htb_extras=False):
|
|
447 |
output['pos'],
|
448 |
output['feats'],
|
449 |
str(dep),
|
450 |
-
|
451 |
'_', '_'
|
452 |
]))
|
453 |
return final_output
|
454 |
|
455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
def ud_get_prefix_dep(pre, word, word_idx):
|
457 |
does_follow_main = False
|
458 |
|
|
|
186 |
morph_logits=morph_logits
|
187 |
)
|
188 |
|
189 |
+
def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, output_style: Literal['json', 'ud', 'iahlt_ud'] = 'json'):
|
190 |
is_single_sentence = isinstance(sentences, str)
|
191 |
if is_single_sentence:
|
192 |
sentences = [sentences]
|
193 |
|
194 |
+
if output_style not in ['json', 'ud', 'iahlt_ud']:
|
195 |
+
raise ValueError('output_style must be in json/ud/iahlt_ud')
|
196 |
+
if output_style in ['ud', 'iahlt_ud'] and (self.prefix is None or self.morph is None or self.syntax is None or self.lex is None):
|
197 |
+
raise ValueError("Cannot output UD format when any of the prefix,morph,syntax, and lex heads aren't loaded.")
|
198 |
|
199 |
# predict the logits for the sentence
|
200 |
if self.prefix is not None:
|
|
|
235 |
merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
|
236 |
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
|
237 |
|
238 |
+
if output_style in ['ud', 'iahlt_ud']:
|
239 |
+
final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
|
240 |
|
241 |
if is_single_sentence:
|
242 |
final_output = final_output[0]
|
|
|
341 |
'Gender=Fem|Number=Sing|Person=2': '_ืืช',
|
342 |
'Gender=Masc|Number=Plur|Person=3': '_ืื'
|
343 |
}
|
344 |
+
def convert_output_to_ud(output_sentences, style: Literal['htb', 'iahlt']):
|
345 |
+
if style not in ['htb', 'iahlt']:
|
346 |
+
raise ValueError('style must be htb/iahlt')
|
347 |
+
|
348 |
final_output = []
|
349 |
for sent_idx, sentence in enumerate(output_sentences):
|
350 |
# next, go through each word and insert it in the UD format. Store in a temp format for the post process
|
|
|
368 |
|
369 |
# if there was an implicit heh, add it in dependent on the method
|
370 |
if not 'ื' in pre and intermediate_output[-1]['pos'] == 'ADP' and 'DET' in word['morph']['prefixes']:
|
371 |
+
if style == 'htb':
|
372 |
intermediate_output.append(dict(word='ื_', lex='ื', pos='DET', dep=word_idx, func='det', feats='_'))
|
373 |
+
elif style == 'iahlt':
|
374 |
intermediate_output[-1]['feats'] = 'Definite=Def|PronType=Art'
|
375 |
|
376 |
|
|
|
399 |
s_word, s_lex = word['seg'][-1], word['lex']
|
400 |
# update the word of the string and extract the string of the suffix!
|
401 |
# for IAHLT:
|
402 |
+
if style == 'iahlt':
|
403 |
# we need to shorten the main word and extract the suffix
|
404 |
# if it is longer than the lexeme - just take off the lexeme.
|
405 |
if len(s_word) > len(s_lex):
|
|
|
412 |
suf = s_word[idx:]
|
413 |
intermediate_output[-1]['word'] = s_word[:idx]
|
414 |
# for htb:
|
415 |
+
elif style == 'htb':
|
416 |
# main word becomes the lexeme, the suffix is based on the features
|
417 |
intermediate_output[-1]['word'] = (s_lex if s_lex != s_word else s_word[:-1]) + '_'
|
418 |
suf_feats = word['morph']['suffix_feats']
|
|
|
443 |
for idx,output in enumerate(intermediate_output[start:end], start + 1):
|
444 |
# compute the actual dependency location
|
445 |
dep = output['dep'] if output.get('absolute_dep', False) else idx_to_key[output['dep']]
|
446 |
+
func = normalize_dep_rel(output['func'], style)
|
447 |
# and add the full ud string in
|
448 |
cur_output.append('\t'.join([
|
449 |
str(idx),
|
|
|
453 |
output['pos'],
|
454 |
output['feats'],
|
455 |
str(dep),
|
456 |
+
func,
|
457 |
'_', '_'
|
458 |
]))
|
459 |
return final_output
|
460 |
|
461 |
+
def normalize_dep_rel(dep, style: Literal['htb', 'iahlt']):
|
462 |
+
if style == 'iahlt':
|
463 |
+
if dep == 'compound:smixut': return 'compound'
|
464 |
+
if dep == 'nsubj:cop': return 'nsubj'
|
465 |
+
if dep == 'mark:q': return 'mark'
|
466 |
+
if dep == 'case:gen' or dep == 'case:acc': return 'case'
|
467 |
+
return dep
|
468 |
+
|
469 |
+
|
470 |
def ud_get_prefix_dep(pre, word, word_idx):
|
471 |
does_follow_main = False
|
472 |
|
BertForSyntaxParsing.py
CHANGED
@@ -6,7 +6,7 @@ from typing import Dict, List, Tuple, Optional, Union
|
|
6 |
from dataclasses import dataclass
|
7 |
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
|
8 |
|
9 |
-
ALL_FUNCTION_LABELS = ["nsubj", "punct", "mark", "case", "fixed", "obl", "det", "amod", "acl:relcl", "nmod", "cc", "conj", "root", "compound", "cop", "compound:affix", "advmod", "nummod", "appos", "nsubj:pass", "nmod:poss", "xcomp", "obj", "aux", "parataxis", "advcl", "ccomp", "csubj", "acl", "obl:tmod", "csubj:pass", "dep", "dislocated", "nmod:tmod", "nmod:npmod", "flat", "obl:npmod", "goeswith", "reparandum", "orphan", "list", "discourse", "iobj", "vocative", "expl", "flat:name"]
|
10 |
|
11 |
@dataclass
|
12 |
class SyntaxLogitsOutput(ModelOutput):
|
|
|
6 |
from dataclasses import dataclass
|
7 |
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
|
8 |
|
9 |
+
ALL_FUNCTION_LABELS = ["nsubj", "nsubj:cop", "punct", "mark", "mark:q", "case", "case:gen", "case:acc", "fixed", "obl", "det", "amod", "acl:relcl", "nmod", "cc", "conj", "root", "compound:smixut", "cop", "compound:affix", "advmod", "nummod", "appos", "nsubj:pass", "nmod:poss", "xcomp", "obj", "aux", "parataxis", "advcl", "ccomp", "csubj", "acl", "obl:tmod", "csubj:pass", "dep", "dislocated", "nmod:tmod", "nmod:npmod", "flat", "obl:npmod", "goeswith", "reparandum", "orphan", "list", "discourse", "iobj", "vocative", "expl", "flat:name"]
|
10 |
|
11 |
@dataclass
|
12 |
class SyntaxLogitsOutput(ModelOutput):
|