File size: 5,433 Bytes
127130c
 
fbd9a75
127130c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4040d43
127130c
 
 
 
 
 
fbd9a75
127130c
 
 
4040d43
127130c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"Client-server interface custom implementation for seizure detection models."

from common import SEIZURE_DETECTION_MODEL_PATH
from concrete import fhe

from seizure_detection import SeizureDetector


class FHEServer:
    """Server interface to run a FHE circuit for seizure detection."""

    def __init__(self, model_path):
        """Initialize the FHE interface.

        Args:
            model_path (Path): The path to the directory where the circuit is saved.
        """
        self.model_path = model_path

        # Load the FHE circuit
        self.server = fhe.Server.load(self.model_path / "server.zip")

    def run(self, serialized_encrypted_image, serialized_evaluation_keys):
        """Run seizure detection on the server over an encrypted image.

        Args:
            serialized_encrypted_image (bytes): The encrypted and serialized image.
            serialized_evaluation_keys (bytes): The serialized evaluation keys.

        Returns:
            bytes: The encrypted boolean output indicating seizure detection.
        """
        # Deserialize the encrypted input image and the evaluation keys
        encrypted_image = fhe.Value.deserialize(serialized_encrypted_image)
        evaluation_keys = fhe.EvaluationKeys.deserialize(serialized_evaluation_keys)

        # Execute the seizure detection in FHE
        encrypted_output = self.server.run(encrypted_image, evaluation_keys=evaluation_keys)

        # Serialize the encrypted output
        serialized_encrypted_output = encrypted_output.serialize()

        return serialized_encrypted_output


class FHEDev:
    """Development interface to save and load the seizure detection model."""

    def __init__(self, seizure_detector, model_path):
        """Initialize the FHE interface.

        Args:
            seizure_detector (SeizureDetector): The seizure detection model to use in the FHE interface.
            model_path (str): The path to the directory where the circuit is saved.
        """

        self.seizure_detector = seizure_detector
        self.model_path = model_path

        self.model_path.mkdir(parents=True, exist_ok=True)

    def save(self):
        """Export all needed artifacts for the client and server interfaces."""

        assert self.seizure_detector.fhe_circuit is not None, (
            "The model must be compiled before saving it."
        )

        # Save the circuit for the server, using the via_mlir in order to handle cross-platform
        # execution
        path_circuit_server = self.model_path / "server.zip"
        self.seizure_detector.fhe_circuit.server.save(path_circuit_server, via_mlir=True)

        # Save the circuit for the client
        path_circuit_client = self.model_path / "client.zip"
        self.seizure_detector.fhe_circuit.client.save(path_circuit_client)


class FHEClient:
    """Client interface to encrypt and decrypt FHE data associated to a SeizureDetector."""

    def __init__(self, key_dir=None):
        """Initialize the FHE interface.

        Args:
            model_path (Path): The path to the directory where the circuit is saved.
            key_dir (Path): The path to the directory where the keys are stored. Default to None.
        """
        self.model_path = SEIZURE_DETECTION_MODEL_PATH
        self.key_dir = key_dir

        # If model_path does not exist raise
        assert self.model_path.exists(), f"{self.model_path} does not exist. Please specify a valid path."

        # Load the client
        self.client = fhe.Client.load(self.model_path / "client.zip", self.key_dir)

        # Instantiate the seizure detector
        self.seizure_detector = SeizureDetector()

    def generate_private_and_evaluation_keys(self, force=False):
        """Generate the private and evaluation keys.

        Args:
            force (bool): If True, regenerate the keys even if they already exist.
        """
        self.client.keygen(force)

    def get_serialized_evaluation_keys(self):
        """Get the serialized evaluation keys.

        Returns:
            bytes: The evaluation keys.
        """
        return self.client.evaluation_keys.serialize()

    def encrypt_serialize(self, input_image):
        """Encrypt and serialize the input image in the clear.

        Args:
            input_image (numpy.ndarray): The image to encrypt and serialize.

        Returns:
            bytes: The pre-processed, encrypted and serialized image.
        """
        # Encrypt the image
        encrypted_image = self.client.encrypt(input_image)

        # Serialize the encrypted image to be sent to the server
        serialized_encrypted_image = encrypted_image.serialize()
        return serialized_encrypted_image

    def deserialize_decrypt_post_process(self, serialized_encrypted_output):
        """Deserialize, decrypt and post-process the output in the clear.

        Args:
            serialized_encrypted_output (bytes): The serialized and encrypted output.

        Returns:
            bool: The decrypted and deserialized boolean indicating seizure detection.
        """
        # Deserialize the encrypted output
        encrypted_output = fhe.Value.deserialize(serialized_encrypted_output)

        # Decrypt the output
        output = self.client.decrypt(encrypted_output)

        # Post-process the output (if needed)
        seizure_detected = self.seizure_detector.post_processing(output)

        return seizure_detected