File size: 2,654 Bytes
3bb49fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from imports import *
from all_datasets import *

class PhoBertLstmCrf(RobertaForTokenClassification):
    def __init__(self, config):
        super(PhoBertLstmCrf, self).__init__(config=config)
        self.num_labels = config.num_labels
        self.lstm = nn.LSTM(input_size=config.hidden_size,
                            hidden_size=config.hidden_size // 2,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=True)
        self.crf = CRF(config.num_labels, batch_first=True)

    @staticmethod
    def sort_batch(src_tensor, lengths):
        """
        Sort a minibatch by the length of the sequences with the longest sequences first
        return the sorted batch targes and sequence lengths.
        This way the output can be used by pack_padd ed_sequences(...)
        """
        seq_lengths, perm_idx = lengths.sort(0, descending=True)
        seq_tensor = src_tensor[perm_idx]
        _, reversed_idx = perm_idx.sort(0, descending=False)
        return seq_tensor, seq_lengths, reversed_idx

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, valid_ids=None,
                label_masks=None):
        seq_outputs = self.roberta(input_ids=input_ids,
                                   token_type_ids=token_type_ids,
                                   attention_mask=attention_mask,
                                   head_mask=None)[0]

        batch_size, max_len, feat_dim = seq_outputs.shape
        seq_lens = torch.sum(label_masks, dim=-1)
        range_vector = torch.arange(0, batch_size, dtype=torch.long, device=seq_outputs.device).unsqueeze(1)
        seq_outputs = seq_outputs[range_vector, valid_ids]

        sorted_seq_outputs, sorted_seq_lens, reversed_idx = self.sort_batch(src_tensor=seq_outputs,
                                                                            lengths=seq_lens)
        packed_words = pack_padded_sequence(sorted_seq_outputs, sorted_seq_lens.cpu(), True)
        lstm_outs, _ = self.lstm(packed_words)
        lstm_outs, _ = pad_packed_sequence(lstm_outs, batch_first=True, total_length=max_len)
        seq_outputs = lstm_outs[reversed_idx]

        seq_outputs = self.dropout(seq_outputs)
        logits = self.classifier(seq_outputs)
        seq_tags = self.crf.decode(logits, mask=label_masks != 0)

        if labels is not None:
            log_likelihood = self.crf(logits, labels, mask=label_masks.type(torch.uint8))
            return NerOutput(loss=-1.0 * log_likelihood, tags=seq_tags, cls_metrics=seq_tags)
        else:
            return NerOutput(tags=seq_tags)