File size: 2,861 Bytes
646bd9e
 
 
 
 
df6182e
 
1dfccc3
 
646bd9e
 
 
2b591f4
646bd9e
 
 
d0b1031
1dfccc3
 
 
646bd9e
 
df6182e
 
 
646bd9e
 
 
 
 
 
 
 
 
 
 
 
 
df6182e
 
 
 
 
 
 
 
 
 
 
646bd9e
 
1dfccc3
 
628fe8f
2b591f4
646bd9e
df6182e
 
b160148
df6182e
 
 
 
 
 
646bd9e
d0b1031
 
 
 
df6182e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gensim
import re
from concrete.ml.deployment import FHEModelClient, FHEModelServer
from pathlib import Path
from concrete.ml.common.serialization.loaders import load
import uuid
import json
from transformers import AutoTokenizer, AutoModel
from utils_demo import get_batch_text_representation

base_dir = Path(__file__).parent


class FHEAnonymizer:
    def __init__(self, punctuation_list=".,!?:;"):

        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
        self.embeddings_model = AutoModel.from_pretrained("obi/deid_roberta_i2b2")

        self.punctuation_list = punctuation_list

        with open(base_dir / "original_document_uuid_mapping.json", 'r') as file:
            self.uuid_map = json.load(file)

        path_to_model = (base_dir / "deployment").resolve()
        self.client = FHEModelClient(path_to_model)
        self.server = FHEModelServer(path_to_model)
        self.client.generate_private_and_evaluation_keys()
        self.evaluation_key = self.client.get_serialized_evaluation_keys()

    def fhe_inference(self, x):
        enc_x = self.client.quantize_encrypt_serialize(x)
        enc_y = self.server.run(enc_x, self.evaluation_key)
        y = self.client.deserialize_decrypt_dequantize(enc_y)
        return y

    def __call__(self, text: str):
        # Pattern to identify words and non-words (including punctuation, spaces, etc.)
        token_pattern = r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)"
        tokens = re.findall(token_pattern, text)
        identified_words_with_prob = []
        processed_tokens = []

        for token in tokens:
            # Directly append non-word tokens or whitespace to processed_tokens
            if not token.strip() or not re.match(r"\w+", token):
                processed_tokens.append(token)
                continue

            # Prediction for each word
            x = get_batch_text_representation([token], self.embeddings_model, self.tokenizer)

            prediction_proba = self.fhe_inference(x)
            probability = prediction_proba[0][1]

            if probability >= 0.5:
                identified_words_with_prob.append((token, probability))

                # Use the existing UUID if available, otherwise generate a new one
                tmp_uuid = self.uuid_map.get(token, str(uuid.uuid4())[:8])
                processed_tokens.append(tmp_uuid)
                self.uuid_map[token] = tmp_uuid
            else:
                processed_tokens.append(token)

        # Update the UUID map with query.
        with open(base_dir / "original_document_uuid_mapping.json", 'w') as file:
            json.dump(self.uuid_map, file)

        # Reconstruct the sentence
        reconstructed_sentence = ''.join(processed_tokens)
        return reconstructed_sentence, identified_words_with_prob