Shaltiel commited on
Commit
7897807
1 Parent(s): f5f1fb9

Upload BertForSyntaxParsing.py

Browse files
Files changed (1) hide show
  1. BertForSyntaxParsing.py +38 -36
BertForSyntaxParsing.py CHANGED
@@ -2,11 +2,11 @@ import math
2
  from transformers.utils import ModelOutput
3
  import torch
4
  from torch import nn
5
- from typing import List, Tuple, Optional, Union
6
  from dataclasses import dataclass
7
  from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
8
 
9
- ALL_FUNCTION_LABELS = ["nsubj", "punct", "mark", "case", "fixed", "obl", "det", "amod", "acl:relcl", "nmod", "cc", "conj", "root", "compound", "cop", "compound:affix", "advmod", "nummod", "appos", "nsubj:pass", "nmod:poss", "xcomp", "obj", "aux", "parataxis", "advcl", "ccomp", "csubj", "acl", "obl:tmod", "csubj:pass", "dep", "dislocated", "nmod:tmod", "nmod:npmod", "flat", "obl:npmod", "goeswith", "reparandum", "orphan", "list", "discourse", "iobj", "vocative", "expl", "flat:name"]
10
 
11
  @dataclass
12
  class SyntaxLogitsOutput(ModelOutput):
@@ -160,44 +160,46 @@ class BertForSyntaxParsing(BertPreTrainedModel):
160
  inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt')
161
  inputs = {k:v.to(self.device) for k,v in inputs.items()}
162
  logits = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_mst).logits
 
163
 
164
- outputs = []
165
- for i in range(len(sentences)):
166
- deps = logits.dependency_head_indices[i].tolist()
167
- funcs = logits.function_logits.argmax(-1)[i].tolist()
168
- toks = tokenizer.convert_ids_to_tokens(inputs['input_ids'][i])[1:-1] # ignore cls and sep
 
 
 
 
 
 
 
 
 
 
169
 
170
- # first, go through the tokens and create a mapping between each dependency index and the index without wordpieces
171
- # wordpieces. At the same time, append the wordpieces in
172
- idx_mapping = {-1:-1} # default root
173
- real_idx = -1
174
- for i in range(len(toks)):
175
- if not toks[i].startswith('##'):
176
- real_idx += 1
177
- idx_mapping[i] = real_idx
178
-
179
- # build our tree, keeping tracking of the root idx
180
- tree = []
181
- root_idx = 0
182
- for i in range(len(toks)):
183
- if toks[i].startswith('##'):
184
- tree[-1]['word'] += toks[i][2:]
185
- continue
186
-
187
- dep_idx = deps[i + 1] - 1 # increase 1 for cls, decrease 1 for cls
188
- dep_head = 'root' if dep_idx == -1 else toks[dep_idx]
189
- dep_func = ALL_FUNCTION_LABELS[funcs[i + 1]]
190
-
191
- if dep_head == 'root': root_idx = len(tree)
192
- tree.append(dict(word=toks[i], dep_head_idx=idx_mapping[dep_idx], dep_func=dep_func))
193
- # append the head word
194
- for d in tree:
195
- d['dep_head'] = tree[d['dep_head_idx']]['word']
196
 
197
- outputs.append(dict(tree=tree, root_idx=root_idx))
198
- return outputs
 
 
 
 
 
 
 
 
 
 
 
199
 
200
-
201
  def compute_mst_tree(attention_scores: torch.Tensor):
202
  # attention scores should be 3 dimensions - batch x seq x seq (if it is 2 - just unsqueeze)
203
  if attention_scores.ndim == 2: attention_scores = attention_scores.unsqueeze(0)
 
2
  from transformers.utils import ModelOutput
3
  import torch
4
  from torch import nn
5
+ from typing import Dict, List, Tuple, Optional, Union
6
  from dataclasses import dataclass
7
  from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
8
 
9
+ ALL_FUNCTION_LABELS = ["nsubj", "nsubj:cop", "punct", "mark", "mark:q", "case", "case:gen", "case:acc", "fixed", "obl", "det", "amod", "acl:relcl", "nmod", "cc", "conj", "root", "compound:smixut", "cop", "compound:affix", "advmod", "nummod", "appos", "nsubj:pass", "nmod:poss", "xcomp", "obj", "aux", "parataxis", "advcl", "ccomp", "csubj", "acl", "obl:tmod", "csubj:pass", "dep", "dislocated", "nmod:tmod", "nmod:npmod", "flat", "obl:npmod", "goeswith", "reparandum", "orphan", "list", "discourse", "iobj", "vocative", "expl", "flat:name"]
10
 
11
  @dataclass
12
  class SyntaxLogitsOutput(ModelOutput):
 
160
  inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt')
161
  inputs = {k:v.to(self.device) for k,v in inputs.items()}
162
  logits = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_mst).logits
163
+ return parse_logits(inputs, sentences, tokenizer, logits)
164
 
165
+ def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: SyntaxLogitsOutput):
166
+ outputs = []
167
+ for i in range(len(sentences)):
168
+ deps = logits.dependency_head_indices[i].tolist()
169
+ funcs = logits.function_logits.argmax(-1)[i].tolist()
170
+ toks = tokenizer.convert_ids_to_tokens(inputs['input_ids'][i])[1:-1] # ignore cls and sep
171
+
172
+ # first, go through the tokens and create a mapping between each dependency index and the index without wordpieces
173
+ # wordpieces. At the same time, append the wordpieces in
174
+ idx_mapping = {-1:-1} # default root
175
+ real_idx = -1
176
+ for i in range(len(toks)):
177
+ if not toks[i].startswith('##'):
178
+ real_idx += 1
179
+ idx_mapping[i] = real_idx
180
 
181
+ # build our tree, keeping tracking of the root idx
182
+ tree = []
183
+ root_idx = 0
184
+ for i in range(len(toks)):
185
+ if toks[i].startswith('##'):
186
+ tree[-1]['word'] += toks[i][2:]
187
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
+ dep_idx = deps[i + 1] - 1 # increase 1 for cls, decrease 1 for cls
190
+ dep_head = 'root' if dep_idx == -1 else toks[dep_idx]
191
+ dep_func = ALL_FUNCTION_LABELS[funcs[i + 1]]
192
+
193
+ if dep_head == 'root': root_idx = len(tree)
194
+ tree.append(dict(word=toks[i], dep_head_idx=idx_mapping[dep_idx], dep_func=dep_func))
195
+ # append the head word
196
+ for d in tree:
197
+ d['dep_head'] = tree[d['dep_head_idx']]['word']
198
+
199
+ outputs.append(dict(tree=tree, root_idx=root_idx))
200
+ return outputs
201
+
202
 
 
203
  def compute_mst_tree(attention_scores: torch.Tensor):
204
  # attention scores should be 3 dimensions - batch x seq x seq (if it is 2 - just unsqueeze)
205
  if attention_scores.ndim == 2: attention_scores = attention_scores.unsqueeze(0)