File size: 13,714 Bytes
014f409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cbe31a
 
 
 
 
014f409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import math
from transformers.utils import ModelOutput
import torch
from torch import nn
from typing import List, Tuple, Optional, Union
from dataclasses import dataclass
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast

ALL_FUNCTION_LABELS = ["nsubj", "punct", "mark", "case", "fixed", "obl", "det", "amod", "acl:relcl", "nmod", "cc", "conj", "root", "compound", "cop", "compound:affix", "advmod", "nummod", "appos", "nsubj:pass", "nmod:poss", "xcomp", "obj", "aux", "parataxis", "advcl", "ccomp", "csubj", "acl", "obl:tmod", "csubj:pass", "dep", "dislocated", "nmod:tmod", "nmod:npmod", "flat", "obl:npmod", "goeswith", "reparandum", "orphan", "list", "discourse", "iobj", "vocative", "expl", "flat:name"]

@dataclass
class SyntaxLogitsOutput(ModelOutput):
    dependency_logits: torch.FloatTensor = None
    function_logits: torch.FloatTensor = None
    dependency_head_indices: torch.LongTensor = None

    def detach(self):
        return SyntaxTaggingOutput(self.dependency_logits.detach(), self.function_logits.detach(), self.dependency_head_indices.detach())

@dataclass
class SyntaxTaggingOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[SyntaxLogitsOutput] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

@dataclass
class SyntaxLabels(ModelOutput):
    dependency_labels: Optional[torch.LongTensor] = None
    function_labels: Optional[torch.LongTensor] = None

    def detach(self):
        return SyntaxLabels(self.dependency_labels.detach(), self.function_labels.detach())
    
    def to(self, device):
        return SyntaxLabels(self.dependency_labels.to(device), self.function_labels.to(device))

class BertSyntaxParsingHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        # the attention query & key values
        self.head_size = config.syntax_head_size# int(config.hidden_size / config.num_attention_heads * 2)
        self.query = nn.Linear(config.hidden_size, self.head_size)
        self.key = nn.Linear(config.hidden_size, self.head_size)
        # the function classifier gets two encoding values and predicts the labels
        self.num_function_classes = len(ALL_FUNCTION_LABELS)
        self.cls = nn.Linear(config.hidden_size * 2, self.num_function_classes)

    def forward(
            self, 
            hidden_states: torch.Tensor, 
            extended_attention_mask: Optional[torch.Tensor],
            labels: Optional[SyntaxLabels] = None,
            compute_mst: bool = False) -> Tuple[torch.Tensor, SyntaxLogitsOutput]:
        
        # Take the dot product between "query" and "key" to get the raw attention scores.
        query_layer = self.query(hidden_states)
        key_layer = self.key(hidden_states)
        attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / math.sqrt(self.head_size)

        # add in the attention mask
        if extended_attention_mask is not None:
            if extended_attention_mask.ndim == 4:
                extended_attention_mask = extended_attention_mask.squeeze(1)
            attention_scores += extended_attention_mask# batch x seq x seq

        # At this point take the hidden_state of the word and of the dependency word, and predict the function
        # If labels are provided, use the labels.
        if self.training and labels is not None:
            # Note that the labels can have -100, so just set those to zero with a max
            dep_indices = labels.dependency_labels.clamp_min(0)
        # Otherwise - check if he wants the MST or just the argmax
        elif compute_mst:
            dep_indices = compute_mst_tree(attention_scores)
        else: 
            dep_indices = torch.argmax(attention_scores, dim=-1)
        
        # After we retrieved the dependency indicies, create a tensor of teh batch indices, and and retrieve the vectors of the heads to calculate the function
        batch_indices = torch.arange(dep_indices.size(0)).view(-1, 1).expand(-1, dep_indices.size(1)).to(dep_indices.device)
        dep_vectors = hidden_states[batch_indices, dep_indices, :] # batch x seq x dim

        # concatenate that with the last hidden states, and send to the classifier output
        cls_inputs = torch.cat((hidden_states, dep_vectors), dim=-1)
        function_logits = self.cls(cls_inputs)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            # step 1: dependency scores loss - this is applied to the attention scores
            loss = loss_fct(attention_scores.view(-1, hidden_states.size(-2)), labels.dependency_labels.view(-1))
            # step 2: function loss
            loss += loss_fct(function_logits.view(-1, self.num_function_classes), labels.function_labels.view(-1))
        
        return (loss, SyntaxLogitsOutput(attention_scores, function_logits, dep_indices))


class BertForSyntaxParsing(BertPreTrainedModel):

    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config, add_pooling_layer=False)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.syntax = BertSyntaxParsingHead(config)
        
        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        labels: Optional[SyntaxLabels] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        compute_syntax_mst: Optional[bool] = None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        bert_outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        extended_attention_mask = None
        if attention_mask is not None:
            extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.size())
        # apply the syntax head
        loss, logits = self.syntax(self.dropout(bert_outputs[0]), extended_attention_mask, labels, compute_syntax_mst)
        
        if not return_dict:
            return (loss,(logits.dependency_logits, logits.function_logits)) + bert_outputs[2:]
        
        return SyntaxTaggingOutput(
            loss=loss,
            logits=logits,
            hidden_states=bert_outputs.hidden_states,
            attentions=bert_outputs.attentions,
        )

    def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, compute_mst=True):
        if isinstance(sentences, str):
            sentences = [sentences]
            
        # predict the logits for the sentence
        inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt')
        inputs = {k:v.to(self.device) for k,v in inputs.items()}
        logits = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_mst).logits

        outputs = []        
        for i in range(len(sentences)):
            deps = logits.dependency_head_indices[i].tolist()
            funcs = logits.function_logits.argmax(-1)[i].tolist()
            toks = tokenizer.convert_ids_to_tokens(inputs['input_ids'][i])[1:-1] # ignore cls and sep
            
            # first, go through the tokens and create a mapping between each dependency index and the index without wordpieces
            # wordpieces. At the same time, append the wordpieces in
            idx_mapping = {-1:-1} # default root 
            real_idx = -1
            for i in range(len(toks)):
                if not toks[i].startswith('##'):
                    real_idx += 1
                idx_mapping[i] = real_idx
                
            # build our tree, keeping tracking of the root idx
            tree = []
            root_idx = 0
            for i in range(len(toks)):
                if toks[i].startswith('##'):
                    tree[-1]['word'] += toks[i][2:]
                    continue 
                
                dep_idx = deps[i + 1] - 1 # increase 1 for cls, decrease 1 for cls
                dep_head = 'root' if dep_idx == -1 else toks[dep_idx]
                dep_func = ALL_FUNCTION_LABELS[funcs[i + 1]]

                if dep_head == 'root': root_idx = len(tree)
                tree.append(dict(word=toks[i], dep_head_idx=idx_mapping[dep_idx], dep_func=dep_func))
            # append the head word
            for d in tree:
                d['dep_head'] = tree[d['dep_head_idx']]['word']
            
            outputs.append(dict(tree=tree, root_idx=root_idx))
        return outputs

    
def compute_mst_tree(attention_scores: torch.Tensor):
    # attention scores should be 3 dimensions - batch x seq x seq (if it is 2 - just unsqueeze)
    if attention_scores.ndim == 2: attention_scores = attention_scores.unsqueeze(0)
    if attention_scores.ndim != 3 or attention_scores.shape[1] != attention_scores.shape[2]:
        raise ValueError(f'Expected attention scores to be of shape batch x seq x seq, instead got {attention_scores.shape}')
    
    batch_size, seq_len, _ = attention_scores.shape
    # start by softmaxing so the scores are comparable
    attention_scores = attention_scores.softmax(dim=-1)

    # set the values for the CLS and sep to all by very low, so they never get chosen as a replacement arc
    attention_scores[:, 0, :] = -10000
    attention_scores[:, -1, :] = -10000
    attention_scores[:, :, -1] = -10000 # can never predict sep
    
    # find the root, and make him super high so we never have a conflict
    root_cands = torch.argsort(attention_scores[:, :, 0], dim=-1)
    batch_indices = torch.arange(batch_size, device=root_cands.device)
    attention_scores[batch_indices.unsqueeze(1), root_cands, 0] = -10000
    attention_scores[batch_indices, root_cands[:, -1], 0] = 10000
    
    # we start by getting the argmax for each score, and then computing the cycles and contracting them
    sorted_indices = torch.argsort(attention_scores, dim=-1, descending=True)
    indices = sorted_indices[:, :, 0].clone() # take the argmax
    
    # go through each batch item and make sure our tree works
    for batch_idx in range(batch_size):
        # We have one root - detect the cycles and contract them. A cycle can never contain the root so really
        # for every cycle, we look at all the nodes, and find the highest arc out of the cycle for any values. Replace that and tada
        has_cycle, cycle_nodes = detect_cycle(indices[batch_idx])
        while has_cycle:
            base_idx, head_idx = choose_contracting_arc(indices[batch_idx], sorted_indices[batch_idx], cycle_nodes, attention_scores[batch_idx])
            indices[batch_idx, base_idx] = head_idx
            # find the next cycle
            has_cycle, cycle_nodes = detect_cycle(indices[batch_idx])
            
    return indices

def detect_cycle(indices: torch.LongTensor):
    # Simple cycle detection algorithm
    # Returns a boolean indicating if a cycle is detected and the nodes involved in the cycle
    visited = set()
    for node in range(1, len(indices) - 1): # ignore the CLS/SEP tokens
        if node in visited:
            continue
        current_path = set()
        while node not in visited:
            visited.add(node)
            current_path.add(node)
            node = indices[node].item()
            if node == 0: break # roots never point to anything
            if node in current_path:
                return True, current_path  # Cycle detected
    return False, None

def choose_contracting_arc(indices: torch.LongTensor, sorted_indices: torch.LongTensor, cycle_nodes: set, scores: torch.FloatTensor):
    # Chooses the highest-scoring, non-cycling arc from a graph. Iterates through 'cycle_nodes' to find
    # the best arc based on 'scores', avoiding cycles and zero node connections.
    # For each node, we only look at the next highest scoring non-cycling arc 
    best_base_idx, best_head_idx = -1, -1
    score = float('-inf')
    
    # convert the indices to a list once, to avoid multiple conversions (saves a few seconds)
    currents = indices.tolist()
    for base_node in cycle_nodes:
        # we don't want to take anything that has a higher score than the current value - we can end up in an endless loop
        # Since the indices are sorted, as soon as we find our current item, we can move on to the next. 
        current = currents[base_node]
        found_current = False
        
        for head_node in sorted_indices[base_node].tolist():
            if head_node == current:
                found_current = True
                continue
            if not found_current or head_node in cycle_nodes or head_node == 0: 
                continue
            
            current_score = scores[base_node, head_node].item()
            if current_score > score:
                best_base_idx, best_head_idx, score = base_node, head_node, current_score
            break
    
    return best_base_idx, best_head_idx