Shaltiel commited on
Commit
707c144
1 Parent(s): d23c5a4

Updated with fixed MST and batching from DictaBERT-Joint

Browse files
Files changed (1) hide show
  1. BertForSyntaxParsing.py +50 -22
BertForSyntaxParsing.py CHANGED
@@ -73,7 +73,7 @@ class BertSyntaxParsingHead(nn.Module):
73
  dep_indices = labels.dependency_labels.clamp_min(0)
74
  # Otherwise - check if he wants the MST or just the argmax
75
  elif compute_mst:
76
- dep_indices = compute_mst_tree(attention_scores)
77
  else:
78
  dep_indices = torch.argmax(attention_scores, dim=-1)
79
 
@@ -160,14 +160,17 @@ class BertForSyntaxParsing(BertPreTrainedModel):
160
  inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt')
161
  inputs = {k:v.to(self.device) for k,v in inputs.items()}
162
  logits = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_mst).logits
163
- return parse_logits(inputs, sentences, tokenizer, logits)
164
 
165
- def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: SyntaxLogitsOutput):
166
  outputs = []
 
 
 
167
  for i in range(len(sentences)):
168
  deps = logits.dependency_head_indices[i].tolist()
169
  funcs = logits.function_logits.argmax(-1)[i].tolist()
170
- toks = tokenizer.convert_ids_to_tokens(inputs['input_ids'][i])[1:-1] # ignore cls and sep
171
 
172
  # first, go through the tokens and create a mapping between each dependency index and the index without wordpieces
173
  # wordpieces. At the same time, append the wordpieces in
@@ -187,6 +190,8 @@ def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenize
187
  continue
188
 
189
  dep_idx = deps[i + 1] - 1 # increase 1 for cls, decrease 1 for cls
 
 
190
  dep_head = 'root' if dep_idx == -1 else toks[dep_idx]
191
  dep_func = ALL_FUNCTION_LABELS[funcs[i + 1]]
192
 
@@ -200,7 +205,7 @@ def parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenize
200
  return outputs
201
 
202
 
203
- def compute_mst_tree(attention_scores: torch.Tensor):
204
  # attention scores should be 3 dimensions - batch x seq x seq (if it is 2 - just unsqueeze)
205
  if attention_scores.ndim == 2: attention_scores = attention_scores.unsqueeze(0)
206
  if attention_scores.ndim != 3 or attention_scores.shape[1] != attention_scores.shape[2]:
@@ -209,40 +214,58 @@ def compute_mst_tree(attention_scores: torch.Tensor):
209
  batch_size, seq_len, _ = attention_scores.shape
210
  # start by softmaxing so the scores are comparable
211
  attention_scores = attention_scores.softmax(dim=-1)
 
 
 
 
 
 
 
 
 
 
212
 
213
  # set the values for the CLS and sep to all by very low, so they never get chosen as a replacement arc
214
- attention_scores[:, 0, :] = -10000
215
- attention_scores[:, -1, :] = -10000
216
- attention_scores[:, :, -1] = -10000 # can never predict sep
 
 
217
 
218
  # find the root, and make him super high so we never have a conflict
219
  root_cands = torch.argsort(attention_scores[:, :, 0], dim=-1)
220
- batch_indices = torch.arange(batch_size, device=root_cands.device)
221
- attention_scores[batch_indices.unsqueeze(1), root_cands, 0] = -10000
222
- attention_scores[batch_indices, root_cands[:, -1], 0] = 10000
223
-
224
  # we start by getting the argmax for each score, and then computing the cycles and contracting them
225
  sorted_indices = torch.argsort(attention_scores, dim=-1, descending=True)
226
  indices = sorted_indices[:, :, 0].clone() # take the argmax
227
 
 
 
 
 
 
228
  # go through each batch item and make sure our tree works
229
  for batch_idx in range(batch_size):
230
  # We have one root - detect the cycles and contract them. A cycle can never contain the root so really
231
  # for every cycle, we look at all the nodes, and find the highest arc out of the cycle for any values. Replace that and tada
232
- has_cycle, cycle_nodes = detect_cycle(indices[batch_idx])
 
233
  while has_cycle:
234
- base_idx, head_idx = choose_contracting_arc(indices[batch_idx], sorted_indices[batch_idx], cycle_nodes, attention_scores[batch_idx])
235
  indices[batch_idx, base_idx] = head_idx
 
236
  # find the next cycle
237
- has_cycle, cycle_nodes = detect_cycle(indices[batch_idx])
238
-
239
  return indices
240
 
241
- def detect_cycle(indices: torch.LongTensor):
242
  # Simple cycle detection algorithm
243
  # Returns a boolean indicating if a cycle is detected and the nodes involved in the cycle
244
  visited = set()
245
- for node in range(1, len(indices) - 1): # ignore the CLS/SEP tokens
246
  if node in visited:
247
  continue
248
  current_path = set()
@@ -255,31 +278,36 @@ def detect_cycle(indices: torch.LongTensor):
255
  return True, current_path # Cycle detected
256
  return False, None
257
 
258
- def choose_contracting_arc(indices: torch.LongTensor, sorted_indices: torch.LongTensor, cycle_nodes: set, scores: torch.FloatTensor):
259
  # Chooses the highest-scoring, non-cycling arc from a graph. Iterates through 'cycle_nodes' to find
260
  # the best arc based on 'scores', avoiding cycles and zero node connections.
261
  # For each node, we only look at the next highest scoring non-cycling arc
262
  best_base_idx, best_head_idx = -1, -1
263
- score = float('-inf')
264
 
265
  # convert the indices to a list once, to avoid multiple conversions (saves a few seconds)
266
  currents = indices.tolist()
267
  for base_node in cycle_nodes:
 
268
  # we don't want to take anything that has a higher score than the current value - we can end up in an endless loop
269
  # Since the indices are sorted, as soon as we find our current item, we can move on to the next.
270
  current = currents[base_node]
271
  found_current = False
272
 
273
- for head_node in sorted_indices[base_node].tolist():
274
  if head_node == current:
275
  found_current = True
276
  continue
 
277
  if not found_current or head_node in cycle_nodes or head_node == 0:
278
  continue
279
 
280
- current_score = scores[base_node, head_node].item()
281
  if current_score > score:
282
  best_base_idx, best_head_idx, score = base_node, head_node, current_score
283
  break
284
 
 
 
 
285
  return best_base_idx, best_head_idx
 
73
  dep_indices = labels.dependency_labels.clamp_min(0)
74
  # Otherwise - check if he wants the MST or just the argmax
75
  elif compute_mst:
76
+ dep_indices = compute_mst_tree(attention_scores, extended_attention_mask)
77
  else:
78
  dep_indices = torch.argmax(attention_scores, dim=-1)
79
 
 
160
  inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt')
161
  inputs = {k:v.to(self.device) for k,v in inputs.items()}
162
  logits = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_mst).logits
163
+ return parse_logits(inputs['input_ids'].tolist(), sentences, tokenizer, logits)
164
 
165
+ def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: SyntaxLogitsOutput):
166
  outputs = []
167
+
168
+ special_toks = tokenizer.all_special_tokens
169
+ special_toks.remove(tokenizer.unk_token)
170
  for i in range(len(sentences)):
171
  deps = logits.dependency_head_indices[i].tolist()
172
  funcs = logits.function_logits.argmax(-1)[i].tolist()
173
+ toks = [tok for tok in tokenizer.convert_ids_to_tokens(input_ids[i]) if tok not in special_toks]
174
 
175
  # first, go through the tokens and create a mapping between each dependency index and the index without wordpieces
176
  # wordpieces. At the same time, append the wordpieces in
 
190
  continue
191
 
192
  dep_idx = deps[i + 1] - 1 # increase 1 for cls, decrease 1 for cls
193
+ if dep_idx == len(toks): dep_idx = i - 1 # if he predicts sep, then just point to the previous word
194
+
195
  dep_head = 'root' if dep_idx == -1 else toks[dep_idx]
196
  dep_func = ALL_FUNCTION_LABELS[funcs[i + 1]]
197
 
 
205
  return outputs
206
 
207
 
208
+ def compute_mst_tree(attention_scores: torch.Tensor, extended_attention_mask: torch.LongTensor):
209
  # attention scores should be 3 dimensions - batch x seq x seq (if it is 2 - just unsqueeze)
210
  if attention_scores.ndim == 2: attention_scores = attention_scores.unsqueeze(0)
211
  if attention_scores.ndim != 3 or attention_scores.shape[1] != attention_scores.shape[2]:
 
214
  batch_size, seq_len, _ = attention_scores.shape
215
  # start by softmaxing so the scores are comparable
216
  attention_scores = attention_scores.softmax(dim=-1)
217
+
218
+ batch_indices = torch.arange(batch_size, device=attention_scores.device)
219
+ seq_indices = torch.arange(seq_len, device=attention_scores.device)
220
+
221
+ seq_lens = torch.full((batch_size,), seq_len)
222
+
223
+ if extended_attention_mask is not None:
224
+ seq_lens = torch.argmax((extended_attention_mask != 0).int(), dim=2).squeeze(1)
225
+ # zero out any padding
226
+ attention_scores[extended_attention_mask.squeeze(1) != 0] = 0
227
 
228
  # set the values for the CLS and sep to all by very low, so they never get chosen as a replacement arc
229
+ attention_scores[:, 0, :] = 0
230
+ attention_scores[batch_indices, seq_lens - 1, :] = 0
231
+ attention_scores[batch_indices, :, seq_lens - 1] = 0 # can never predict sep
232
+ # set the values for each token pointing to itself be 0
233
+ attention_scores[:, seq_indices, seq_indices] = 0
234
 
235
  # find the root, and make him super high so we never have a conflict
236
  root_cands = torch.argsort(attention_scores[:, :, 0], dim=-1)
237
+ attention_scores[batch_indices.unsqueeze(1), root_cands, 0] = 0
238
+ attention_scores[batch_indices, root_cands[:, -1], 0] = 1.0
239
+
 
240
  # we start by getting the argmax for each score, and then computing the cycles and contracting them
241
  sorted_indices = torch.argsort(attention_scores, dim=-1, descending=True)
242
  indices = sorted_indices[:, :, 0].clone() # take the argmax
243
 
244
+ attention_scores = attention_scores.tolist()
245
+ seq_lens = seq_lens.tolist()
246
+ sorted_indices = [[sub_l[:slen] for sub_l in l[:slen]] for l,slen in zip(sorted_indices.tolist(), seq_lens)]
247
+
248
+
249
  # go through each batch item and make sure our tree works
250
  for batch_idx in range(batch_size):
251
  # We have one root - detect the cycles and contract them. A cycle can never contain the root so really
252
  # for every cycle, we look at all the nodes, and find the highest arc out of the cycle for any values. Replace that and tada
253
+ has_cycle, cycle_nodes = detect_cycle(indices[batch_idx], seq_lens[batch_idx])
254
+ contracted_arcs = set()
255
  while has_cycle:
256
+ base_idx, head_idx = choose_contracting_arc(indices[batch_idx], sorted_indices[batch_idx], cycle_nodes, contracted_arcs, seq_lens[batch_idx], attention_scores[batch_idx])
257
  indices[batch_idx, base_idx] = head_idx
258
+ contracted_arcs.add(base_idx)
259
  # find the next cycle
260
+ has_cycle, cycle_nodes = detect_cycle(indices[batch_idx], seq_lens[batch_idx])
261
+
262
  return indices
263
 
264
+ def detect_cycle(indices: torch.LongTensor, seq_len: int):
265
  # Simple cycle detection algorithm
266
  # Returns a boolean indicating if a cycle is detected and the nodes involved in the cycle
267
  visited = set()
268
+ for node in range(1, seq_len - 1): # ignore the CLS/SEP tokens
269
  if node in visited:
270
  continue
271
  current_path = set()
 
278
  return True, current_path # Cycle detected
279
  return False, None
280
 
281
+ def choose_contracting_arc(indices: torch.LongTensor, sorted_indices: List[List[int]], cycle_nodes: set, contracted_arcs: set, seq_len: int, scores: List[List[float]]):
282
  # Chooses the highest-scoring, non-cycling arc from a graph. Iterates through 'cycle_nodes' to find
283
  # the best arc based on 'scores', avoiding cycles and zero node connections.
284
  # For each node, we only look at the next highest scoring non-cycling arc
285
  best_base_idx, best_head_idx = -1, -1
286
+ score = 0
287
 
288
  # convert the indices to a list once, to avoid multiple conversions (saves a few seconds)
289
  currents = indices.tolist()
290
  for base_node in cycle_nodes:
291
+ if base_node in contracted_arcs: continue
292
  # we don't want to take anything that has a higher score than the current value - we can end up in an endless loop
293
  # Since the indices are sorted, as soon as we find our current item, we can move on to the next.
294
  current = currents[base_node]
295
  found_current = False
296
 
297
+ for head_node in sorted_indices[base_node]:
298
  if head_node == current:
299
  found_current = True
300
  continue
301
+ if head_node in contracted_arcs: continue
302
  if not found_current or head_node in cycle_nodes or head_node == 0:
303
  continue
304
 
305
+ current_score = scores[base_node][head_node]
306
  if current_score > score:
307
  best_base_idx, best_head_idx, score = base_node, head_node, current_score
308
  break
309
 
310
+ if best_base_idx == -1:
311
+ raise ValueError('Stuck in endless loop trying to compute syntax mst. Please try again setting compute_syntax_mst=False')
312
+
313
  return best_base_idx, best_head_idx