JonathanEGP
commited on
Create multi_model_anonymizer.py
Browse files- multi_model_anonymizer.py +57 -0
multi_model_anonymizer.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
|
2 |
+
from typing import List, Dict, Any
|
3 |
+
|
4 |
+
class MultiModelAnonymizer:
|
5 |
+
def __init__(self, model_paths: List[Dict[str, str]], thresholds: Dict[str, float]):
|
6 |
+
self.recognizers = []
|
7 |
+
for path in model_paths:
|
8 |
+
model = AutoModelForTokenClassification.from_pretrained(path['model'])
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained(path['tokenizer'])
|
10 |
+
self.recognizers.append(pipeline("ner", model=model, tokenizer=tokenizer))
|
11 |
+
self.thresholds = thresholds
|
12 |
+
|
13 |
+
def merge_overlapping_entities(self, entities):
|
14 |
+
sorted_entities = sorted(entities, key=lambda x: (x['start'], -x['end']))
|
15 |
+
merged = []
|
16 |
+
|
17 |
+
for entity in sorted_entities:
|
18 |
+
entity_type = entity['entity']
|
19 |
+
threshold = self.thresholds.get(entity_type, 0.7)
|
20 |
+
|
21 |
+
if not merged or entity['start'] >= merged[-1]['end']:
|
22 |
+
if entity['score'] >= threshold:
|
23 |
+
merged.append(entity)
|
24 |
+
else:
|
25 |
+
prev = merged[-1]
|
26 |
+
if entity['entity'] == prev['entity']:
|
27 |
+
if max(entity['score'], prev['score']) >= threshold:
|
28 |
+
merged[-1] = {
|
29 |
+
'start': min(prev['start'], entity['start']),
|
30 |
+
'end': max(prev['end'], entity['end']),
|
31 |
+
'entity': prev['entity'],
|
32 |
+
'word': prev['word'] + entity['word'].replace('##', ''),
|
33 |
+
'score': max(prev['score'], entity['score'])
|
34 |
+
}
|
35 |
+
elif entity['score'] > prev['score'] and entity['score'] >= threshold:
|
36 |
+
merged[-1] = entity
|
37 |
+
|
38 |
+
return merged
|
39 |
+
|
40 |
+
def anonymize(self, text: str) -> str:
|
41 |
+
all_entities = []
|
42 |
+
|
43 |
+
for recognizer in self.recognizers:
|
44 |
+
entities = recognizer(text)
|
45 |
+
all_entities.extend(entities)
|
46 |
+
|
47 |
+
merged_entities = self.merge_overlapping_entities(all_entities)
|
48 |
+
merged_entities.sort(key=lambda x: -x['start'])
|
49 |
+
|
50 |
+
anonymized_text = text
|
51 |
+
for entity in merged_entities:
|
52 |
+
start = entity['start']
|
53 |
+
end = entity['end']
|
54 |
+
anon_label = "[X]"
|
55 |
+
anonymized_text = anonymized_text[:start] + anon_label + anonymized_text[end:]
|
56 |
+
|
57 |
+
return anonymized_text
|