Zamanonymize3 / fhe_anonymizer.py
jfrery-zama's picture
update
b160148
raw
history blame
No virus
2.74 kB
import gensim
import re
from concrete.ml.deployment import FHEModelClient, FHEModelServer
from pathlib import Path
from concrete.ml.common.serialization.loaders import load
import uuid
import json
base_dir = Path(__file__).parent
class FHEAnonymizer:
def __init__(self, punctuation_list=".,!?:;"):
self.embeddings_model = gensim.models.FastText.load(
str(base_dir / "models/without_pronoun_embedded_model.model")
)
self.punctuation_list = punctuation_list
with open(base_dir / "models/without_pronoun_cml_xgboost.model", "r") as model_file:
self.fhe_ner_detection = load(file=model_file)
with open(base_dir / "original_document_uuid_mapping.json", 'r') as file:
self.uuid_map = json.load(file)
path_to_model = (base_dir / "deployment").resolve()
self.client = FHEModelClient(path_to_model)
self.server = FHEModelServer(path_to_model)
self.client.generate_private_and_evaluation_keys()
self.evaluation_key = self.client.get_serialized_evaluation_keys()
def fhe_inference(self, x):
enc_x = self.client.quantize_encrypt_serialize(x)
enc_y = self.server.run(enc_x, self.evaluation_key)
y = self.client.deserialize_decrypt_dequantize(enc_y)
return y
def __call__(self, text: str):
# Pattern to identify words and non-words (including punctuation, spaces, etc.)
token_pattern = r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)"
tokens = re.findall(token_pattern, text)
identified_words_with_prob = []
processed_tokens = []
print(tokens)
for token in tokens:
# Directly append non-word tokens or whitespace to processed_tokens
if not token.strip() or not re.match(r"\w+", token):
processed_tokens.append(token)
continue
# Prediction for each word
x = self.embeddings_model.wv[token][None]
# prediction_proba = self.fhe_ner_detection.predict_proba(x)
prediction_proba = self.fhe_inference(x)
probability = prediction_proba[0][1]
if probability >= 0.5:
identified_words_with_prob.append((token, probability))
# Use the existing UUID if available, otherwise generate a new one
tmp_uuid = self.uuid_map.get(token, str(uuid.uuid4())[:8])
processed_tokens.append(tmp_uuid)
self.uuid_map[token] = tmp_uuid
else:
processed_tokens.append(token)
# Reconstruct the sentence
reconstructed_sentence = ''.join(processed_tokens)
return reconstructed_sentence, identified_words_with_prob