Zamanonymize3 / fhe_anonymizer.py
jfrery-zama's picture
update representation with roberta + new fast model
1dfccc3
raw
history blame
2.98 kB
import gensim
import re
from concrete.ml.deployment import FHEModelClient, FHEModelServer
from pathlib import Path
from concrete.ml.common.serialization.loaders import load
import uuid
import json
from transformers import AutoTokenizer, AutoModel
from utils_demo import get_batch_text_representation
base_dir = Path(__file__).parent
class FHEAnonymizer:
def __init__(self, punctuation_list=".,!?:;"):
# Load tokenizer and model, move model to the selected device
self.tokenizer = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
self.embeddings_model = AutoModel.from_pretrained("obi/deid_roberta_i2b2")
self.punctuation_list = punctuation_list
with open(base_dir / "models/without_pronoun_cml_xgboost.model", "r") as model_file:
self.fhe_ner_detection = load(file=model_file)
with open(base_dir / "original_document_uuid_mapping.json", 'r') as file:
self.uuid_map = json.load(file)
path_to_model = (base_dir / "deployment").resolve()
self.client = FHEModelClient(path_to_model)
self.server = FHEModelServer(path_to_model)
self.client.generate_private_and_evaluation_keys()
self.evaluation_key = self.client.get_serialized_evaluation_keys()
def fhe_inference(self, x):
enc_x = self.client.quantize_encrypt_serialize(x)
enc_y = self.server.run(enc_x, self.evaluation_key)
y = self.client.deserialize_decrypt_dequantize(enc_y)
return y
def __call__(self, text: str):
# Pattern to identify words and non-words (including punctuation, spaces, etc.)
token_pattern = r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)"
tokens = re.findall(token_pattern, text)
identified_words_with_prob = []
processed_tokens = []
print(tokens)
for token in tokens:
# Directly append non-word tokens or whitespace to processed_tokens
if not token.strip() or not re.match(r"\w+", token):
processed_tokens.append(token)
continue
# Prediction for each word
x = get_batch_text_representation([token], self.embeddings_model, self.tokenizer)
# prediction_proba = self.fhe_ner_detection.predict_proba(x)
prediction_proba = self.fhe_inference(x)
probability = prediction_proba[0][1]
if probability >= 0.5:
identified_words_with_prob.append((token, probability))
# Use the existing UUID if available, otherwise generate a new one
tmp_uuid = self.uuid_map.get(token, str(uuid.uuid4())[:8])
processed_tokens.append(tmp_uuid)
self.uuid_map[token] = tmp_uuid
else:
processed_tokens.append(token)
# Reconstruct the sentence
reconstructed_sentence = ''.join(processed_tokens)
return reconstructed_sentence, identified_words_with_prob