|
from transformers import PIPELINE_REGISTRY, TokenClassificationPipeline |
|
import torch |
|
|
|
@PIPELINE_REGISTRY.register_pipeline(task="multi_label_token_classification", pipeline_class=None) |
|
class MultiLabelTokenClassificationPipeline(TokenClassificationPipeline): |
|
def __init__(self, model, tokenizer, **kwargs): |
|
super().__init__(model=model, tokenizer=tokenizer, **kwargs) |
|
self.id2label = { |
|
0: 'O', |
|
1: 'B-STEREO', |
|
2: 'I-STEREO', |
|
3: 'B-GEN', |
|
4: 'I-GEN', |
|
5: 'B-UNFAIR', |
|
6: 'I-UNFAIR' |
|
} |
|
|
|
def postprocess(self, model_outputs, **kwargs): |
|
results = [] |
|
for logits, tokens in zip(model_outputs[0], model_outputs[1]): |
|
probabilities = torch.sigmoid(logits) |
|
predicted_labels = (probabilities > 0.5).int() |
|
token_results = [] |
|
for i, token in enumerate(tokens): |
|
if token not in self.tokenizer.all_special_tokens: |
|
label_indices = (predicted_labels[i] == 1).nonzero(as_tuple=False).squeeze(-1) |
|
labels = [self.id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O'] |
|
token_results.append({"token": token, "labels": labels}) |
|
results.append(token_results) |
|
return results |
|
|