JonathanEGP commited on
Commit
07ae5ea
·
verified ·
1 Parent(s): 5a4dccf

Create multi_model_anonymizer.py

Browse files
Files changed (1) hide show
  1. multi_model_anonymizer.py +57 -0
multi_model_anonymizer.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
2
+ from typing import List, Dict, Any
3
+
4
+ class MultiModelAnonymizer:
5
+ def __init__(self, model_paths: List[Dict[str, str]], thresholds: Dict[str, float]):
6
+ self.recognizers = []
7
+ for path in model_paths:
8
+ model = AutoModelForTokenClassification.from_pretrained(path['model'])
9
+ tokenizer = AutoTokenizer.from_pretrained(path['tokenizer'])
10
+ self.recognizers.append(pipeline("ner", model=model, tokenizer=tokenizer))
11
+ self.thresholds = thresholds
12
+
13
+ def merge_overlapping_entities(self, entities):
14
+ sorted_entities = sorted(entities, key=lambda x: (x['start'], -x['end']))
15
+ merged = []
16
+
17
+ for entity in sorted_entities:
18
+ entity_type = entity['entity']
19
+ threshold = self.thresholds.get(entity_type, 0.7)
20
+
21
+ if not merged or entity['start'] >= merged[-1]['end']:
22
+ if entity['score'] >= threshold:
23
+ merged.append(entity)
24
+ else:
25
+ prev = merged[-1]
26
+ if entity['entity'] == prev['entity']:
27
+ if max(entity['score'], prev['score']) >= threshold:
28
+ merged[-1] = {
29
+ 'start': min(prev['start'], entity['start']),
30
+ 'end': max(prev['end'], entity['end']),
31
+ 'entity': prev['entity'],
32
+ 'word': prev['word'] + entity['word'].replace('##', ''),
33
+ 'score': max(prev['score'], entity['score'])
34
+ }
35
+ elif entity['score'] > prev['score'] and entity['score'] >= threshold:
36
+ merged[-1] = entity
37
+
38
+ return merged
39
+
40
+ def anonymize(self, text: str) -> str:
41
+ all_entities = []
42
+
43
+ for recognizer in self.recognizers:
44
+ entities = recognizer(text)
45
+ all_entities.extend(entities)
46
+
47
+ merged_entities = self.merge_overlapping_entities(all_entities)
48
+ merged_entities.sort(key=lambda x: -x['start'])
49
+
50
+ anonymized_text = text
51
+ for entity in merged_entities:
52
+ start = entity['start']
53
+ end = entity['end']
54
+ anon_label = "[X]"
55
+ anonymized_text = anonymized_text[:start] + anon_label + anonymized_text[end:]
56
+
57
+ return anonymized_text