File size: 2,464 Bytes
07ae5ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
from typing import List, Dict, Any

class MultiModelAnonymizer:
    def __init__(self, model_paths: List[Dict[str, str]], thresholds: Dict[str, float]):
        self.recognizers = []
        for path in model_paths:
            model = AutoModelForTokenClassification.from_pretrained(path['model'])
            tokenizer = AutoTokenizer.from_pretrained(path['tokenizer'])
            self.recognizers.append(pipeline("ner", model=model, tokenizer=tokenizer))
        self.thresholds = thresholds

    def merge_overlapping_entities(self, entities):
        sorted_entities = sorted(entities, key=lambda x: (x['start'], -x['end']))
        merged = []
        
        for entity in sorted_entities:
            entity_type = entity['entity']
            threshold = self.thresholds.get(entity_type, 0.7)
            
            if not merged or entity['start'] >= merged[-1]['end']:
                if entity['score'] >= threshold:
                    merged.append(entity)
            else:
                prev = merged[-1]
                if entity['entity'] == prev['entity']:
                    if max(entity['score'], prev['score']) >= threshold:
                        merged[-1] = {
                            'start': min(prev['start'], entity['start']),
                            'end': max(prev['end'], entity['end']),
                            'entity': prev['entity'],
                            'word': prev['word'] + entity['word'].replace('##', ''),
                            'score': max(prev['score'], entity['score'])
                        }
                elif entity['score'] > prev['score'] and entity['score'] >= threshold:
                    merged[-1] = entity
        
        return merged

    def anonymize(self, text: str) -> str:
        all_entities = []
        
        for recognizer in self.recognizers:
            entities = recognizer(text)
            all_entities.extend(entities)
        
        merged_entities = self.merge_overlapping_entities(all_entities)
        merged_entities.sort(key=lambda x: -x['start'])
        
        anonymized_text = text
        for entity in merged_entities:
            start = entity['start']
            end = entity['end']
            anon_label = "[X]"
            anonymized_text = anonymized_text[:start] + anon_label + anonymized_text[end:]
        
        return anonymized_text