Shaltiel commited on
Commit
c397e3d
โ€ข
1 Parent(s): 406f5ef

Upload 2 files

Browse files
Files changed (2) hide show
  1. BertForJointParsing.py +26 -12
  2. BertForSyntaxParsing.py +1 -1
BertForJointParsing.py CHANGED
@@ -186,13 +186,15 @@ class BertForJointParsing(BertPreTrainedModel):
186
  morph_logits=morph_logits
187
  )
188
 
189
- def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, as_iahlt_ud=False, as_htb_ud=False):
190
  is_single_sentence = isinstance(sentences, str)
191
  if is_single_sentence:
192
  sentences = [sentences]
193
 
194
- if (as_htb_ud or as_iahlt_ud) and (self.prefix is None or self.morph is None or self.syntax is None or self.lex is None):
195
- raise ValueError("Cannot output UD format when any of the prefix,morph,syntax,lex heads aren't loaded.")
 
 
196
 
197
  # predict the logits for the sentence
198
  if self.prefix is not None:
@@ -233,8 +235,8 @@ class BertForJointParsing(BertPreTrainedModel):
233
  merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
234
  final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
235
 
236
- if as_iahlt_ud or as_htb_ud:
237
- final_output = convert_output_to_ud(final_output, htb_extras=as_htb_ud)
238
 
239
  if is_single_sentence:
240
  final_output = final_output[0]
@@ -339,7 +341,10 @@ ud_suffix_to_htb_str = {
339
  'Gender=Fem|Number=Sing|Person=2': '_ืืช',
340
  'Gender=Masc|Number=Plur|Person=3': '_ื”ื'
341
  }
342
- def convert_output_to_ud(output_sentences, htb_extras=False):
 
 
 
343
  final_output = []
344
  for sent_idx, sentence in enumerate(output_sentences):
345
  # next, go through each word and insert it in the UD format. Store in a temp format for the post process
@@ -363,9 +368,9 @@ def convert_output_to_ud(output_sentences, htb_extras=False):
363
 
364
  # if there was an implicit heh, add it in dependent on the method
365
  if not 'ื”' in pre and intermediate_output[-1]['pos'] == 'ADP' and 'DET' in word['morph']['prefixes']:
366
- if htb_extras:
367
  intermediate_output.append(dict(word='ื”_', lex='ื”', pos='DET', dep=word_idx, func='det', feats='_'))
368
- else:
369
  intermediate_output[-1]['feats'] = 'Definite=Def|PronType=Art'
370
 
371
 
@@ -394,7 +399,7 @@ def convert_output_to_ud(output_sentences, htb_extras=False):
394
  s_word, s_lex = word['seg'][-1], word['lex']
395
  # update the word of the string and extract the string of the suffix!
396
  # for IAHLT:
397
- if not htb_extras:
398
  # we need to shorten the main word and extract the suffix
399
  # if it is longer than the lexeme - just take off the lexeme.
400
  if len(s_word) > len(s_lex):
@@ -407,7 +412,7 @@ def convert_output_to_ud(output_sentences, htb_extras=False):
407
  suf = s_word[idx:]
408
  intermediate_output[-1]['word'] = s_word[:idx]
409
  # for htb:
410
- else:
411
  # main word becomes the lexeme, the suffix is based on the features
412
  intermediate_output[-1]['word'] = (s_lex if s_lex != s_word else s_word[:-1]) + '_'
413
  suf_feats = word['morph']['suffix_feats']
@@ -438,6 +443,7 @@ def convert_output_to_ud(output_sentences, htb_extras=False):
438
  for idx,output in enumerate(intermediate_output[start:end], start + 1):
439
  # compute the actual dependency location
440
  dep = output['dep'] if output.get('absolute_dep', False) else idx_to_key[output['dep']]
 
441
  # and add the full ud string in
442
  cur_output.append('\t'.join([
443
  str(idx),
@@ -447,12 +453,20 @@ def convert_output_to_ud(output_sentences, htb_extras=False):
447
  output['pos'],
448
  output['feats'],
449
  str(dep),
450
- output['func'],
451
  '_', '_'
452
  ]))
453
  return final_output
454
 
455
-
 
 
 
 
 
 
 
 
456
  def ud_get_prefix_dep(pre, word, word_idx):
457
  does_follow_main = False
458
 
 
186
  morph_logits=morph_logits
187
  )
188
 
189
+ def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, output_style: Literal['json', 'ud', 'iahlt_ud'] = 'json'):
190
  is_single_sentence = isinstance(sentences, str)
191
  if is_single_sentence:
192
  sentences = [sentences]
193
 
194
+ if output_style not in ['json', 'ud', 'iahlt_ud']:
195
+ raise ValueError('output_style must be in json/ud/iahlt_ud')
196
+ if output_style in ['ud', 'iahlt_ud'] and (self.prefix is None or self.morph is None or self.syntax is None or self.lex is None):
197
+ raise ValueError("Cannot output UD format when any of the prefix,morph,syntax, and lex heads aren't loaded.")
198
 
199
  # predict the logits for the sentence
200
  if self.prefix is not None:
 
235
  merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
236
  final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
237
 
238
+ if output_style in ['ud', 'iahlt_ud']:
239
+ final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
240
 
241
  if is_single_sentence:
242
  final_output = final_output[0]
 
341
  'Gender=Fem|Number=Sing|Person=2': '_ืืช',
342
  'Gender=Masc|Number=Plur|Person=3': '_ื”ื'
343
  }
344
+ def convert_output_to_ud(output_sentences, style: Literal['htb', 'iahlt']):
345
+ if style not in ['htb', 'iahlt']:
346
+ raise ValueError('style must be htb/iahlt')
347
+
348
  final_output = []
349
  for sent_idx, sentence in enumerate(output_sentences):
350
  # next, go through each word and insert it in the UD format. Store in a temp format for the post process
 
368
 
369
  # if there was an implicit heh, add it in dependent on the method
370
  if not 'ื”' in pre and intermediate_output[-1]['pos'] == 'ADP' and 'DET' in word['morph']['prefixes']:
371
+ if style == 'htb':
372
  intermediate_output.append(dict(word='ื”_', lex='ื”', pos='DET', dep=word_idx, func='det', feats='_'))
373
+ elif style == 'iahlt':
374
  intermediate_output[-1]['feats'] = 'Definite=Def|PronType=Art'
375
 
376
 
 
399
  s_word, s_lex = word['seg'][-1], word['lex']
400
  # update the word of the string and extract the string of the suffix!
401
  # for IAHLT:
402
+ if style == 'iahlt':
403
  # we need to shorten the main word and extract the suffix
404
  # if it is longer than the lexeme - just take off the lexeme.
405
  if len(s_word) > len(s_lex):
 
412
  suf = s_word[idx:]
413
  intermediate_output[-1]['word'] = s_word[:idx]
414
  # for htb:
415
+ elif style == 'htb':
416
  # main word becomes the lexeme, the suffix is based on the features
417
  intermediate_output[-1]['word'] = (s_lex if s_lex != s_word else s_word[:-1]) + '_'
418
  suf_feats = word['morph']['suffix_feats']
 
443
  for idx,output in enumerate(intermediate_output[start:end], start + 1):
444
  # compute the actual dependency location
445
  dep = output['dep'] if output.get('absolute_dep', False) else idx_to_key[output['dep']]
446
+ func = normalize_dep_rel(output['func'], style)
447
  # and add the full ud string in
448
  cur_output.append('\t'.join([
449
  str(idx),
 
453
  output['pos'],
454
  output['feats'],
455
  str(dep),
456
+ func,
457
  '_', '_'
458
  ]))
459
  return final_output
460
 
461
+ def normalize_dep_rel(dep, style: Literal['htb', 'iahlt']):
462
+ if style == 'iahlt':
463
+ if dep == 'compound:smixut': return 'compound'
464
+ if dep == 'nsubj:cop': return 'nsubj'
465
+ if dep == 'mark:q': return 'mark'
466
+ if dep == 'case:gen' or dep == 'case:acc': return 'case'
467
+ return dep
468
+
469
+
470
  def ud_get_prefix_dep(pre, word, word_idx):
471
  does_follow_main = False
472
 
BertForSyntaxParsing.py CHANGED
@@ -6,7 +6,7 @@ from typing import Dict, List, Tuple, Optional, Union
6
  from dataclasses import dataclass
7
  from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
8
 
9
- ALL_FUNCTION_LABELS = ["nsubj", "punct", "mark", "case", "fixed", "obl", "det", "amod", "acl:relcl", "nmod", "cc", "conj", "root", "compound", "cop", "compound:affix", "advmod", "nummod", "appos", "nsubj:pass", "nmod:poss", "xcomp", "obj", "aux", "parataxis", "advcl", "ccomp", "csubj", "acl", "obl:tmod", "csubj:pass", "dep", "dislocated", "nmod:tmod", "nmod:npmod", "flat", "obl:npmod", "goeswith", "reparandum", "orphan", "list", "discourse", "iobj", "vocative", "expl", "flat:name"]
10
 
11
  @dataclass
12
  class SyntaxLogitsOutput(ModelOutput):
 
6
  from dataclasses import dataclass
7
  from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
8
 
9
+ ALL_FUNCTION_LABELS = ["nsubj", "nsubj:cop", "punct", "mark", "mark:q", "case", "case:gen", "case:acc", "fixed", "obl", "det", "amod", "acl:relcl", "nmod", "cc", "conj", "root", "compound:smixut", "cop", "compound:affix", "advmod", "nummod", "appos", "nsubj:pass", "nmod:poss", "xcomp", "obj", "aux", "parataxis", "advcl", "ccomp", "csubj", "acl", "obl:tmod", "csubj:pass", "dep", "dislocated", "nmod:tmod", "nmod:npmod", "flat", "obl:npmod", "goeswith", "reparandum", "orphan", "list", "discourse", "iobj", "vocative", "expl", "flat:name"]
10
 
11
  @dataclass
12
  class SyntaxLogitsOutput(ModelOutput):