File size: 1,353 Bytes
14023f9
731d735
 
14023f9
 
 
 
731d735
 
 
 
 
 
 
 
 
 
14023f9
 
 
731d735
 
14023f9
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from transformers import PIPELINE_REGISTRY, TokenClassificationPipeline
import torch

@PIPELINE_REGISTRY.register_pipeline(task="multi_label_token_classification", pipeline_class=None)
class MultiLabelTokenClassificationPipeline(TokenClassificationPipeline):
    def __init__(self, model, tokenizer, **kwargs):
        super().__init__(model=model, tokenizer=tokenizer, **kwargs)
        self.id2label = {
            0: 'O',
            1: 'B-STEREO',
            2: 'I-STEREO',
            3: 'B-GEN',
            4: 'I-GEN',
            5: 'B-UNFAIR',
            6: 'I-UNFAIR'
        }

    def postprocess(self, model_outputs, **kwargs):
        results = []
        for logits, tokens in zip(model_outputs[0], model_outputs[1]):
            probabilities = torch.sigmoid(logits)
            predicted_labels = (probabilities > 0.5).int()
            token_results = []
            for i, token in enumerate(tokens):
                if token not in self.tokenizer.all_special_tokens:
                    label_indices = (predicted_labels[i] == 1).nonzero(as_tuple=False).squeeze(-1)
                    labels = [self.id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
                    token_results.append({"token": token, "labels": labels})
            results.append(token_results)
        return results