from transformers import PIPELINE_REGISTRY, TokenClassificationPipeline import torch @PIPELINE_REGISTRY.register_pipeline(task="multi_label_token_classification", pipeline_class=None) class MultiLabelTokenClassificationPipeline(TokenClassificationPipeline): def __init__(self, model, tokenizer, **kwargs): super().__init__(model=model, tokenizer=tokenizer, **kwargs) self.id2label = { 0: 'O', 1: 'B-STEREO', 2: 'I-STEREO', 3: 'B-GEN', 4: 'I-GEN', 5: 'B-UNFAIR', 6: 'I-UNFAIR' } def postprocess(self, model_outputs, **kwargs): results = [] for logits, tokens in zip(model_outputs[0], model_outputs[1]): probabilities = torch.sigmoid(logits) predicted_labels = (probabilities > 0.5).int() token_results = [] for i, token in enumerate(tokens): if token not in self.tokenizer.all_special_tokens: label_indices = (predicted_labels[i] == 1).nonzero(as_tuple=False).squeeze(-1) labels = [self.id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O'] token_results.append({"token": token, "labels": labels}) results.append(token_results) return results