Roman commited on
Commit
21c7197
1 Parent(s): c27c2c1

Add app files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +7 -0
  2. README.md +69 -5
  3. app.py +463 -0
  4. common.py +54 -0
  5. compile.py +47 -0
  6. custom_client_server.py +204 -0
  7. filters.py +359 -0
  8. filters/black and white/deployment/client.zip +3 -0
  9. filters/black and white/deployment/serialized_processing.json +1 -0
  10. filters/black and white/deployment/server.zip +3 -0
  11. filters/black and white/server.onnx +3 -0
  12. filters/black_and_white/deployment/client.zip +3 -0
  13. filters/black_and_white/deployment/serialized_processing.json +1 -0
  14. filters/black_and_white/deployment/server.zip +3 -0
  15. filters/black_and_white/server.onnx +3 -0
  16. filters/blur/deployment/client.zip +3 -0
  17. filters/blur/deployment/serialized_processing.json +1 -0
  18. filters/blur/deployment/server.zip +3 -0
  19. filters/blur/server.onnx +3 -0
  20. filters/identity/deployment/client.zip +3 -0
  21. filters/identity/deployment/serialized_processing.json +1 -0
  22. filters/identity/deployment/server.zip +3 -0
  23. filters/identity/server.onnx +3 -0
  24. filters/inverted/deployment/client.zip +3 -0
  25. filters/inverted/deployment/serialized_processing.json +1 -0
  26. filters/inverted/deployment/server.zip +3 -0
  27. filters/inverted/server.onnx +3 -0
  28. filters/ridge detection/deployment/client.zip +3 -0
  29. filters/ridge detection/deployment/serialized_processing.json +1 -0
  30. filters/ridge detection/deployment/server.zip +3 -0
  31. filters/ridge detection/server.onnx +3 -0
  32. filters/ridge_detection/deployment/client.zip +3 -0
  33. filters/ridge_detection/deployment/serialized_processing.json +1 -0
  34. filters/ridge_detection/deployment/server.zip +3 -0
  35. filters/ridge_detection/server.onnx +3 -0
  36. filters/rotate/deployment/client.zip +3 -0
  37. filters/rotate/deployment/serialized_processing.json +1 -0
  38. filters/rotate/deployment/server.zip +3 -0
  39. filters/rotate/server.onnx +3 -0
  40. filters/sharpen/deployment/client.zip +3 -0
  41. filters/sharpen/deployment/serialized_processing.json +1 -0
  42. filters/sharpen/deployment/server.zip +3 -0
  43. filters/sharpen/server.onnx +3 -0
  44. generate_dev_files.py +40 -0
  45. input_examples/arc.jpg +0 -0
  46. input_examples/book.jpg +0 -0
  47. input_examples/computer.jpg +0 -0
  48. input_examples/tree.jpg +0 -0
  49. input_examples/zama_math.jpg +0 -0
  50. input_examples/zebra.jpg +0 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .venv
2
+ .playground/
3
+ .artifacts
4
+ .fhe_keys
5
+ server_tmp/
6
+ client_tmp/
7
+ .artifacts
README.md CHANGED
@@ -1,12 +1,76 @@
1
  ---
2
- title: Encrypted Image Filtering
3
- emoji: 🐨
4
  colorFrom: yellow
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.16.1
8
  app_file: app.py
9
- pinned: false
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Image Filtering on Encrypted Images using FHE
3
+ emoji: 🥷💬
4
  colorFrom: yellow
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.2
8
  app_file: app.py
9
+ pinned: true
10
+ tags: [FHE, PPML, privacy, privacy preserving machine learning, homomorphic encryption,
11
+ security]
12
+ python_version: 3.8.15
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
16
+
17
+ # Image filtering using FHE
18
+
19
+ ## Running the application on your machine
20
+
21
+ In this directory, ie `image_filtering`, you can do the following steps.
22
+
23
+ ### Do once
24
+
25
+ First, create a virtual env and activate it:
26
+
27
+ <!--pytest-codeblocks:skip-->
28
+
29
+ ```bash
30
+ python3 -m venv .venv
31
+ source .venv/bin/activate
32
+ ```
33
+
34
+ Then, install required packages:
35
+
36
+ <!--pytest-codeblocks:skip-->
37
+
38
+ ```bash
39
+ pip3 install -U pip wheel setuptools --ignore-installed
40
+ pip3 install -r requirements.txt --ignore-installed
41
+ ```
42
+
43
+ If not on Linux, or if you want to compile the FHE filters by yourself:
44
+
45
+ <!--pytest-codeblocks:skip-->
46
+
47
+ ```bash
48
+ python3 compile.py
49
+ ```
50
+
51
+ Check it finish well (with a "Done!").
52
+
53
+ It is also possible to manually add some new filters in `filters.py`. Yet, in order to be able to use
54
+ them interactively in the app, you first need to update the `AVAILABLE_FILTERS` list found in `common.py`
55
+ and then compile them by running :
56
+
57
+ <!--pytest-codeblocks:skip-->
58
+
59
+ ```bash
60
+ python3 generate_dev_filters.py
61
+ ```
62
+
63
+ ## Run the following steps each time you relaunch the application
64
+
65
+ In a terminal, run:
66
+
67
+ <!--pytest-codeblocks:skip-->
68
+
69
+ ```bash
70
+ source .venv/bin/activate
71
+ python3 app.py
72
+ ```
73
+
74
+ ## Interacting with the application
75
+
76
+ Open the given URL link (search for a line like `Running on local URL: http://127.0.0.1:8888/`).
app.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A local gradio app that filters images using FHE."""
2
+
3
+ import os
4
+ import shutil
5
+ import subprocess
6
+ import time
7
+
8
+ import gradio as gr
9
+ import numpy
10
+ import requests
11
+ from common import (
12
+ AVAILABLE_FILTERS,
13
+ CLIENT_TMP_PATH,
14
+ EXAMPLES,
15
+ FILTERS_PATH,
16
+ INPUT_SHAPE,
17
+ KEYS_PATH,
18
+ REPO_DIR,
19
+ SERVER_URL,
20
+ )
21
+ from custom_client_server import CustomFHEClient
22
+
23
+ # Uncomment here to have both the server and client in the same terminal
24
+ subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
25
+ time.sleep(3)
26
+
27
+
28
+ def shorten_bytes_object(bytes_object, limit=500):
29
+ """Shorten the input bytes object to a given length.
30
+
31
+ Encrypted data is too large for displaying it in the browser using Gradio. This function
32
+ provides a shorten representation of it.
33
+
34
+ Args:
35
+ bytes_object (bytes): The input to shorten
36
+ limit (int): The length to consider. Default to 500.
37
+
38
+ Returns:
39
+ Any: The fitted model.
40
+
41
+ """
42
+ # Define a shift for better display
43
+ shift = 100
44
+ return bytes_object[shift : limit + shift].hex()
45
+
46
+
47
+ def get_client(user_id, image_filter):
48
+ """Get the client API.
49
+
50
+ Args:
51
+ user_id (int): The current user's ID.
52
+ image_filter (str): The filter chosen by the user
53
+
54
+ Returns:
55
+ CustomFHEClient: The client API.
56
+ """
57
+ return CustomFHEClient(
58
+ FILTERS_PATH / f"{image_filter}/deployment", KEYS_PATH / f"{image_filter}_{user_id}"
59
+ )
60
+
61
+
62
+ def get_client_file_path(name, user_id, image_filter):
63
+ """Get the correct temporary file path for the client.
64
+
65
+ Args:
66
+ name (str): The desired file name.
67
+ user_id (int): The current user's ID.
68
+ image_filter (str): The filter chosen by the user
69
+
70
+ Returns:
71
+ pathlib.Path: The file path.
72
+ """
73
+ return CLIENT_TMP_PATH / f"{name}_{image_filter}_{user_id}"
74
+
75
+
76
+ def clean_temporary_files(n_keys=20):
77
+ """Clean keys and encrypted images.
78
+
79
+ A maximum of n_keys keys are allowed to be stored. Once this limit is reached, the oldest are
80
+ deleted.
81
+
82
+ Args:
83
+ n_keys (int): The maximum number of keys to be stored. Default to 20.
84
+
85
+ """
86
+ # Get the oldest files in the key directory
87
+ list_files = sorted(KEYS_PATH.iterdir(), key=os.path.getmtime)
88
+
89
+ # If more than n_keys keys are found, remove the oldest
90
+ user_ids = []
91
+ if len(list_files) > n_keys:
92
+ n_files_to_delete = len(list_files) - n_keys
93
+ for p in list_files[:n_files_to_delete]:
94
+ user_ids.append(p.name)
95
+ shutil.rmtree(p)
96
+
97
+ # Get all the encrypted objects in the temporary folder
98
+ list_files_tmp = CLIENT_TMP_PATH.iterdir()
99
+
100
+ # Delete all files related to the current user
101
+ for file in list_files_tmp:
102
+ for user_id in user_ids:
103
+ if file.name.endswith(f"{user_id}.npy"):
104
+ file.unlink()
105
+
106
+
107
+ def keygen(image_filter):
108
+ """Generate the private key associated to a filter.
109
+
110
+ Args:
111
+ image_filter (str): The current filter to consider.
112
+
113
+ Returns:
114
+ (user_id, True) (Tuple[int, bool]): The current user's ID and a boolean used for visual display.
115
+
116
+ """
117
+ # Clean temporary files
118
+ clean_temporary_files()
119
+
120
+ # Create an ID for the current user
121
+ user_id = numpy.random.randint(0, 2**32)
122
+
123
+ # Retrieve the client API
124
+ # Currently, the key generation needs to be done after choosing a filter
125
+ # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2258
126
+ client = get_client(user_id, image_filter)
127
+
128
+ # Generate a private key
129
+ client.generate_private_and_evaluation_keys(force=True)
130
+
131
+ # Retrieve the serialized evaluation key. In this case, as circuits are fully leveled, this
132
+ # evaluation key is empty. However, for software reasons, it is still needed for proper FHE
133
+ # execution
134
+ evaluation_key = client.get_serialized_evaluation_keys()
135
+
136
+ # Save evaluation_key as bytes in a file as it is too large to pass through regular Gradio
137
+ # buttons (see https://github.com/gradio-app/gradio/issues/1877)
138
+ evaluation_key_path = get_client_file_path("evaluation_key", user_id, image_filter)
139
+
140
+ with evaluation_key_path.open("wb") as evaluation_key_file:
141
+ evaluation_key_file.write(evaluation_key)
142
+
143
+ return (user_id, True)
144
+
145
+
146
+ def encrypt(user_id, input_image, image_filter):
147
+ """Encrypt the given image for a specific user and filter.
148
+
149
+ Args:
150
+ user_id (int): The current user's ID.
151
+ input_image (numpy.ndarray): The image to encrypt.
152
+ image_filter (str): The current filter to consider.
153
+
154
+ Returns:
155
+ (input_image, encrypted_image_short) (Tuple[bytes]): The encrypted image and one of its
156
+ representation.
157
+
158
+ """
159
+ if user_id == "":
160
+ raise gr.Error("Please generate the private key first.")
161
+
162
+ # Retrieve the client API
163
+ client = get_client(user_id, image_filter)
164
+
165
+ # Pre-process, encrypt and serialize the image
166
+ encrypted_image = client.pre_process_encrypt_serialize(input_image)
167
+
168
+ # Save encrypted_image to bytes in a file, since too large to pass through regular Gradio
169
+ # buttons, https://github.com/gradio-app/gradio/issues/1877
170
+ encrypted_image_path = get_client_file_path("encrypted_image", user_id, image_filter)
171
+
172
+ with encrypted_image_path.open("wb") as encrypted_image_file:
173
+ encrypted_image_file.write(encrypted_image)
174
+
175
+ # Create a truncated version of the encrypted image for display
176
+ encrypted_image_short = shorten_bytes_object(encrypted_image)
177
+
178
+ return (input_image, encrypted_image_short)
179
+
180
+
181
+ def send_input(user_id, image_filter):
182
+ """Send the encrypted input image as well as the evaluation key to the server.
183
+
184
+ Args:
185
+ user_id (int): The current user's ID.
186
+ image_filter (str): The current filter to consider.
187
+ """
188
+ # Get the evaluation key path
189
+ evaluation_key_path = get_client_file_path("evaluation_key", user_id, image_filter)
190
+
191
+ if user_id == "" or not evaluation_key_path.is_file():
192
+ raise gr.Error("Please generate the private key first.")
193
+
194
+ encrypted_input_path = get_client_file_path("encrypted_image", user_id, image_filter)
195
+
196
+ if not encrypted_input_path.is_file():
197
+ raise gr.Error("Please generate the private key and then encrypt an image first.")
198
+
199
+ # Define the data and files to post
200
+ data = {
201
+ "user_id": user_id,
202
+ "filter": image_filter,
203
+ }
204
+
205
+ files = [
206
+ ("files", open(encrypted_input_path, "rb")),
207
+ ("files", open(evaluation_key_path, "rb")),
208
+ ]
209
+
210
+ # Send the encrypted input image and evaluation key to the server
211
+ url = SERVER_URL + "send_input"
212
+ with requests.post(
213
+ url=url,
214
+ data=data,
215
+ files=files,
216
+ ) as response:
217
+ return response.ok
218
+
219
+
220
+ def run_fhe(user_id, image_filter):
221
+ """Apply the filter on the encrypted image previously sent using FHE.
222
+
223
+ Args:
224
+ user_id (int): The current user's ID.
225
+ image_filter (str): The current filter to consider.
226
+ """
227
+ data = {
228
+ "user_id": user_id,
229
+ "filter": image_filter,
230
+ }
231
+
232
+ # Trigger the FHE execution on the encrypted image previously sent
233
+ url = SERVER_URL + "run_fhe"
234
+ with requests.post(
235
+ url=url,
236
+ data=data,
237
+ ) as response:
238
+ if response.ok:
239
+ return response.json()
240
+ else:
241
+ raise gr.Error("Please wait for the input image to be sent to the server.")
242
+
243
+
244
+ def get_output(user_id, image_filter):
245
+ """Retrieve the encrypted output image.
246
+
247
+ Args:
248
+ user_id (int): The current user's ID.
249
+ image_filter (str): The current filter to consider.
250
+
251
+ Returns:
252
+ encrypted_output_image_short (bytes): A representation of the encrypted result.
253
+
254
+ """
255
+ data = {
256
+ "user_id": user_id,
257
+ "filter": image_filter,
258
+ }
259
+
260
+ # Retrieve the encrypted output image
261
+ url = SERVER_URL + "get_output"
262
+ with requests.post(
263
+ url=url,
264
+ data=data,
265
+ ) as response:
266
+ if response.ok:
267
+ # Save the encrypted output to bytes in a file as it is too large to pass through regular
268
+ # Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877)
269
+ encrypted_output_path = get_client_file_path("encrypted_output", user_id, image_filter)
270
+
271
+ with encrypted_output_path.open("wb") as encrypted_output_file:
272
+ encrypted_output_file.write(response.content)
273
+
274
+ # Create a truncated version of the encrypted output for display
275
+ encrypted_output_image_short = shorten_bytes_object(response.content)
276
+
277
+ return encrypted_output_image_short
278
+ else:
279
+ raise gr.Error("Please wait for the FHE execution to be completed.")
280
+
281
+
282
+ def decrypt_output(user_id, image_filter):
283
+ """Decrypt the result.
284
+
285
+ Args:
286
+ user_id (int): The current user's ID.
287
+ image_filter (str): The current filter to consider.
288
+
289
+ Returns:
290
+ (output_image, False, False) ((Tuple[numpy.ndarray, bool, bool]): The decrypted output, as
291
+ well as two booleans used for resetting Gradio checkboxes
292
+
293
+ """
294
+ if user_id == "":
295
+ raise gr.Error("Please generate the private key first.")
296
+
297
+ # Get the encrypted output path
298
+ encrypted_output_path = get_client_file_path("encrypted_output", user_id, image_filter)
299
+
300
+ if not encrypted_output_path.is_file():
301
+ raise gr.Error("Please run the FHE execution first.")
302
+
303
+ # Load the encrypted output as bytes
304
+ with encrypted_output_path.open("rb") as encrypted_output_file:
305
+ encrypted_output_image = encrypted_output_file.read()
306
+
307
+ # Retrieve the client API
308
+ client = get_client(user_id, image_filter)
309
+
310
+ # Deserialize, decrypt and post-process the encrypted output
311
+ output_image = client.deserialize_decrypt_post_process(encrypted_output_image)
312
+
313
+ return output_image, False, False
314
+
315
+
316
+ demo = gr.Blocks()
317
+
318
+
319
+ print("Starting the demo...")
320
+ with demo:
321
+ gr.Markdown(
322
+ """
323
+ <p align="center">
324
+ </p>
325
+ <p align="center">
326
+ </p>
327
+ """
328
+ )
329
+
330
+ gr.Markdown("## Client side")
331
+ gr.Markdown(
332
+ f"Step 1. Upload an image. It will automatically be resized to shape ({INPUT_SHAPE[0]}x{INPUT_SHAPE[1]})."
333
+ "The image is however displayed using its original resolution."
334
+ )
335
+ with gr.Row():
336
+ input_image = gr.Image(
337
+ label="Upload an image here.", shape=INPUT_SHAPE, source="upload", interactive=True
338
+ )
339
+
340
+ examples = gr.Examples(
341
+ examples=EXAMPLES, inputs=[input_image], examples_per_page=5, label="Examples to use."
342
+ )
343
+
344
+ gr.Markdown("Step 2. Choose your filter")
345
+ image_filter = gr.Dropdown(
346
+ choices=AVAILABLE_FILTERS, value="inverted", label="Choose your filter", interactive=True
347
+ )
348
+
349
+ gr.Markdown("### Notes")
350
+ gr.Markdown(
351
+ """
352
+ - The private key is used to encrypt and decrypt the data and shall never be shared.
353
+ - No public key are required for these filter operators.
354
+ """
355
+ )
356
+
357
+ with gr.Row():
358
+ keygen_button = gr.Button("Step 3. Generate the private key.")
359
+
360
+ keygen_checkbox = gr.Checkbox(label="Private key generated:", interactive=False)
361
+
362
+ with gr.Row():
363
+ encrypt_button = gr.Button("Step 4. Encrypt the image using FHE.")
364
+
365
+ user_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False)
366
+
367
+ # Display an image representation
368
+ # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2265
369
+ encrypted_image = gr.Textbox(
370
+ label="Encrypted image representation:", max_lines=2, interactive=False
371
+ )
372
+
373
+ gr.Markdown("## Server side")
374
+ gr.Markdown(
375
+ "The encrypted value is received by the server. The server can then compute the filter "
376
+ "directly over encrypted values. Once the computation is finished, the server returns "
377
+ "the encrypted results to the client."
378
+ )
379
+
380
+ with gr.Row():
381
+ send_input_button = gr.Button("Step 5. Send the encrypted image to the server.")
382
+
383
+ send_input_checkbox = gr.Checkbox(label="Encrypted image sent.", interactive=False)
384
+
385
+ with gr.Row():
386
+ execute_fhe_button = gr.Button("Step 6. Run FHE execution")
387
+
388
+ fhe_execution_time = gr.Textbox(
389
+ label="Total FHE execution time (in seconds).", max_lines=1, interactive=False
390
+ )
391
+
392
+ with gr.Row():
393
+ get_output_button = gr.Button("Step 7. Receive the encrypted output image from the server.")
394
+
395
+ # Display an image representation
396
+ # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2265
397
+ encrypted_output_image = gr.Textbox(
398
+ label="Encrypted output image representation:", max_lines=2, interactive=False
399
+ )
400
+
401
+ gr.Markdown("## Client side")
402
+ gr.Markdown(
403
+ "The encrypted output is sent back to client, who can finally decrypt it with its "
404
+ "private key. Only the client is aware of the original image and its transformed version."
405
+ )
406
+
407
+ decrypt_button = gr.Button("Step 8. Decrypt the output")
408
+
409
+ # Final input vs output display
410
+ with gr.Row():
411
+ original_image = gr.Image(
412
+ input_image.value,
413
+ label=f"Input image ({INPUT_SHAPE[0]}x{INPUT_SHAPE[1]}):",
414
+ interactive=False,
415
+ )
416
+ original_image.style(height=256, width=256)
417
+
418
+ output_image = gr.Image(
419
+ label=f"Output image ({INPUT_SHAPE[0]}x{INPUT_SHAPE[1]}):", interactive=False
420
+ )
421
+ output_image.style(height=256, width=256)
422
+
423
+ # Button to generate the private key
424
+ keygen_button.click(
425
+ keygen,
426
+ inputs=[image_filter],
427
+ outputs=[user_id, keygen_checkbox],
428
+ )
429
+
430
+ # Button to encrypt inputs on the client side
431
+ encrypt_button.click(
432
+ encrypt,
433
+ inputs=[user_id, input_image, image_filter],
434
+ outputs=[original_image, encrypted_image],
435
+ )
436
+
437
+ # Button to send the encodings to the server using post method
438
+ send_input_button.click(
439
+ send_input, inputs=[user_id, image_filter], outputs=[send_input_checkbox]
440
+ )
441
+
442
+ # Button to send the encodings to the server using post method
443
+ execute_fhe_button.click(run_fhe, inputs=[user_id, image_filter], outputs=[fhe_execution_time])
444
+
445
+ # Button to send the encodings to the server using post method
446
+ get_output_button.click(
447
+ get_output, inputs=[user_id, image_filter], outputs=[encrypted_output_image]
448
+ )
449
+
450
+ # Button to decrypt the output on the client side
451
+ decrypt_button.click(
452
+ decrypt_output,
453
+ inputs=[user_id, image_filter],
454
+ outputs=[output_image, keygen_checkbox, send_input_checkbox],
455
+ )
456
+
457
+ gr.Markdown(
458
+ "The app was built with [Concrete-ML](https://github.com/zama-ai/concrete-ml), a "
459
+ "Privacy-Preserving Machine Learning (PPML) open-source set of tools by [Zama](https://zama.ai/). "
460
+ "Try it yourself and don't forget to star on Github &#11088;."
461
+ )
462
+
463
+ demo.launch(share=False)
common.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "All the constants used in this repo."
2
+
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ # The repository's directory
9
+ REPO_DIR = Path(__file__).parent
10
+
11
+ # The repository's main directories
12
+ FILTERS_PATH = REPO_DIR / "filters"
13
+ KEYS_PATH = REPO_DIR / ".fhe_keys"
14
+ CLIENT_TMP_PATH = REPO_DIR / "client_tmp"
15
+ SERVER_TMP_PATH = REPO_DIR / "server_tmp"
16
+
17
+ # Create the directories if it does not exist yet
18
+ KEYS_PATH.mkdir(exist_ok=True)
19
+ CLIENT_TMP_PATH.mkdir(exist_ok=True)
20
+ SERVER_TMP_PATH.mkdir(exist_ok=True)
21
+
22
+ # All the filters currently available in the app
23
+ AVAILABLE_FILTERS = [
24
+ "identity",
25
+ "inverted",
26
+ "rotate",
27
+ "black and white",
28
+ "blur",
29
+ "sharpen",
30
+ "ridge detection",
31
+ ]
32
+
33
+ # The input image's shape. Images with larger input shapes will be cropped and/or resized to this
34
+ INPUT_SHAPE = (100, 100)
35
+
36
+ # Generate random images as an inputset for compilation
37
+ np.random.seed(42)
38
+ INPUTSET = tuple(
39
+ np.random.randint(0, 255, size=(INPUT_SHAPE + (3,)), dtype=np.int64) for _ in range(10)
40
+ )
41
+
42
+
43
+ def load_image(image_path):
44
+ image = Image.open(image_path).convert("RGB").resize(INPUT_SHAPE)
45
+ image = np.asarray(image, dtype="int64")
46
+ return image
47
+
48
+
49
+ _INPUTSET_DIR = REPO_DIR / "input_examples"
50
+
51
+ # List of all image examples suggested in the app
52
+ EXAMPLES = [str(image) for image in _INPUTSET_DIR.glob("**/*")]
53
+
54
+ SERVER_URL = "http://localhost:8000/"
compile.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "A script to manually compile all filters."
2
+
3
+ import json
4
+ import shutil
5
+
6
+ import numpy as np
7
+ import onnx
8
+ from common import AVAILABLE_FILTERS, FILTERS_PATH, INPUT_SHAPE, INPUTSET, KEYS_PATH
9
+ from custom_client_server import CustomFHEClient, CustomFHEDev
10
+
11
+ print("Starting compiling the filters.")
12
+
13
+ for image_filter in AVAILABLE_FILTERS:
14
+ print("\nCompiling filter:", image_filter)
15
+
16
+ # Load the onnx model
17
+ onnx_model = onnx.load(FILTERS_PATH / f"{image_filter}/server.onnx")
18
+
19
+ deployment_path = FILTERS_PATH / f"{image_filter}/deployment"
20
+
21
+ # Retrieve the client API related to the current filter
22
+ model = CustomFHEClient(deployment_path, KEYS_PATH).model
23
+
24
+ image_shape = INPUT_SHAPE + (3,)
25
+
26
+ # Compile the model using the loaded onnx model
27
+ model.compile(INPUTSET, onnx_model=onnx_model)
28
+
29
+ processing_json_path = deployment_path / "serialized_processing.json"
30
+
31
+ # Load the serialized_processing.json file
32
+ with open(processing_json_path, "r") as f:
33
+ serialized_processing = json.load(f)
34
+
35
+ # Delete the deployment folder and its content if it exist
36
+ if deployment_path.is_dir():
37
+ shutil.rmtree(deployment_path)
38
+
39
+ # Save the files needed for deployment
40
+ fhe_api = CustomFHEDev(model=model, path_dir=deployment_path)
41
+ fhe_api.save()
42
+
43
+ # Write the serialized_processing.json file to the deployment folder
44
+ with open(processing_json_path, "w") as f:
45
+ json.dump(serialized_processing, f)
46
+
47
+ print("Done!")
custom_client_server.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Client-server interface implementation for custom models."
2
+
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import concrete.numpy as cnp
7
+ import numpy as np
8
+ from filters import Filter
9
+
10
+ from concrete.ml.common.debugging.custom_assert import assert_true
11
+
12
+
13
+ class CustomFHEDev:
14
+ """Dev API to save the custom model and then load and run the FHE circuit."""
15
+
16
+ model: Any = None
17
+
18
+ def __init__(self, path_dir: str, model: Any = None):
19
+ """Initialize the FHE API.
20
+
21
+ Args:
22
+ path_dir (str): the path to the directory where the circuit is saved
23
+ model (Any): the model to use for the FHE API
24
+ """
25
+
26
+ self.path_dir = Path(path_dir)
27
+ self.model = model
28
+
29
+ # Create the directory path if it does not exist yet
30
+ Path(self.path_dir).mkdir(parents=True, exist_ok=True)
31
+
32
+ def save(self):
33
+ """Export all needed artifacts for the client and server.
34
+
35
+ Raises:
36
+ Exception: path_dir is not empty
37
+ """
38
+ # Check if the path_dir is empty with pathlib
39
+ listdir = list(Path(self.path_dir).glob("**/*"))
40
+ if len(listdir) > 0:
41
+ raise Exception(
42
+ f"path_dir: {self.path_dir} is not empty."
43
+ "Please delete it before saving a new model."
44
+ )
45
+
46
+ assert_true(
47
+ hasattr(self.model, "fhe_circuit"),
48
+ "The model must be compiled and have a fhe_circuit object",
49
+ )
50
+
51
+ # Model must be compiled with jit=False
52
+ # In a jit model, everything is in memory so it is not serializable.
53
+ assert_true(
54
+ not self.model.fhe_circuit.configuration.jit,
55
+ "The model must be compiled with the configuration option jit=False.",
56
+ )
57
+
58
+ # Export the parameters
59
+ self.model.to_json(path_dir=self.path_dir, file_name="serialized_processing")
60
+
61
+ # Save the circuit for the server
62
+ path_circuit_server = self.path_dir / "server.zip"
63
+ self.model.fhe_circuit.server.save(path_circuit_server)
64
+
65
+ # Save the circuit for the client
66
+ path_circuit_client = self.path_dir / "client.zip"
67
+ self.model.fhe_circuit.client.save(path_circuit_client)
68
+
69
+
70
+ class CustomFHEClient:
71
+ """Client API to encrypt and decrypt FHE data."""
72
+
73
+ client: cnp.Client
74
+
75
+ def __init__(self, path_dir: str, key_dir: str = None):
76
+ """Initialize the FHE API.
77
+
78
+ Args:
79
+ path_dir (str): the path to the directory where the circuit is saved
80
+ key_dir (str): the path to the directory where the keys are stored
81
+ """
82
+ self.path_dir = Path(path_dir)
83
+ self.key_dir = Path(key_dir)
84
+
85
+ # If path_dir does not exist, raise an error
86
+ assert_true(
87
+ Path(path_dir).exists(), f"{path_dir} does not exist. Please specify a valid path."
88
+ )
89
+
90
+ # Load
91
+ self.load()
92
+
93
+ def load(self): # pylint: disable=no-value-for-parameter
94
+ """Load the parameters along with the FHE specs."""
95
+
96
+ # Load the client
97
+ self.client = cnp.Client.load(self.path_dir / "client.zip", self.key_dir)
98
+
99
+ # Load the model
100
+ self.model = Filter.from_json(self.path_dir / "serialized_processing.json")
101
+
102
+ def generate_private_and_evaluation_keys(self, force=False):
103
+ """Generate the private and evaluation keys.
104
+
105
+ Args:
106
+ force (bool): if True, regenerate the keys even if they already exist
107
+ """
108
+ self.client.keygen(force)
109
+
110
+ def get_serialized_evaluation_keys(self) -> cnp.EvaluationKeys:
111
+ """Get the serialized evaluation keys.
112
+
113
+ Returns:
114
+ cnp.EvaluationKeys: the evaluation keys
115
+ """
116
+ return self.client.evaluation_keys.serialize()
117
+
118
+ def pre_process_encrypt_serialize(self, x: np.ndarray) -> cnp.PublicArguments:
119
+ """Encrypt and serialize the values.
120
+
121
+ Args:
122
+ x (numpy.ndarray): the values to encrypt and serialize
123
+
124
+ Returns:
125
+ cnp.PublicArguments: the encrypted and serialized values
126
+ """
127
+ # Pre-process the values
128
+ x = self.model.pre_processing(x)
129
+
130
+ # Encrypt the values
131
+ enc_x = self.client.encrypt(x)
132
+
133
+ # Serialize the encrypted values to be sent to the server
134
+ serialized_enc_x = self.client.specs.serialize_public_args(enc_x)
135
+ return serialized_enc_x
136
+
137
+ def deserialize_decrypt_post_process(
138
+ self, serialized_encrypted_output: cnp.PublicArguments
139
+ ) -> np.ndarray:
140
+ """Deserialize, decrypt and post-process the values.
141
+
142
+ Args:
143
+ serialized_encrypted_output (cnp.PublicArguments): the serialized and encrypted output
144
+
145
+ Returns:
146
+ numpy.ndarray: the decrypted values
147
+ """
148
+ # Deserialize the encrypted values
149
+ deserialized_encrypted_output = self.client.specs.unserialize_public_result(
150
+ serialized_encrypted_output
151
+ )
152
+
153
+ # Decrypt the values
154
+ deserialized_decrypted_output = self.client.decrypt(deserialized_encrypted_output)
155
+
156
+ # Apply the model post processing
157
+ deserialized_decrypted_output = self.model.post_processing(deserialized_decrypted_output)
158
+ return deserialized_decrypted_output
159
+
160
+
161
+ class CustomFHEServer:
162
+ """Server API to load and run the FHE circuit."""
163
+
164
+ server: cnp.Server
165
+
166
+ def __init__(self, path_dir: str):
167
+ """Initialize the FHE API.
168
+
169
+ Args:
170
+ path_dir (str): the path to the directory where the circuit is saved
171
+ """
172
+
173
+ self.path_dir = Path(path_dir)
174
+
175
+ # Load the FHE circuit
176
+ self.load()
177
+
178
+ def load(self):
179
+ """Load the circuit."""
180
+ self.server = cnp.Server.load(self.path_dir / "server.zip")
181
+
182
+ def run(
183
+ self,
184
+ serialized_encrypted_data: cnp.PublicArguments,
185
+ serialized_evaluation_keys: cnp.EvaluationKeys,
186
+ ) -> cnp.PublicResult:
187
+ """Run the model on the server over encrypted data.
188
+
189
+ Args:
190
+ serialized_encrypted_data (cnp.PublicArguments): the encrypted and serialized data
191
+ serialized_evaluation_keys (cnp.EvaluationKeys): the serialized evaluation keys
192
+
193
+ Returns:
194
+ cnp.PublicResult: the result of the model
195
+ """
196
+ assert_true(self.server is not None, "Model has not been loaded.")
197
+
198
+ deserialized_encrypted_data = self.server.client_specs.unserialize_public_args(
199
+ serialized_encrypted_data
200
+ )
201
+ deserialized_evaluation_keys = cnp.EvaluationKeys.unserialize(serialized_evaluation_keys)
202
+ result = self.server.run(deserialized_encrypted_data, deserialized_evaluation_keys)
203
+ serialized_result = self.server.client_specs.serialize_public_result(result)
204
+ return serialized_result
filters.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Filter definitions, with pre-processing, post-processing and compilation methods."
2
+
3
+ import json
4
+
5
+ import numpy as np
6
+ import torch
7
+ from common import AVAILABLE_FILTERS
8
+ from concrete.numpy.compilation.compiler import Compiler
9
+ from torch import nn
10
+
11
+ from concrete.ml.common.debugging.custom_assert import assert_true
12
+ from concrete.ml.common.utils import generate_proxy_function
13
+ from concrete.ml.onnx.convert import get_equivalent_numpy_forward
14
+ from concrete.ml.torch.numpy_module import NumpyModule
15
+ from concrete.ml.version import __version__ as CML_VERSION
16
+
17
+ # Add a "black and white" filter
18
+ # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2277
19
+
20
+
21
+ class _TorchIdentity(nn.Module):
22
+ """Torch identity model."""
23
+
24
+ def forward(self, x):
25
+ """Identity forward pass.
26
+
27
+ Args:
28
+ x (torch.Tensor): The input image.
29
+
30
+ Returns:
31
+ x (torch.Tensor): The input image.
32
+ """
33
+ return x
34
+
35
+
36
+ class _TorchInverted(nn.Module):
37
+ """Torch inverted model."""
38
+
39
+ def forward(self, x):
40
+ """Forward pass for inverting an image's colors.
41
+
42
+ Args:
43
+ x (torch.Tensor): The input image.
44
+
45
+ Returns:
46
+ torch.Tensor: The (color) inverted image.
47
+ """
48
+ return 255 - x
49
+
50
+
51
+ class _TorchRotate(nn.Module):
52
+ """Torch rotated model."""
53
+
54
+ def forward(self, x):
55
+ """Forward pass for rotating an image.
56
+
57
+ Args:
58
+ x (torch.Tensor): The input image.
59
+
60
+ Returns:
61
+ torch.Tensor: The rotated image.
62
+ """
63
+ return x.transpose(2, 3)
64
+
65
+
66
+ class _TorchConv2D(nn.Module):
67
+ """Torch model for applying a single 2D convolution operator on images."""
68
+
69
+ def __init__(self, kernel, n_in_channels=3, n_out_channels=3, groups=1):
70
+ """Initializing the filter
71
+
72
+ Args:
73
+ kernel (np.ndarray): The convolution kernel to consider.
74
+ """
75
+ super().__init__()
76
+ self.kernel = kernel
77
+ self.n_out_channels = n_out_channels
78
+ self.n_in_channels = n_in_channels
79
+ self.groups = groups
80
+
81
+ def forward(self, x):
82
+ """Forward pass for filtering the image using a 2D kernel.
83
+
84
+ Args:
85
+ x (torch.Tensor): The input image.
86
+
87
+ Returns:
88
+ torch.Tensor: The filtered image.
89
+
90
+ """
91
+ # Define the convolution parameters
92
+ stride = 1
93
+ kernel_shape = self.kernel.shape
94
+
95
+ # Ensure the kernel has a proper shape
96
+ # If the kernel has a 1D shape, a (1, 1) kernel is used for each in_channels
97
+ if len(kernel_shape) == 1:
98
+ kernel = self.kernel.reshape(
99
+ self.n_out_channels,
100
+ self.n_in_channels // self.groups,
101
+ 1,
102
+ 1,
103
+ )
104
+
105
+ # Else, if the kernel has a 2D shape, a single (Kw, Kh) kernel is used on all in_channels
106
+ elif len(kernel_shape) == 2:
107
+ kernel = self.kernel.expand(
108
+ self.n_out_channels,
109
+ self.n_in_channels // self.groups,
110
+ kernel_shape[0],
111
+ kernel_shape[1],
112
+ )
113
+ else:
114
+ raise ValueError(
115
+ "Wrong kernel shape, only 1D or 2D kernels are accepted. Got kernel of shape "
116
+ f"{kernel_shape}"
117
+ )
118
+
119
+ return nn.functional.conv2d(x, kernel, stride=stride, groups=self.groups)
120
+
121
+
122
+ class Filter:
123
+ """Filter class used in the app."""
124
+
125
+ def __init__(self, image_filter="inverted"):
126
+ """Initializing the filter class using a given filter.
127
+
128
+ Most filters can be found at https://en.wikipedia.org/wiki/Kernel_(image_processing).
129
+
130
+ Args:
131
+ image_filter (str): The filter to consider. Default to "inverted".
132
+ """
133
+
134
+ assert_true(
135
+ image_filter in AVAILABLE_FILTERS,
136
+ f"Unsupported image filter or transformation. Expected one of {*AVAILABLE_FILTERS,}, "
137
+ f"but got {image_filter}",
138
+ )
139
+
140
+ self.filter = image_filter
141
+ self.divide = None
142
+ self.repeat_out_channels = False
143
+
144
+ if image_filter == "identity":
145
+ self.torch_model = _TorchIdentity()
146
+
147
+ elif image_filter == "inverted":
148
+ self.torch_model = _TorchInverted()
149
+
150
+ elif image_filter == "rotate":
151
+ self.torch_model = _TorchRotate()
152
+
153
+ elif image_filter == "black and white":
154
+ # Define the grayscale weights (RGB order)
155
+ # These weights were used in PAL and NTSC video systems and can be found at
156
+ # https://en.wikipedia.org/wiki/Grayscale
157
+ # There are initially supposed to be float weights (0.299, 0.587, 0.114), with
158
+ # 0.299 + 0.587 + 0.114 = 1
159
+ # However, since FHE computations require weights to be integers, we first multiply
160
+ # these by a factor of 1000. The output image's values are then divided by 1000 in
161
+ # post-processing in order to retrieve the correct result
162
+ kernel = torch.tensor([299, 587, 114])
163
+
164
+ self.torch_model = _TorchConv2D(kernel, n_out_channels=1, groups=1)
165
+
166
+ # Division value for post-processing
167
+ self.divide = 1000
168
+
169
+ # Grayscaled image needs to be put in RGB format for Gradio display
170
+ self.repeat_out_channels = True
171
+
172
+ elif image_filter == "blur":
173
+ kernel = torch.ones((3, 3), dtype=torch.int64)
174
+
175
+ self.torch_model = _TorchConv2D(kernel, n_out_channels=3, groups=3)
176
+
177
+ # Division value for post-processing
178
+ self.divide = 9
179
+
180
+ elif image_filter == "sharpen":
181
+ kernel = torch.tensor(
182
+ [
183
+ [0, -1, 0],
184
+ [-1, 5, -1],
185
+ [0, -1, 0],
186
+ ]
187
+ )
188
+
189
+ self.torch_model = _TorchConv2D(kernel, n_out_channels=3, groups=3)
190
+
191
+ elif image_filter == "ridge detection":
192
+ # Make the filter properly grayscaled, as it is commonly used
193
+ # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2265
194
+
195
+ kernel = torch.tensor(
196
+ [
197
+ [-1, -1, -1],
198
+ [-1, 9, -1],
199
+ [-1, -1, -1],
200
+ ]
201
+ )
202
+
203
+ self.torch_model = _TorchConv2D(kernel, n_out_channels=1, groups=1)
204
+
205
+ # Ridge detection is usually displayed as a grayscaled image, which needs to be put in
206
+ # RGB format for Gradio display
207
+ self.repeat_out_channels = True
208
+
209
+ self.onnx_model = None
210
+ self.fhe_circuit = None
211
+
212
+ def compile(self, inputset, onnx_model=None):
213
+ """Compile the model using an inputset.
214
+
215
+ Args:
216
+ inputset (List[np.ndarray]): The set of images to use for compilation
217
+ onnx_model (onnx.ModelProto): The loaded onnx model to consider. If None, it will be
218
+ generated automatically using a NumpyModule. Default to None.
219
+ """
220
+ # Reshape the inputs found in inputset. This is done because Torch and Numpy don't follow
221
+ # the same shape conventions.
222
+ inputset = tuple(
223
+ np.expand_dims(input.transpose(2, 0, 1), axis=0).astype(np.int64) for input in inputset
224
+ )
225
+
226
+ # If no onnx model was given, generate a new one.
227
+ if onnx_model is None:
228
+ numpy_module = NumpyModule(
229
+ self.torch_model,
230
+ dummy_input=torch.from_numpy(inputset[0]),
231
+ )
232
+
233
+ onnx_model = numpy_module.onnx_model
234
+
235
+ # Get the proxy function and parameter mappings for initializing the compiler
236
+ self.onnx_model = onnx_model
237
+ numpy_filter = get_equivalent_numpy_forward(onnx_model)
238
+
239
+ numpy_filter_proxy, parameters_mapping = generate_proxy_function(numpy_filter, ["inputs"])
240
+
241
+ compiler = Compiler(
242
+ numpy_filter_proxy,
243
+ {parameters_mapping["inputs"]: "encrypted"},
244
+ )
245
+
246
+ # Compile the filter
247
+ self.fhe_circuit = compiler.compile(inputset)
248
+
249
+ return self.fhe_circuit
250
+
251
+ def pre_processing(self, input_image):
252
+ """Processing that needs to be applied before encryption.
253
+
254
+ Args:
255
+ input_image (np.ndarray): The image to pre-process
256
+
257
+ Returns:
258
+ input_image (np.ndarray): The pre-processed image
259
+ """
260
+ # Reshape the inputs found in inputset. This is done because Torch and Numpy don't follow
261
+ # the same shape conventions.
262
+ input_image = np.expand_dims(input_image.transpose(2, 0, 1), axis=0).astype(np.int64)
263
+
264
+ return input_image
265
+
266
+ def post_processing(self, output_image):
267
+ """Processing that needs to be applied after decryption.
268
+
269
+ Args:
270
+ input_image (np.ndarray): The decrypted image to post-process
271
+
272
+ Returns:
273
+ input_image (np.ndarray): The post-processed image
274
+ """
275
+ # Apply a division if needed
276
+ if self.divide is not None:
277
+ output_image //= self.divide
278
+
279
+ # Clip the image's values to proper RGB standards as filters don't handle such constraints
280
+ output_image = output_image.clip(0, 255)
281
+
282
+ # Reshape the inputs found in inputset. This is done because Torch and Numpy don't follow
283
+ # the same shape conventions.
284
+ output_image = output_image.transpose(0, 2, 3, 1).squeeze(0)
285
+
286
+ # Grayscaled image needs to be put in RGB format for Gradio display
287
+ if self.repeat_out_channels:
288
+ output_image = output_image.repeat(3, axis=2)
289
+
290
+ return output_image
291
+
292
+ @classmethod
293
+ def from_json(cls, json_path):
294
+ """Instantiate a filter using a json file.
295
+
296
+ Args:
297
+ json_path (Union[str, pathlib.Path]): Path to the json file.
298
+
299
+ Returns:
300
+ model (Filter): The instantiated filter class.
301
+ """
302
+ # Load the parameters from the json file
303
+ with open(json_path, "r", encoding="utf-8") as f:
304
+ serialized_processing = json.load(f)
305
+
306
+ # Make sure the version in serialized_model is the same as CML_VERSION
307
+ assert_true(
308
+ serialized_processing["cml_version"] == CML_VERSION,
309
+ f"The version of Concrete ML library ({CML_VERSION}) is different "
310
+ f"from the one used to save the model ({serialized_processing['cml_version']}). "
311
+ "Please update to the proper Concrete ML version.",
312
+ )
313
+
314
+ # Initialize the model
315
+ model = cls(image_filter=serialized_processing["model_filter"])
316
+
317
+ return model
318
+
319
+ def to_json(self, path_dir, file_name="serialized_processing"):
320
+ """Export the parameters to a json file.
321
+
322
+ Args:
323
+ path_dir (Union[str, pathlib.Path]): The path to consider when saving the file.
324
+ file_name (str): The file name
325
+ """
326
+ # Serialize the parameters
327
+ serialized_processing = {
328
+ "model_filter": self.filter,
329
+ }
330
+ serialized_processing = self._clean_dict_types_for_json(serialized_processing)
331
+
332
+ # Add the version of the current CML library
333
+ serialized_processing["cml_version"] = CML_VERSION
334
+
335
+ # Save the json file
336
+ with open(path_dir / f"{file_name}.json", "w", encoding="utf-8") as f:
337
+ json.dump(serialized_processing, f)
338
+
339
+ def _clean_dict_types_for_json(self, d: dict) -> dict:
340
+ """Clean all values in the dict to be json serializable.
341
+
342
+ Args:
343
+ d (Dict): The dict to clean
344
+
345
+ Returns:
346
+ Dict: The cleaned dict
347
+ """
348
+ key_to_delete = []
349
+ for key, value in d.items():
350
+ if isinstance(value, list) and len(value) > 0 and isinstance(value[0], dict):
351
+ d[key] = [self._clean_dict_types_for_json(v) for v in value]
352
+ elif isinstance(value, dict):
353
+ d[key] = self._clean_dict_types_for_json(value)
354
+ elif isinstance(value, (np.generic, np.ndarray)):
355
+ d[key] = d[key].tolist()
356
+
357
+ for key in key_to_delete:
358
+ d.pop(key)
359
+ return d
filters/black and white/deployment/client.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b20b09e3118c2dc24004c4f1c4bc1465cf4b0ed0e1c907fffb7695b3db6bbace
3
+ size 388
filters/black and white/deployment/serialized_processing.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_filter": "black and white", "cml_version": "0.6.0-rc0"}
filters/black and white/deployment/server.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fb1f2ff4aa7a1a56cf5d0f8d63d34ee912c06b347fe5e97088c79ad0ba6e902
3
+ size 4870
filters/black and white/server.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f4774c394a6fec8cc43dae14ce627837aa998fcc78ba4ab67ad1c5bf92dd3ee
3
+ size 336
filters/black_and_white/deployment/client.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:000285a62f642b20eda541c6697e33de3d725c254ff5c2098e3157fc73cd017b
3
+ size 388
filters/black_and_white/deployment/serialized_processing.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_filter": "black_and_white", "cml_version": "0.6.0-rc0"}
filters/black_and_white/deployment/server.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9867657ff1e7b2c8eb3c72f28be8b8e8ee0b355762b99f34a25a2c9de0cb104c
3
+ size 4762
filters/black_and_white/server.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f4774c394a6fec8cc43dae14ce627837aa998fcc78ba4ab67ad1c5bf92dd3ee
3
+ size 336
filters/blur/deployment/client.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8846612ef61a81a0f18b96a4bcca90bffde5400f9e689ac208f40673e3581aca
3
+ size 391
filters/blur/deployment/serialized_processing.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_filter": "blur", "cml_version": "0.6.0-rc0"}
filters/blur/deployment/server.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2652c963dbd9b82788671b9f133e70131f9da5ecf0a3123c9aa323ff69ee77a3
3
+ size 8651
filters/blur/server.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8fd3d313ec3a9d565a0621921768317f66e53596ad950ca2be6b1efbcf984bd
3
+ size 532
filters/identity/deployment/client.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7320407c56796bf0fe4d719f5e5826650f83c8424cb15779ac8c5b5ef0722fd
3
+ size 378
filters/identity/deployment/serialized_processing.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_filter": "identity", "cml_version": "0.6.0-rc0"}
filters/identity/deployment/server.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:246a6063277532e2ef246df2d0ce7f0c5d38fbfaa85c8a0d649cada63e7b0bb9
3
+ size 2637
filters/identity/server.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71a8a398ea2edac9b0dfd41232c74549d1b8c159d391a4b3d42e2b4b731da02b
3
+ size 155
filters/inverted/deployment/client.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7320407c56796bf0fe4d719f5e5826650f83c8424cb15779ac8c5b5ef0722fd
3
+ size 378
filters/inverted/deployment/serialized_processing.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_filter": "inverted", "cml_version": "0.6.0-rc0"}
filters/inverted/deployment/server.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7f54eb012a12e29a927fbfeb2d7c811533ebaa1d50527e14019c940a7c86f52
3
+ size 5136
filters/inverted/server.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2fead709dbc8f0ab1f19ff1878e0ac9ce110b2b3ced261d7a87d32e0fc58b61
3
+ size 211
filters/ridge detection/deployment/client.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88f760b83837021929bccf86aaffefed2f9e5e97c3638346d32238e1027cb7a2
3
+ size 397
filters/ridge detection/deployment/serialized_processing.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_filter": "ridge detection", "cml_version": "0.6.0-rc0"}
filters/ridge detection/deployment/server.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bef0e5a94d7d50c8ac658b9e1a411c81051dba914a1b19f0e2badc53a2f36fdc
3
+ size 5020
filters/ridge detection/server.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48821745ed7a9b25b5ba8ae0dc3da35739985bf5dd1dac5b3a9c207adbbf1c45
3
+ size 532
filters/ridge_detection/deployment/client.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d241694b8c01dce2ad8a5ce2dbe12190e40d6912e88d086dbc0e047aba4dfafb
3
+ size 397
filters/ridge_detection/deployment/serialized_processing.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_filter": "ridge_detection", "cml_version": "0.6.0-rc0"}
filters/ridge_detection/deployment/server.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3605e14d8533e3c57edf30a7da32d4441fcb68228a8ebd028015338b8b5d5f70
3
+ size 4884
filters/ridge_detection/server.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e05d56c4988abd621aee6dea4efe2dfdaf1d09dfb78bb7bf7b6bb3a00d3e80b
3
+ size 532
filters/rotate/deployment/client.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f43305830edade90df38b59070c255948810dc9a8c58eda16157e0424b9bffe
3
+ size 383
filters/rotate/deployment/serialized_processing.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_filter": "rotate", "cml_version": "0.6.0-rc0"}
filters/rotate/deployment/server.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b382f12dafa436e4e5e4adc0346fa81b52d5fad709a19f1b2cad52001a97c984
3
+ size 5366
filters/rotate/server.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa03ea9a684b65c29c2cc0e6ab20f6b6349f35c4bd70921d264e74298a758de1
3
+ size 178
filters/sharpen/deployment/client.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d666044a75f5e7d4642145181ea239de6076f8ae424d971c7139e3467a758793
3
+ size 396
filters/sharpen/deployment/serialized_processing.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_filter": "sharpen", "cml_version": "0.6.0-rc0"}
filters/sharpen/deployment/server.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6922f5fa0d0a0584636ce755dbd921bc36fca082f2a0facb74669f4b24b67368
3
+ size 8720
filters/sharpen/server.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7958a3c9be1b578486ec1708701340263ce3ad70b7cd3ff281230797f67de0d
3
+ size 532
generate_dev_files.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "A script to generate all development files necessary for the image filtering demo."
2
+
3
+ import shutil
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import onnx
8
+ from common import AVAILABLE_FILTERS, FILTERS_PATH, INPUT_SHAPE, INPUTSET
9
+ from custom_client_server import CustomFHEDev
10
+ from filters import Filter
11
+
12
+ print("Generating deployment files for all available filters")
13
+
14
+ for image_filter in AVAILABLE_FILTERS:
15
+ print("Filter:", image_filter, "\n")
16
+
17
+ # Create the filter instance
18
+ filter = Filter(image_filter)
19
+
20
+ image_shape = INPUT_SHAPE + (3,)
21
+
22
+ # Compile the filter on the inputset
23
+ filter.compile(INPUTSET)
24
+
25
+ filter_path = FILTERS_PATH / image_filter
26
+
27
+ deployment_path = filter_path / "deployment"
28
+
29
+ # Delete the deployment folder and its content if it exist
30
+ if deployment_path.is_dir():
31
+ shutil.rmtree(deployment_path)
32
+
33
+ # Save the files needed for deployment
34
+ fhe_dev_filter = CustomFHEDev(deployment_path, filter)
35
+ fhe_dev_filter.save()
36
+
37
+ # Save the ONNX model
38
+ onnx.save(filter.onnx_model, filter_path / "server.onnx")
39
+
40
+ print("Done !")
input_examples/arc.jpg ADDED
input_examples/book.jpg ADDED
input_examples/computer.jpg ADDED
input_examples/tree.jpg ADDED
input_examples/zama_math.jpg ADDED
input_examples/zebra.jpg ADDED