import gensim import re from concrete.ml.deployment import FHEModelClient, FHEModelServer from pathlib import Path from concrete.ml.common.serialization.loaders import load import uuid import json from transformers import AutoTokenizer, AutoModel from utils_demo import get_batch_text_representation base_dir = Path(__file__).parent class FHEAnonymizer: def __init__(self, punctuation_list=".,!?:;"): # Load tokenizer and model, move model to the selected device self.tokenizer = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2") self.embeddings_model = AutoModel.from_pretrained("obi/deid_roberta_i2b2") self.punctuation_list = punctuation_list with open(base_dir / "models/without_pronoun_cml_xgboost.model", "r") as model_file: self.fhe_ner_detection = load(file=model_file) with open(base_dir / "original_document_uuid_mapping.json", 'r') as file: self.uuid_map = json.load(file) path_to_model = (base_dir / "deployment").resolve() self.client = FHEModelClient(path_to_model) self.server = FHEModelServer(path_to_model) self.client.generate_private_and_evaluation_keys() self.evaluation_key = self.client.get_serialized_evaluation_keys() def fhe_inference(self, x): enc_x = self.client.quantize_encrypt_serialize(x) enc_y = self.server.run(enc_x, self.evaluation_key) y = self.client.deserialize_decrypt_dequantize(enc_y) return y def __call__(self, text: str): # Pattern to identify words and non-words (including punctuation, spaces, etc.) token_pattern = r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)" tokens = re.findall(token_pattern, text) identified_words_with_prob = [] processed_tokens = [] print(tokens) for token in tokens: # Directly append non-word tokens or whitespace to processed_tokens if not token.strip() or not re.match(r"\w+", token): processed_tokens.append(token) continue # Prediction for each word x = get_batch_text_representation([token], self.embeddings_model, self.tokenizer) # prediction_proba = self.fhe_ner_detection.predict_proba(x) prediction_proba = self.fhe_inference(x) probability = prediction_proba[0][1] if probability >= 0.5: identified_words_with_prob.append((token, probability)) # Use the existing UUID if available, otherwise generate a new one tmp_uuid = self.uuid_map.get(token, str(uuid.uuid4())[:8]) processed_tokens.append(tmp_uuid) self.uuid_map[token] = tmp_uuid else: processed_tokens.append(token) # Reconstruct the sentence reconstructed_sentence = ''.join(processed_tokens) return reconstructed_sentence, identified_words_with_prob