Zamanonymize3 / fhe_anonymizer.py
jfrery-zama's picture
update anonymize file in clear with roberta +update uuid map with query id
d0b1031
raw
history blame
2.86 kB
import gensim
import re
from concrete.ml.deployment import FHEModelClient, FHEModelServer
from pathlib import Path
from concrete.ml.common.serialization.loaders import load
import uuid
import json
from transformers import AutoTokenizer, AutoModel
from utils_demo import get_batch_text_representation
base_dir = Path(__file__).parent
class FHEAnonymizer:
def __init__(self, punctuation_list=".,!?:;"):
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
self.embeddings_model = AutoModel.from_pretrained("obi/deid_roberta_i2b2")
self.punctuation_list = punctuation_list
with open(base_dir / "original_document_uuid_mapping.json", 'r') as file:
self.uuid_map = json.load(file)
path_to_model = (base_dir / "deployment").resolve()
self.client = FHEModelClient(path_to_model)
self.server = FHEModelServer(path_to_model)
self.client.generate_private_and_evaluation_keys()
self.evaluation_key = self.client.get_serialized_evaluation_keys()
def fhe_inference(self, x):
enc_x = self.client.quantize_encrypt_serialize(x)
enc_y = self.server.run(enc_x, self.evaluation_key)
y = self.client.deserialize_decrypt_dequantize(enc_y)
return y
def __call__(self, text: str):
# Pattern to identify words and non-words (including punctuation, spaces, etc.)
token_pattern = r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)"
tokens = re.findall(token_pattern, text)
identified_words_with_prob = []
processed_tokens = []
for token in tokens:
# Directly append non-word tokens or whitespace to processed_tokens
if not token.strip() or not re.match(r"\w+", token):
processed_tokens.append(token)
continue
# Prediction for each word
x = get_batch_text_representation([token], self.embeddings_model, self.tokenizer)
prediction_proba = self.fhe_inference(x)
probability = prediction_proba[0][1]
if probability >= 0.5:
identified_words_with_prob.append((token, probability))
# Use the existing UUID if available, otherwise generate a new one
tmp_uuid = self.uuid_map.get(token, str(uuid.uuid4())[:8])
processed_tokens.append(tmp_uuid)
self.uuid_map[token] = tmp_uuid
else:
processed_tokens.append(token)
# Update the UUID map with query.
with open(base_dir / "original_document_uuid_mapping.json", 'w') as file:
json.dump(self.uuid_map, file)
# Reconstruct the sentence
reconstructed_sentence = ''.join(processed_tokens)
return reconstructed_sentence, identified_words_with_prob