Zamanonymize3 / fhe_anonymizer.py
jfrery-zama's picture
add probability along with detected words
2b591f4
raw
history blame
No virus
2.9 kB
import gensim
import re
from concrete.ml.deployment import FHEModelClient, FHEModelServer
from pathlib import Path
from concrete.ml.common.serialization.loaders import load
base_dir = Path(__file__).parent
class FHEAnonymizer:
def __init__(self, punctuation_list=".,!?:;"):
self.embeddings_model = gensim.models.FastText.load(
str(base_dir / "embedded_model.model")
)
self.punctuation_list = punctuation_list
with open(base_dir / "cml_xgboost.model", "r") as model_file:
self.fhe_ner_detection = load(file=model_file)
path_to_model = (base_dir / "deployment").resolve()
self.client = FHEModelClient(path_to_model)
self.server = FHEModelServer(path_to_model)
self.client.generate_private_and_evaluation_keys()
self.evaluation_key = self.client.get_serialized_evaluation_keys()
def fhe_inference(self, x):
enc_x = self.client.quantize_encrypt_serialize(x)
enc_y = self.server.run(enc_x, self.evaluation_key)
y = self.client.deserialize_decrypt_dequantize(enc_y)
return y
def __call__(self, text: str):
text = self.preprocess_sentences(text)
identified_words_with_prob = [] # tuples of (word, probability)
new_text = []
for word in text.split():
# Prediction for each word
x = self.embeddings_model.wv[word][None]
prediction_proba = self.fhe_ner_detection.predict_proba(x)
# prediction = self.fhe_inference(x).argmax(1)[0]
# print(word, prediction)
probability = prediction_proba[0][1]
prediction = probability >= 0.5
if prediction == 1:
identified_words_with_prob.append((word, probability))
new_text.append("<REMOVED>")
else:
new_text.append(word)
# Joining the modified text
modified_text = " ".join(new_text)
return modified_text, identified_words_with_prob
def preprocess_sentences(self, sentence, verbose=False):
"""Preprocess the sentence."""
sentence = re.sub(r"\n+", " ", sentence)
if verbose:
print(sentence)
sentence = re.sub(" +", " ", sentence)
if verbose:
print(sentence)
sentence = re.sub(r"'s\b", " s", sentence)
if verbose:
print(sentence)
sentence = re.sub(r"\s([,.!?;:])", r"\1", sentence)
if verbose:
print(sentence)
pattern = r"(?<!\w)[{}]|[{}](?!\w)".format(
re.escape(self.punctuation_list), re.escape(self.punctuation_list)
)
sentence = re.sub(pattern, "", sentence)
if verbose:
print(sentence)
sentence = re.sub(r"\s([,.!?;:])", r"\1", sentence)
if verbose:
print(sentence)
return sentence