Thomas Chardonnens commited on
Commit
127130c
1 Parent(s): cbb90a5

baseline, wip

Browse files
app.py CHANGED
@@ -1,7 +1,437 @@
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A local gradio app that detects seizures with EEG using FHE."""
2
+ from PIL import Image
3
+ import os
4
+ import shutil
5
+ import subprocess
6
+ import time
7
  import gradio as gr
8
+ import numpy
9
+ import requests
10
+ from itertools import chain
11
 
12
+ from common import (
13
+ CLIENT_TMP_PATH,
14
+ SERVER_TMP_PATH,
15
+ EXAMPLES,
16
+ INPUT_SHAPE,
17
+ KEYS_PATH,
18
+ REPO_DIR,
19
+ SERVER_URL,
20
+ )
21
+ from client_server_interface import FHEClient
22
 
23
+ # Uncomment here to have both the server and client in the same terminal
24
+ subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
25
+ time.sleep(3)
26
+
27
+ def shorten_bytes_object(bytes_object, limit=500):
28
+ """Shorten the input bytes object to a given length.
29
+
30
+ Encrypted data is too large for displaying it in the browser using Gradio. This function
31
+ provides a shorten representation of it.
32
+
33
+ Args:
34
+ bytes_object (bytes): The input to shorten
35
+ limit (int): The length to consider. Default to 500.
36
+
37
+ Returns:
38
+ str: Hexadecimal string shorten representation of the input byte object.
39
+
40
+ """
41
+ # Define a shift for better display
42
+ shift = 100
43
+ return bytes_object[shift : limit + shift].hex()
44
+
45
+ def get_client(user_id):
46
+ """Get the client API.
47
+
48
+ Args:
49
+ user_id (int): The current user's ID.
50
+
51
+ Returns:
52
+ FHEClient: The client API.
53
+ """
54
+ return FHEClient(
55
+ key_dir=KEYS_PATH / f"seizure_detection_{user_id}",
56
+ )
57
+
58
+ def get_client_file_path(name, user_id):
59
+ """Get the correct temporary file path for the client.
60
+
61
+ Args:
62
+ name (str): The desired file name.
63
+ user_id (int): The current user's ID.
64
+
65
+ Returns:
66
+ pathlib.Path: The file path.
67
+ """
68
+ return CLIENT_TMP_PATH / f"{name}_seizure_detection_{user_id}"
69
+
70
+ def clean_temporary_files(n_keys=20):
71
+ """Clean keys and encrypted images.
72
+
73
+ A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this
74
+ limit is reached, the oldest files are deleted.
75
+
76
+ Args:
77
+ n_keys (int): The maximum number of keys and associated files to be stored. Default to 20.
78
+
79
+ """
80
+ # Get the oldest key files in the key directory
81
+ key_dirs = sorted(KEYS_PATH.iterdir(), key=os.path.getmtime)
82
+
83
+ # If more than n_keys keys are found, remove the oldest
84
+ user_ids = []
85
+ if len(key_dirs) > n_keys:
86
+ n_keys_to_delete = len(key_dirs) - n_keys
87
+ for key_dir in key_dirs[:n_keys_to_delete]:
88
+ user_ids.append(key_dir.name)
89
+ shutil.rmtree(key_dir)
90
+
91
+ # Get all the encrypted objects in the temporary folder
92
+ client_files = CLIENT_TMP_PATH.iterdir()
93
+ server_files = SERVER_TMP_PATH.iterdir()
94
+
95
+ # Delete all files related to the ids whose keys were deleted
96
+ for file in chain(client_files, server_files):
97
+ for user_id in user_ids:
98
+ if user_id in file.name:
99
+ file.unlink()
100
+
101
+ def keygen():
102
+ """Generate the private key for seizure detection.
103
+
104
+ Returns:
105
+ (user_id, True) (Tuple[int, bool]): The current user's ID and a boolean used for visual display.
106
+
107
+ """
108
+ # Clean temporary files
109
+ clean_temporary_files()
110
+
111
+ # Create an ID for the current user
112
+ user_id = numpy.random.randint(0, 2**32)
113
+
114
+ # Retrieve the client API
115
+ client = get_client(user_id)
116
+
117
+ # Generate a private key
118
+ client.generate_private_and_evaluation_keys(force=True)
119
+
120
+ # Retrieve the serialized evaluation key
121
+ evaluation_key = client.get_serialized_evaluation_keys()
122
+
123
+ # Save evaluation_key as bytes in a file as it is too large to pass through regular Gradio
124
+ # buttons (see https://github.com/gradio-app/gradio/issues/1877)
125
+ evaluation_key_path = get_client_file_path("evaluation_key", user_id)
126
+
127
+ with evaluation_key_path.open("wb") as evaluation_key_file:
128
+ evaluation_key_file.write(evaluation_key)
129
+
130
+ return (user_id, True)
131
+
132
+ def encrypt(user_id, input_image):
133
+ """Encrypt the given image for seizure detection.
134
+
135
+ Args:
136
+ user_id (int): The current user's ID.
137
+ input_image (numpy.ndarray): The image to encrypt.
138
+
139
+ Returns:
140
+ (input_image, encrypted_image_short) (Tuple[bytes]): The encrypted image and one of its
141
+ representation.
142
+
143
+ """
144
+ if user_id == "":
145
+ raise gr.Error("Please generate the private key first.")
146
+
147
+ if input_image is None:
148
+ raise gr.Error("Please choose an image first.")
149
+
150
+ if input_image.shape[-1] != 3:
151
+ raise ValueError(f"Input image must have 3 channels (RGB). Current shape: {input_image.shape}")
152
+
153
+ # Resize the image if it hasn't the shape (224, 224, 3)
154
+ if input_image.shape != (224, 224, 3):
155
+ input_image_pil = Image.fromarray(input_image)
156
+ input_image_pil = input_image_pil.resize((224, 224))
157
+ input_image = numpy.array(input_image_pil)
158
+
159
+ # Retrieve the client API
160
+ client = get_client(user_id)
161
+
162
+ # Pre-process, encrypt and serialize the image
163
+ encrypted_image = client.encrypt_serialize(input_image)
164
+
165
+ # Save encrypted_image to bytes in a file, since too large to pass through regular Gradio
166
+ # buttons, https://github.com/gradio-app/gradio/issues/1877
167
+ encrypted_image_path = get_client_file_path("encrypted_image", user_id)
168
+
169
+ with encrypted_image_path.open("wb") as encrypted_image_file:
170
+ encrypted_image_file.write(encrypted_image)
171
+
172
+ # Create a truncated version of the encrypted image for display
173
+ encrypted_image_short = shorten_bytes_object(encrypted_image)
174
+
175
+ return (resize_img(input_image), encrypted_image_short)
176
+
177
+ def send_input(user_id):
178
+ """Send the encrypted input image as well as the evaluation key to the server.
179
+
180
+ Args:
181
+ user_id (int): The current user's ID.
182
+ """
183
+ # Get the evaluation key path
184
+ evaluation_key_path = get_client_file_path("evaluation_key", user_id)
185
+
186
+ if user_id == "" or not evaluation_key_path.is_file():
187
+ raise gr.Error("Please generate the private key first.")
188
+
189
+ encrypted_input_path = get_client_file_path("encrypted_image", user_id)
190
+
191
+ if not encrypted_input_path.is_file():
192
+ raise gr.Error("Please generate the private key and then encrypt an image first.")
193
+
194
+ # Define the data and files to post
195
+ data = {
196
+ "user_id": user_id,
197
+ }
198
+
199
+ files = [
200
+ ("files", open(encrypted_input_path, "rb")),
201
+ ("files", open(evaluation_key_path, "rb")),
202
+ ]
203
+
204
+ # Send the encrypted input image and evaluation key to the server
205
+ url = SERVER_URL + "send_input"
206
+ with requests.post(
207
+ url=url,
208
+ data=data,
209
+ files=files,
210
+ ) as response:
211
+ return response.ok
212
+
213
+ def run_fhe(user_id):
214
+ """Apply the seizure detection model on the encrypted image previously sent using FHE.
215
+
216
+ Args:
217
+ user_id (int): The current user's ID.
218
+ """
219
+ data = {
220
+ "user_id": user_id,
221
+ }
222
+
223
+ # Trigger the FHE execution on the encrypted image previously sent
224
+ url = SERVER_URL + "run_fhe"
225
+ with requests.post(
226
+ url=url,
227
+ data=data,
228
+ ) as response:
229
+ if response.ok:
230
+ return response.json()
231
+ else:
232
+ raise gr.Error("Please wait for the input image to be sent to the server.")
233
+
234
+ def get_output(user_id):
235
+ """Retrieve the encrypted output (boolean).
236
+
237
+ Args:
238
+ user_id (int): The current user's ID.
239
+
240
+ Returns:
241
+ encrypted_output_short (bytes): A representation of the encrypted result.
242
+
243
+ """
244
+ data = {
245
+ "user_id": user_id,
246
+ }
247
+
248
+ # Retrieve the encrypted output
249
+ url = SERVER_URL + "get_output"
250
+ with requests.post(
251
+ url=url,
252
+ data=data,
253
+ ) as response:
254
+ if response.ok:
255
+ encrypted_output = response.content
256
+
257
+ # Save the encrypted output to bytes in a file as it is too large to pass through regular
258
+ # Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877)
259
+ encrypted_output_path = get_client_file_path("encrypted_output", user_id)
260
+
261
+ with encrypted_output_path.open("wb") as encrypted_output_file:
262
+ encrypted_output_file.write(encrypted_output)
263
+
264
+ # Create a truncated version of the encrypted output for display
265
+ encrypted_output_short = shorten_bytes_object(encrypted_output)
266
+
267
+ return encrypted_output_short
268
+ else:
269
+ raise gr.Error("Please wait for the FHE execution to be completed.")
270
+
271
+ def decrypt_output(user_id):
272
+ """Decrypt the result.
273
+
274
+ Args:
275
+ user_id (int): The current user's ID.
276
+
277
+ Returns:
278
+ bool: The decrypted output (True if seizure detected, False otherwise)
279
+
280
+ """
281
+ if user_id == "":
282
+ raise gr.Error("Please generate the private key first.")
283
+
284
+ # Get the encrypted output path
285
+ encrypted_output_path = get_client_file_path("encrypted_output", user_id)
286
+
287
+ if not encrypted_output_path.is_file():
288
+ raise gr.Error("Please run the FHE execution first.")
289
+
290
+ # Load the encrypted output as bytes
291
+ with encrypted_output_path.open("rb") as encrypted_output_file:
292
+ encrypted_output = encrypted_output_file.read()
293
+
294
+ # Retrieve the client API
295
+ client = get_client(user_id)
296
+
297
+ # Deserialize, decrypt and post-process the encrypted output
298
+ decrypted_output = client.deserialize_decrypt_post_process(encrypted_output)
299
+
300
+ return "Seizure detected" if decrypted_output else "No seizure detected"
301
+
302
+ def resize_img(img, width=256, height=256):
303
+ """Resize the image."""
304
+ if img.dtype != numpy.uint8:
305
+ img = img.astype(numpy.uint8)
306
+ img_pil = Image.fromarray(img)
307
+ # Resize the image
308
+ resized_img_pil = img_pil.resize((width, height))
309
+ # Convert back to a NumPy array
310
+ return numpy.array(resized_img_pil)
311
+
312
+ demo = gr.Blocks()
313
+
314
+ print("Starting the demo...")
315
+ with demo:
316
+ gr.Markdown(
317
+ """
318
+ <h1 align="center">Seizure Detection on Encrypted EEG Data Using Fully Homomorphic Encryption</h1>
319
+ """
320
+ )
321
+
322
+ gr.Markdown("## Client side")
323
+ gr.Markdown("### Step 1: Upload an EEG image. ")
324
+ gr.Markdown(
325
+ f"The image will automatically be resized to shape (224x224). "
326
+ "The image here, however, is displayed in its original resolution."
327
+ )
328
+ with gr.Row():
329
+ input_image = gr.Image(
330
+ value=None, label="Upload an EEG image here.", height=256,
331
+ width=256, sources="upload", interactive=True,
332
+ )
333
+
334
+ examples = gr.Examples(
335
+ examples=EXAMPLES, inputs=[input_image], examples_per_page=5, label="Examples to use."
336
+ )
337
+
338
+ gr.Markdown("### Step 2: Generate the private key.")
339
+ keygen_button = gr.Button("Generate the private key.")
340
+
341
+ with gr.Row():
342
+ keygen_checkbox = gr.Checkbox(label="Private key generated:", interactive=False)
343
+
344
+ user_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
345
+
346
+ gr.Markdown("### Step 3: Encrypt the image using FHE.")
347
+ encrypt_button = gr.Button("Encrypt the image using FHE.")
348
+
349
+ with gr.Row():
350
+ encrypted_input = gr.Textbox(
351
+ label="Encrypted input representation:", max_lines=2, interactive=False
352
+ )
353
+
354
+ gr.Markdown("## Server side")
355
+ gr.Markdown(
356
+ "The encrypted value is received by the server. The server can then compute the seizure "
357
+ "detection directly over encrypted values. Once the computation is finished, the server returns "
358
+ "the encrypted results to the client."
359
+ )
360
+
361
+ gr.Markdown("### Step 4: Send the encrypted image to the server.")
362
+ send_input_button = gr.Button("Send the encrypted image to the server.")
363
+ send_input_checkbox = gr.Checkbox(label="Encrypted image sent.", interactive=False)
364
+
365
+ gr.Markdown("### Step 5: Run FHE execution.")
366
+ execute_fhe_button = gr.Button("Run FHE execution.")
367
+ fhe_execution_time = gr.Textbox(
368
+ label="Total FHE execution time (in seconds):", max_lines=1, interactive=False
369
+ )
370
+
371
+ gr.Markdown("### Step 6: Receive the encrypted output from the server.")
372
+ get_output_button = gr.Button("Receive the encrypted output from the server.")
373
+
374
+ with gr.Row():
375
+ encrypted_output = gr.Textbox(
376
+ label="Encrypted output representation:",
377
+ max_lines=2,
378
+ interactive=False
379
+ )
380
+
381
+ gr.Markdown("## Client side")
382
+ gr.Markdown(
383
+ "The encrypted output is sent back to the client, who can finally decrypt it with the "
384
+ "private key. Only the client is aware of the original image and the detection result."
385
+ )
386
+
387
+ gr.Markdown("### Step 7: Decrypt the output.")
388
+ decrypt_button = gr.Button("Decrypt the output")
389
+
390
+ with gr.Row():
391
+ decrypted_output = gr.Textbox(
392
+ label="Seizure detection result:",
393
+ interactive=False
394
+ )
395
+
396
+ # Button to generate the private key
397
+ keygen_button.click(
398
+ keygen,
399
+ outputs=[user_id, keygen_checkbox],
400
+ )
401
+
402
+ # Button to encrypt inputs on the client side
403
+ encrypt_button.click(
404
+ encrypt,
405
+ inputs=[user_id, input_image],
406
+ outputs=[input_image, encrypted_input],
407
+ )
408
+
409
+ # Button to send the encodings to the server using post method
410
+ send_input_button.click(
411
+ send_input, inputs=[user_id], outputs=[send_input_checkbox]
412
+ )
413
+
414
+ # Button to send the encodings to the server using post method
415
+ execute_fhe_button.click(run_fhe, inputs=[user_id], outputs=[fhe_execution_time])
416
+
417
+ # Button to send the encodings to the server using post method
418
+ get_output_button.click(
419
+ get_output,
420
+ inputs=[user_id],
421
+ outputs=[encrypted_output]
422
+ )
423
+
424
+ # Button to decrypt the output on the client side
425
+ decrypt_button.click(
426
+ decrypt_output,
427
+ inputs=[user_id],
428
+ outputs=[decrypted_output],
429
+ )
430
+
431
+ gr.Markdown(
432
+ "The app was built with [Concrete-ML](https://github.com/zama-ai/concrete-ml), a "
433
+ "Privacy-Preserving Machine Learning (PPML) open-source set of tools by [Zama](https://zama.ai/). "
434
+ "Try it yourself and don't forget to star on Github &#11088;."
435
+ )
436
+
437
+ demo.launch(share=False)
client_server_interface.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Client-server interface custom implementation for seizure detection models."
2
+
3
+ from concrete import fhe
4
+
5
+ from seizure_detection import SeizureDetector
6
+
7
+
8
+ class FHEServer:
9
+ """Server interface to run a FHE circuit for seizure detection."""
10
+
11
+ def __init__(self, model_path):
12
+ """Initialize the FHE interface.
13
+
14
+ Args:
15
+ model_path (Path): The path to the directory where the circuit is saved.
16
+ """
17
+ self.model_path = model_path
18
+
19
+ # Load the FHE circuit
20
+ self.server = fhe.Server.load(self.model_path / "server.zip")
21
+
22
+ def run(self, serialized_encrypted_image, serialized_evaluation_keys):
23
+ """Run seizure detection on the server over an encrypted image.
24
+
25
+ Args:
26
+ serialized_encrypted_image (bytes): The encrypted and serialized image.
27
+ serialized_evaluation_keys (bytes): The serialized evaluation keys.
28
+
29
+ Returns:
30
+ bytes: The encrypted boolean output indicating seizure detection.
31
+ """
32
+ # Deserialize the encrypted input image and the evaluation keys
33
+ encrypted_image = fhe.Value.deserialize(serialized_encrypted_image)
34
+ evaluation_keys = fhe.EvaluationKeys.deserialize(serialized_evaluation_keys)
35
+
36
+ # Execute the seizure detection in FHE
37
+ encrypted_output = self.server.run(encrypted_image, evaluation_keys=evaluation_keys)
38
+
39
+ # Serialize the encrypted output
40
+ serialized_encrypted_output = encrypted_output.serialize()
41
+
42
+ return serialized_encrypted_output
43
+
44
+
45
+ class FHEDev:
46
+ """Development interface to save and load the seizure detection model."""
47
+
48
+ def __init__(self, seizure_detector, model_path):
49
+ """Initialize the FHE interface.
50
+
51
+ Args:
52
+ seizure_detector (SeizureDetector): The seizure detection model to use in the FHE interface.
53
+ model_path (str): The path to the directory where the circuit is saved.
54
+ """
55
+
56
+ self.seizure_detector = seizure_detector
57
+ self.model_path = model_path
58
+
59
+ self.model_path.mkdir(parents=True, exist_ok=True)
60
+
61
+ def save(self):
62
+ """Export all needed artifacts for the client and server interfaces."""
63
+
64
+ assert self.seizure_detector.fhe_circuit is not None, (
65
+ "The model must be compiled before saving it."
66
+ )
67
+
68
+ # Save the circuit for the server, using the via_mlir in order to handle cross-platform
69
+ # execution
70
+ path_circuit_server = self.model_path / "server.zip"
71
+ self.seizure_detector.fhe_circuit.server.save(path_circuit_server, via_mlir=True)
72
+
73
+ # Save the circuit for the client
74
+ path_circuit_client = self.model_path / "client.zip"
75
+ self.seizure_detector.fhe_circuit.client.save(path_circuit_client)
76
+
77
+
78
+ class FHEClient:
79
+ """Client interface to encrypt and decrypt FHE data associated to a SeizureDetector."""
80
+
81
+ def __init__(self, model_path, key_dir=None):
82
+ """Initialize the FHE interface.
83
+
84
+ Args:
85
+ model_path (Path): The path to the directory where the circuit is saved.
86
+ key_dir (Path): The path to the directory where the keys are stored. Default to None.
87
+ """
88
+ self.model_path = model_path
89
+ self.key_dir = key_dir
90
+
91
+ # If model_path does not exist raise
92
+ assert model_path.exists(), f"{model_path} does not exist. Please specify a valid path."
93
+
94
+ # Load the client
95
+ self.client = fhe.Client.load(self.model_path / "client.zip", self.key_dir)
96
+
97
+ # Instantiate the seizure detector
98
+ self.seizure_detector = SeizureDetector()
99
+
100
+ def generate_private_and_evaluation_keys(self, force=False):
101
+ """Generate the private and evaluation keys.
102
+
103
+ Args:
104
+ force (bool): If True, regenerate the keys even if they already exist.
105
+ """
106
+ self.client.keygen(force)
107
+
108
+ def get_serialized_evaluation_keys(self):
109
+ """Get the serialized evaluation keys.
110
+
111
+ Returns:
112
+ bytes: The evaluation keys.
113
+ """
114
+ return self.client.evaluation_keys.serialize()
115
+
116
+ def encrypt_serialize(self, input_image):
117
+ """Encrypt and serialize the input image in the clear.
118
+
119
+ Args:
120
+ input_image (numpy.ndarray): The image to encrypt and serialize.
121
+
122
+ Returns:
123
+ bytes: The pre-processed, encrypted and serialized image.
124
+ """
125
+ # Encrypt the image
126
+ encrypted_image = self.client.encrypt(input_image)
127
+
128
+ # Serialize the encrypted image to be sent to the server
129
+ serialized_encrypted_image = encrypted_image.serialize()
130
+ return serialized_encrypted_image
131
+
132
+ def deserialize_decrypt_post_process(self, serialized_encrypted_output):
133
+ """Deserialize, decrypt and post-process the output in the clear.
134
+
135
+ Args:
136
+ serialized_encrypted_output (bytes): The serialized and encrypted output.
137
+
138
+ Returns:
139
+ bool: The decrypted and deserialized boolean indicating seizure detection.
140
+ """
141
+ # Deserialize the encrypted output
142
+ encrypted_output = fhe.Value.deserialize(serialized_encrypted_output)
143
+
144
+ # Decrypt the output
145
+ output = self.client.decrypt(encrypted_output)
146
+
147
+ # Post-process the output (if needed)
148
+ seizure_detected = self.seizure_detector.post_processing(output)
149
+
150
+ return seizure_detected
common.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "All the constants used in this repo."
2
+
3
+ from pathlib import Path
4
+
5
+ # This repository's directory
6
+ REPO_DIR = Path(__file__).parent
7
+
8
+ # This repository's main necessary folders
9
+ FILTERS_PATH = REPO_DIR / "filters"
10
+ KEYS_PATH = REPO_DIR / ".fhe_keys"
11
+ CLIENT_TMP_PATH = REPO_DIR / "client_tmp"
12
+ SERVER_TMP_PATH = REPO_DIR / "server_tmp"
13
+
14
+ # Create the necessary folders
15
+ KEYS_PATH.mkdir(exist_ok=True)
16
+ CLIENT_TMP_PATH.mkdir(exist_ok=True)
17
+ SERVER_TMP_PATH.mkdir(exist_ok=True)
18
+
19
+ # The input images' shape. Images with different input shapes will be cropped and resized by Gradio
20
+ INPUT_SHAPE = (224, 224)
21
+
22
+ # Retrieve the input examples directory
23
+ INPUT_EXAMPLES_DIR = REPO_DIR / "input_examples"
24
+
25
+ # List of all image examples suggested in the demo
26
+ EXAMPLES = [str(image) for image in INPUT_EXAMPLES_DIR.glob("**/*")]
27
+
28
+ # Store the server's URL
29
+ SERVER_URL = "http://localhost:8000/"
input_examples/eeg-1.png ADDED
input_examples/eeg-2.png ADDED
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ concrete-ml==1.1.0
2
+ gradio
seizure_detection.py ADDED
File without changes
server.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Server that will listen for GET and POST requests from the client."""
2
+
3
+ import time
4
+ from typing import List
5
+ from fastapi import FastAPI, File, Form, UploadFile
6
+ from fastapi.responses import JSONResponse, Response
7
+
8
+ from common import SERVER_TMP_PATH
9
+ from client_server_interface import FHEServer
10
+
11
+ # Load the server object for seizure detection
12
+ FHE_SERVER = FHEServer(model_path="path/to/seizure_detection_model")
13
+
14
+ def get_server_file_path(name, user_id):
15
+ """Get the correct temporary file path for the server.
16
+
17
+ Args:
18
+ name (str): The desired file name.
19
+ user_id (int): The current user's ID.
20
+
21
+ Returns:
22
+ pathlib.Path: The file path.
23
+ """
24
+ return SERVER_TMP_PATH / f"{name}_seizure_detection_{user_id}"
25
+
26
+
27
+ # Initialize an instance of FastAPI
28
+ app = FastAPI()
29
+
30
+ # Define the default route
31
+ @app.get("/")
32
+ def root():
33
+ return {"message": "Welcome to Your Seizure Detection FHE Server!"}
34
+
35
+
36
+ @app.post("/send_input")
37
+ def send_input(
38
+ user_id: str = Form(),
39
+ files: List[UploadFile] = File(),
40
+ ):
41
+ """Send the inputs to the server."""
42
+ # Retrieve the encrypted input image and the evaluation key paths
43
+ encrypted_image_path = get_server_file_path("encrypted_image", user_id)
44
+ evaluation_key_path = get_server_file_path("evaluation_key", user_id)
45
+
46
+ # Write the files using the above paths
47
+ with encrypted_image_path.open("wb") as encrypted_image, evaluation_key_path.open(
48
+ "wb"
49
+ ) as evaluation_key:
50
+ encrypted_image.write(files[0].file.read())
51
+ evaluation_key.write(files[1].file.read())
52
+
53
+
54
+ @app.post("/run_fhe")
55
+ def run_fhe(
56
+ user_id: str = Form(),
57
+ ):
58
+ """Execute seizure detection on the encrypted input image using FHE."""
59
+ # Retrieve the encrypted input image and the evaluation key paths
60
+ encrypted_image_path = get_server_file_path("encrypted_image", user_id)
61
+ evaluation_key_path = get_server_file_path("evaluation_key", user_id)
62
+
63
+ # Read the files using the above paths
64
+ with encrypted_image_path.open("rb") as encrypted_image_file, evaluation_key_path.open(
65
+ "rb"
66
+ ) as evaluation_key_file:
67
+ encrypted_image = encrypted_image_file.read()
68
+ evaluation_key = evaluation_key_file.read()
69
+
70
+ # Run the FHE execution
71
+ start = time.time()
72
+ encrypted_output = FHE_SERVER.run(encrypted_image, evaluation_key)
73
+ fhe_execution_time = round(time.time() - start, 2)
74
+
75
+ # Retrieve the encrypted output path
76
+ encrypted_output_path = get_server_file_path("encrypted_output", user_id)
77
+
78
+ # Write the file using the above path
79
+ with encrypted_output_path.open("wb") as encrypted_output_file:
80
+ encrypted_output_file.write(encrypted_output)
81
+
82
+ return JSONResponse(content=fhe_execution_time)
83
+
84
+
85
+ @app.post("/get_output")
86
+ def get_output(
87
+ user_id: str = Form(),
88
+ ):
89
+ """Retrieve the encrypted output."""
90
+ # Retrieve the encrypted output path
91
+ encrypted_output_path = get_server_file_path("encrypted_output", user_id)
92
+
93
+ # Read the file using the above path
94
+ with encrypted_output_path.open("rb") as encrypted_output_file:
95
+ encrypted_output = encrypted_output_file.read()
96
+
97
+ return Response(encrypted_output)