Upload BertForSyntaxParsing.py
Browse files- BertForSyntaxParsing.py +38 -36
BertForSyntaxParsing.py
CHANGED
@@ -2,11 +2,11 @@ import math
|
|
2 |
from transformers.utils import ModelOutput
|
3 |
import torch
|
4 |
from torch import nn
|
5 |
-
from typing import List, Tuple, Optional, Union
|
6 |
from dataclasses import dataclass
|
7 |
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
|
8 |
|
9 |
-
ALL_FUNCTION_LABELS = ["nsubj", "punct", "mark", "case", "fixed", "obl", "det", "amod", "acl:relcl", "nmod", "cc", "conj", "root", "compound", "cop", "compound:affix", "advmod", "nummod", "appos", "nsubj:pass", "nmod:poss", "xcomp", "obj", "aux", "parataxis", "advcl", "ccomp", "csubj", "acl", "obl:tmod", "csubj:pass", "dep", "dislocated", "nmod:tmod", "nmod:npmod", "flat", "obl:npmod", "goeswith", "reparandum", "orphan", "list", "discourse", "iobj", "vocative", "expl", "flat:name"]
|
10 |
|
11 |
@dataclass
|
12 |
class SyntaxLogitsOutput(ModelOutput):
|
@@ -160,44 +160,46 @@ class BertForSyntaxParsing(BertPreTrainedModel):
|
|
160 |
inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt')
|
161 |
inputs = {k:v.to(self.device) for k,v in inputs.items()}
|
162 |
logits = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_mst).logits
|
|
|
163 |
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
idx_mapping[i] = real_idx
|
178 |
-
|
179 |
-
# build our tree, keeping tracking of the root idx
|
180 |
-
tree = []
|
181 |
-
root_idx = 0
|
182 |
-
for i in range(len(toks)):
|
183 |
-
if toks[i].startswith('##'):
|
184 |
-
tree[-1]['word'] += toks[i][2:]
|
185 |
-
continue
|
186 |
-
|
187 |
-
dep_idx = deps[i + 1] - 1 # increase 1 for cls, decrease 1 for cls
|
188 |
-
dep_head = 'root' if dep_idx == -1 else toks[dep_idx]
|
189 |
-
dep_func = ALL_FUNCTION_LABELS[funcs[i + 1]]
|
190 |
-
|
191 |
-
if dep_head == 'root': root_idx = len(tree)
|
192 |
-
tree.append(dict(word=toks[i], dep_head_idx=idx_mapping[dep_idx], dep_func=dep_func))
|
193 |
-
# append the head word
|
194 |
-
for d in tree:
|
195 |
-
d['dep_head'] = tree[d['dep_head_idx']]['word']
|
196 |
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
-
|
201 |
def compute_mst_tree(attention_scores: torch.Tensor):
|
202 |
# attention scores should be 3 dimensions - batch x seq x seq (if it is 2 - just unsqueeze)
|
203 |
if attention_scores.ndim == 2: attention_scores = attention_scores.unsqueeze(0)
|
|
|
2 |
from transformers.utils import ModelOutput
|
3 |
import torch
|
4 |
from torch import nn
|
5 |
+
from typing import Dict, List, Tuple, Optional, Union
|
6 |
from dataclasses import dataclass
|
7 |
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
|
8 |
|
9 |
+
ALL_FUNCTION_LABELS = ["nsubj", "nsubj:cop", "punct", "mark", "mark:q", "case", "case:gen", "case:acc", "fixed", "obl", "det", "amod", "acl:relcl", "nmod", "cc", "conj", "root", "compound:smixut", "cop", "compound:affix", "advmod", "nummod", "appos", "nsubj:pass", "nmod:poss", "xcomp", "obj", "aux", "parataxis", "advcl", "ccomp", "csubj", "acl", "obl:tmod", "csubj:pass", "dep", "dislocated", "nmod:tmod", "nmod:npmod", "flat", "obl:npmod", "goeswith", "reparandum", "orphan", "list", "discourse", "iobj", "vocative", "expl", "flat:name"]
|
10 |
|
11 |
@dataclass
|
12 |
class SyntaxLogitsOutput(ModelOutput):
|
|
|
160 |
inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt')
|
161 |
inputs = {k:v.to(self.device) for k,v in inputs.items()}
|
162 |
logits = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_mst).logits
|
163 |
+
return parse_logits(inputs, sentences, tokenizer, logits)
|
164 |
|
165 |
+
def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: SyntaxLogitsOutput):
|
166 |
+
outputs = []
|
167 |
+
for i in range(len(sentences)):
|
168 |
+
deps = logits.dependency_head_indices[i].tolist()
|
169 |
+
funcs = logits.function_logits.argmax(-1)[i].tolist()
|
170 |
+
toks = tokenizer.convert_ids_to_tokens(inputs['input_ids'][i])[1:-1] # ignore cls and sep
|
171 |
+
|
172 |
+
# first, go through the tokens and create a mapping between each dependency index and the index without wordpieces
|
173 |
+
# wordpieces. At the same time, append the wordpieces in
|
174 |
+
idx_mapping = {-1:-1} # default root
|
175 |
+
real_idx = -1
|
176 |
+
for i in range(len(toks)):
|
177 |
+
if not toks[i].startswith('##'):
|
178 |
+
real_idx += 1
|
179 |
+
idx_mapping[i] = real_idx
|
180 |
|
181 |
+
# build our tree, keeping tracking of the root idx
|
182 |
+
tree = []
|
183 |
+
root_idx = 0
|
184 |
+
for i in range(len(toks)):
|
185 |
+
if toks[i].startswith('##'):
|
186 |
+
tree[-1]['word'] += toks[i][2:]
|
187 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
+
dep_idx = deps[i + 1] - 1 # increase 1 for cls, decrease 1 for cls
|
190 |
+
dep_head = 'root' if dep_idx == -1 else toks[dep_idx]
|
191 |
+
dep_func = ALL_FUNCTION_LABELS[funcs[i + 1]]
|
192 |
+
|
193 |
+
if dep_head == 'root': root_idx = len(tree)
|
194 |
+
tree.append(dict(word=toks[i], dep_head_idx=idx_mapping[dep_idx], dep_func=dep_func))
|
195 |
+
# append the head word
|
196 |
+
for d in tree:
|
197 |
+
d['dep_head'] = tree[d['dep_head_idx']]['word']
|
198 |
+
|
199 |
+
outputs.append(dict(tree=tree, root_idx=root_idx))
|
200 |
+
return outputs
|
201 |
+
|
202 |
|
|
|
203 |
def compute_mst_tree(attention_scores: torch.Tensor):
|
204 |
# attention scores should be 3 dimensions - batch x seq x seq (if it is 2 - just unsqueeze)
|
205 |
if attention_scores.ndim == 2: attention_scores = attention_scores.unsqueeze(0)
|