maximuspowers
commited on
Commit
•
14023f9
1
Parent(s):
f0d4714
pipeline registry?
Browse files- pipeline.py +16 -28
pipeline.py
CHANGED
@@ -1,15 +1,10 @@
|
|
1 |
-
from
|
2 |
-
import json
|
3 |
import torch
|
4 |
-
from transformers import BertTokenizerFast, BertForTokenClassification
|
5 |
-
|
6 |
-
class BiasNERPipeline:
|
7 |
-
def __init__(self, model_path: str = 'maximuspowers/bias-detection-ner'):
|
8 |
-
self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
9 |
-
self.model = BertForTokenClassification.from_pretrained(model_path)
|
10 |
-
self.model.eval()
|
11 |
-
self.model.to('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
|
|
|
|
|
|
|
|
|
13 |
self.id2label = {
|
14 |
0: 'O',
|
15 |
1: 'B-STEREO',
|
@@ -20,23 +15,16 @@ class BiasNERPipeline:
|
|
20 |
6: 'I-UNFAIR'
|
21 |
}
|
22 |
|
23 |
-
def
|
24 |
-
|
25 |
-
|
26 |
-
attention_mask = tokenized_inputs['attention_mask'].to(self.model.device)
|
27 |
-
|
28 |
-
with torch.no_grad():
|
29 |
-
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
30 |
-
logits = outputs.logits
|
31 |
probabilities = torch.sigmoid(logits)
|
32 |
predicted_labels = (probabilities > 0.5).int()
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
return json.dumps(result, indent=4)
|
|
|
1 |
+
from transformers import PIPELINE_REGISTRY, TokenClassificationPipeline
|
|
|
2 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
+
@PIPELINE_REGISTRY.register_pipeline(task="multi_label_token_classification", pipeline_class=None)
|
5 |
+
class MultiLabelTokenClassificationPipeline(TokenClassificationPipeline):
|
6 |
+
def __init__(self, model, tokenizer, **kwargs):
|
7 |
+
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
|
8 |
self.id2label = {
|
9 |
0: 'O',
|
10 |
1: 'B-STEREO',
|
|
|
15 |
6: 'I-UNFAIR'
|
16 |
}
|
17 |
|
18 |
+
def postprocess(self, model_outputs, **kwargs):
|
19 |
+
results = []
|
20 |
+
for logits, tokens in zip(model_outputs[0], model_outputs[1]):
|
|
|
|
|
|
|
|
|
|
|
21 |
probabilities = torch.sigmoid(logits)
|
22 |
predicted_labels = (probabilities > 0.5).int()
|
23 |
+
token_results = []
|
24 |
+
for i, token in enumerate(tokens):
|
25 |
+
if token not in self.tokenizer.all_special_tokens:
|
26 |
+
label_indices = (predicted_labels[i] == 1).nonzero(as_tuple=False).squeeze(-1)
|
27 |
+
labels = [self.id2label[idx.item()] for idx in label_indices] if label_indices.numel() > 0 else ['O']
|
28 |
+
token_results.append({"token": token, "labels": labels})
|
29 |
+
results.append(token_results)
|
30 |
+
return results
|
|
|
|