File size: 1,353 Bytes
14023f9 731d735 14023f9 731d735 14023f9 731d735 14023f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from transformers import PIPELINE_REGISTRY, TokenClassificationPipeline
import torch
@PIPELINE_REGISTRY.register_pipeline(task="multi_label_token_classification", pipeline_class=None)
class MultiLabelTokenClassificationPipeline(TokenClassificationPipeline):
def __init__(self, model, tokenizer, **kwargs):
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
self.id2label = {
0: 'O',
1: 'B-STEREO',
2: 'I-STEREO',
3: 'B-GEN',
4: 'I-GEN',
5: 'B-UNFAIR',
6: 'I-UNFAIR'
}
def postprocess(self, model_outputs, **kwargs):
results = []
for logits, tokens in zip(model_outputs[0], model_outputs[1]):
probabilities = torch.sigmoid(logits)
predicted_labels = (probabilities > 0.5).int()
token_results = []
for i, token in enumerate(tokens):
if token not in self.tokenizer.all_special_tokens:
label_indices = (predicted_labels[i] == 1).nonzero(as_tuple=False).squeeze(-1)
labels = [self.id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
token_results.append({"token": token, "labels": labels})
results.append(token_results)
return results
|