import json
import re
import time
import uuid
from pathlib import Path

from transformers import AutoModel, AutoTokenizer
from utils_demo import *

from concrete.ml.common.serialization.loaders import load
from concrete.ml.deployment import FHEModelClient, FHEModelServer

TOLERANCE_PROBA = 0.77

CURRENT_DIR = Path(__file__).parent

DEPLOYMENT_DIR = CURRENT_DIR / "deployment"
KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys"


class FHEAnonymizer:
    def __init__(self):

        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained("obi/deid_roberta_i2b2")
        self.embeddings_model = AutoModel.from_pretrained("obi/deid_roberta_i2b2")

        self.punctuation_list = PUNCTUATION_LIST
        self.uuid_map = read_json(MAPPING_UUID_PATH)

        self.client = FHEModelClient(DEPLOYMENT_DIR, key_dir=KEYS_DIR)
        self.server = FHEModelServer(DEPLOYMENT_DIR)

    def generate_key(self):

        clean_directory()

        # Creates the private and evaluation keys on the client side
        self.client.generate_private_and_evaluation_keys()

        # Get the serialized evaluation keys
        self.evaluation_key = self.client.get_serialized_evaluation_keys()
        assert isinstance(self.evaluation_key, bytes)

        evaluation_key_path = KEYS_DIR / "evaluation_key"

        with evaluation_key_path.open("wb") as f:
            f.write(self.evaluation_key)

    def encrypt_query(self, text: str):
        # Pattern to identify words and non-words (including punctuation, spaces, etc.)
        tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", text)
        encrypted_tokens = []

        for token in tokens:
            if bool(re.match(r"^\s+$", token)):
                continue
            # Directly append non-word tokens or whitespace to processed_tokens

            # Prediction for each word
            emb_x = get_batch_text_representation([token], self.embeddings_model, self.tokenizer)
            encrypted_x = self.client.quantize_encrypt_serialize(emb_x)
            assert isinstance(encrypted_x, bytes)

            encrypted_tokens.append(encrypted_x)

        write_pickle(KEYS_DIR / f"encrypted_quantized_query", encrypted_tokens)

    def run_server(self):

        encrypted_tokens = read_pickle(KEYS_DIR / f"encrypted_quantized_query")

        encrypted_output, timing = [], []
        for enc_x in encrypted_tokens:
            start_time = time.time()
            enc_y = self.server.run(enc_x, self.evaluation_key)
            timing.append((time.time() - start_time) / 60.0)
            encrypted_output.append(enc_y)

        write_pickle(KEYS_DIR / f"encrypted_output", encrypted_output)
        write_pickle(KEYS_DIR / f"encrypted_timing", timing)

        return encrypted_output, timing

    def decrypt_output(self, text):

        encrypted_output = read_pickle(KEYS_DIR / f"encrypted_output")

        tokens = re.findall(r"(\b[\w\.\/\-@]+\b|[\s,.!?;:'\"-]+)", text)
        decrypted_output, identified_words_with_prob = [], []

        i = 0
        for token in tokens:
            # Directly append non-word tokens or whitespace to processed_tokens
            if bool(re.match(r"^\s+$", token)):
                continue
            else:
                encrypted_token = encrypted_output[i]
                prediction_proba = self.client.deserialize_decrypt_dequantize(encrypted_token)
                probability = prediction_proba[0][1]
                i += 1

                if probability >= TOLERANCE_PROBA:
                    identified_words_with_prob.append((token, probability))

                    # Use the existing UUID if available, otherwise generate a new one
                    tmp_uuid = self.uuid_map.get(token, str(uuid.uuid4())[:8])
                    decrypted_output.append(tmp_uuid)
                    self.uuid_map[token] = tmp_uuid
                else:
                    decrypted_output.append(token)

            # Update the UUID map with query.
            with open(MAPPING_UUID_PATH, "w") as file:
                json.dump(self.uuid_map, file)

        write_pickle(KEYS_DIR / f"reconstructed_sentence", " ".join(decrypted_output))
        write_pickle(KEYS_DIR / f"identified_words_with_prob", identified_words_with_prob)


    def run_server_and_decrypt_output(self, text):
        self.run_server()
        self.decrypt_output(text)