Roman
commited on
Commit
•
7b32412
1
Parent(s):
b1501ef
chore: remove CML client/server API by only using CN
Browse files- app.py +10 -13
- client_server_interface.py +175 -0
- compile.py +8 -8
- custom_client_server.py +0 -35
- filters.py +23 -45
- filters/black and white/deployment/client.zip +2 -2
- filters/black and white/deployment/serialized_processing.json +1 -1
- filters/black and white/deployment/server.zip +1 -1
- filters/blur/deployment/client.zip +2 -2
- filters/blur/deployment/serialized_processing.json +1 -1
- filters/blur/deployment/server.zip +1 -1
- filters/identity/deployment/client.zip +2 -2
- filters/identity/deployment/serialized_processing.json +1 -1
- filters/identity/deployment/server.zip +1 -1
- filters/inverted/deployment/client.zip +2 -2
- filters/inverted/deployment/serialized_processing.json +1 -1
- filters/inverted/deployment/server.zip +1 -1
- filters/ridge detection/deployment/client.zip +2 -2
- filters/ridge detection/deployment/serialized_processing.json +1 -1
- filters/ridge detection/deployment/server.zip +1 -1
- filters/rotate/deployment/client.zip +2 -2
- filters/rotate/deployment/serialized_processing.json +1 -1
- filters/rotate/deployment/server.zip +1 -1
- filters/sharpen/deployment/client.zip +2 -2
- filters/sharpen/deployment/serialized_processing.json +1 -1
- filters/sharpen/deployment/server.zip +1 -1
- generate_dev_files.py +3 -3
- server.py +4 -9
app.py
CHANGED
@@ -19,7 +19,7 @@ from common import (
|
|
19 |
REPO_DIR,
|
20 |
SERVER_URL,
|
21 |
)
|
22 |
-
from
|
23 |
|
24 |
# Uncomment here to have both the server and client in the same terminal
|
25 |
subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
|
@@ -33,11 +33,11 @@ def decrypt_output_with_wrong_key(encrypted_image, filter_name):
|
|
33 |
filter_path = FILTERS_PATH / f"{filter_name}/deployment"
|
34 |
|
35 |
# Instantiate the client interface and generate a new private key
|
36 |
-
wrong_client =
|
37 |
wrong_client.generate_private_and_evaluation_keys(force=True)
|
38 |
|
39 |
-
# Deserialize, decrypt and post-
|
40 |
-
output_image = wrong_client.
|
41 |
|
42 |
return output_image
|
43 |
|
@@ -53,7 +53,7 @@ def shorten_bytes_object(bytes_object, limit=500):
|
|
53 |
limit (int): The length to consider. Default to 500.
|
54 |
|
55 |
Returns:
|
56 |
-
|
57 |
|
58 |
"""
|
59 |
# Define a shift for better display
|
@@ -69,9 +69,9 @@ def get_client(user_id, filter_name):
|
|
69 |
filter_name (str): The filter chosen by the user
|
70 |
|
71 |
Returns:
|
72 |
-
|
73 |
"""
|
74 |
-
return
|
75 |
FILTERS_PATH / f"{filter_name}/deployment", KEYS_PATH / f"{filter_name}_{user_id}"
|
76 |
)
|
77 |
|
@@ -184,11 +184,8 @@ def encrypt(user_id, input_image, filter_name):
|
|
184 |
# Retrieve the client API
|
185 |
client = get_client(user_id, filter_name)
|
186 |
|
187 |
-
# Pre-process
|
188 |
-
|
189 |
-
|
190 |
-
# Encrypt and serialize the image
|
191 |
-
encrypted_image = client.quantize_encrypt_serialize(preprocessed_input_image)
|
192 |
|
193 |
# Compute the input's size in Megabytes
|
194 |
encrypted_input_size = len(encrypted_image) / 1000000
|
@@ -341,7 +338,7 @@ def decrypt_output(user_id, filter_name):
|
|
341 |
client = get_client(user_id, filter_name)
|
342 |
|
343 |
# Deserialize, decrypt and post-process the encrypted output
|
344 |
-
output_image = client.
|
345 |
|
346 |
return output_image, False, False
|
347 |
|
|
|
19 |
REPO_DIR,
|
20 |
SERVER_URL,
|
21 |
)
|
22 |
+
from client_server_interface import FHEClient
|
23 |
|
24 |
# Uncomment here to have both the server and client in the same terminal
|
25 |
subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
|
|
|
33 |
filter_path = FILTERS_PATH / f"{filter_name}/deployment"
|
34 |
|
35 |
# Instantiate the client interface and generate a new private key
|
36 |
+
wrong_client = FHEClient(filter_path, WRONG_KEYS_PATH)
|
37 |
wrong_client.generate_private_and_evaluation_keys(force=True)
|
38 |
|
39 |
+
# Deserialize, decrypt and post-process the encrypted output using the new private key
|
40 |
+
output_image = wrong_client.deserialize_decrypt_post_process(encrypted_image)
|
41 |
|
42 |
return output_image
|
43 |
|
|
|
53 |
limit (int): The length to consider. Default to 500.
|
54 |
|
55 |
Returns:
|
56 |
+
str: Hexadecimal string shorten representation of the input byte object.
|
57 |
|
58 |
"""
|
59 |
# Define a shift for better display
|
|
|
69 |
filter_name (str): The filter chosen by the user
|
70 |
|
71 |
Returns:
|
72 |
+
FHEClient: The client API.
|
73 |
"""
|
74 |
+
return FHEClient(
|
75 |
FILTERS_PATH / f"{filter_name}/deployment", KEYS_PATH / f"{filter_name}_{user_id}"
|
76 |
)
|
77 |
|
|
|
184 |
# Retrieve the client API
|
185 |
client = get_client(user_id, filter_name)
|
186 |
|
187 |
+
# Pre-process, encrypt and serialize the image
|
188 |
+
encrypted_image = client.pre_process_encrypt_serialize(input_image)
|
|
|
|
|
|
|
189 |
|
190 |
# Compute the input's size in Megabytes
|
191 |
encrypted_input_size = len(encrypted_image) / 1000000
|
|
|
338 |
client = get_client(user_id, filter_name)
|
339 |
|
340 |
# Deserialize, decrypt and post-process the encrypted output
|
341 |
+
output_image = client.deserialize_decrypt_post_process(encrypted_output_image)
|
342 |
|
343 |
return output_image, False, False
|
344 |
|
client_server_interface.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Client-server interface custom implementation for filter models."
|
2 |
+
|
3 |
+
import zipfile
|
4 |
+
import json
|
5 |
+
from filters import Filter
|
6 |
+
|
7 |
+
import concrete.numpy as cnp
|
8 |
+
|
9 |
+
class FHEServer:
|
10 |
+
"""Server interface run a FHE circuit."""
|
11 |
+
|
12 |
+
def __init__(self, path_dir):
|
13 |
+
"""Initialize the FHE interface.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
path_dir (Path): The path to the directory where the circuit is saved.
|
17 |
+
"""
|
18 |
+
self.path_dir = path_dir
|
19 |
+
|
20 |
+
# Load the FHE circuit
|
21 |
+
self.server = cnp.Server.load(self.path_dir / "server.zip")
|
22 |
+
|
23 |
+
def run(self, serialized_encrypted_image, serialized_evaluation_keys):
|
24 |
+
"""Run the filter on the server over an encrypted image.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
serialized_encrypted_image (bytes): The encrypted and serialized image.
|
28 |
+
serialized_evaluation_keys (bytes): The serialized evaluation keys.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
bytes: The filter's output.
|
32 |
+
"""
|
33 |
+
# Deserialize the encrypted input image and the evaluation keys
|
34 |
+
deserialized_encrypted_image = self.server.client_specs.unserialize_public_args(
|
35 |
+
serialized_encrypted_image
|
36 |
+
)
|
37 |
+
deserialized_evaluation_keys = cnp.EvaluationKeys.unserialize(serialized_evaluation_keys)
|
38 |
+
|
39 |
+
# Execute the filter in FHE
|
40 |
+
result = self.server.run(
|
41 |
+
deserialized_encrypted_image, deserialized_evaluation_keys
|
42 |
+
)
|
43 |
+
|
44 |
+
# Serialize the encrypted output image
|
45 |
+
serialized_result = self.server.client_specs.serialize_public_result(result)
|
46 |
+
|
47 |
+
return serialized_result
|
48 |
+
|
49 |
+
|
50 |
+
class FHEDev:
|
51 |
+
"""Development interface to save and load the filter."""
|
52 |
+
|
53 |
+
def __init__(self, filter, path_dir):
|
54 |
+
"""Initialize the FHE interface.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
path_dir (str): The path to the directory where the circuit is saved.
|
58 |
+
filter (Filter): The filter to use in the FHE interface.
|
59 |
+
"""
|
60 |
+
|
61 |
+
self.filter = filter
|
62 |
+
self.path_dir = path_dir
|
63 |
+
|
64 |
+
self.path_dir.mkdir(parents=True, exist_ok=True)
|
65 |
+
|
66 |
+
def save(self):
|
67 |
+
"""Export all needed artifacts for the client and server interfaces."""
|
68 |
+
|
69 |
+
assert self.filter.fhe_circuit is not None, (
|
70 |
+
"The model must be compiled before saving it."
|
71 |
+
)
|
72 |
+
|
73 |
+
# Export to json the parameters needed for loading the filter in the other interfaces
|
74 |
+
serialized_processing = {"filter_name": self.filter.filter_name}
|
75 |
+
|
76 |
+
json_path = self.path_dir / "serialized_processing.json"
|
77 |
+
with open(json_path, "w", encoding="utf-8") as file:
|
78 |
+
json.dump(serialized_processing, file)
|
79 |
+
|
80 |
+
# Save the circuit for the server
|
81 |
+
path_circuit_server = self.path_dir / "server.zip"
|
82 |
+
self.filter.fhe_circuit.server.save(path_circuit_server)
|
83 |
+
|
84 |
+
# Save the circuit for the client
|
85 |
+
path_circuit_client = self.path_dir / "client.zip"
|
86 |
+
self.filter.fhe_circuit.client.save(path_circuit_client)
|
87 |
+
|
88 |
+
with zipfile.ZipFile(path_circuit_client, "a") as zip_file:
|
89 |
+
zip_file.write(filename=json_path, arcname="serialized_processing.json")
|
90 |
+
|
91 |
+
|
92 |
+
class FHEClient:
|
93 |
+
"""Client interface to encrypt and decrypt FHE data associated to a Filter."""
|
94 |
+
|
95 |
+
def __init__(self, path_dir, key_dir):
|
96 |
+
"""Initialize the FHE interface.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
path_dir (Path): the path to the directory where the circuit is saved
|
100 |
+
key_dir (Path): the path to the directory where the keys are stored
|
101 |
+
"""
|
102 |
+
self.path_dir = path_dir
|
103 |
+
self.key_dir = key_dir
|
104 |
+
|
105 |
+
# If path_dir does not exist raise
|
106 |
+
assert path_dir.exists(), f"{path_dir} does not exist. Please specify a valid path."
|
107 |
+
|
108 |
+
# Load the client
|
109 |
+
self.client = cnp.Client.load(self.path_dir / "client.zip", self.key_dir)
|
110 |
+
|
111 |
+
# Load the parameters
|
112 |
+
with zipfile.ZipFile(self.path_dir / "client.zip") as client_zip:
|
113 |
+
with client_zip.open("serialized_processing.json", mode="r") as file:
|
114 |
+
serialized_processing = json.load(file)
|
115 |
+
|
116 |
+
# Instantiate the filter
|
117 |
+
filter_name = serialized_processing["filter_name"]
|
118 |
+
self.filter = Filter(filter_name)
|
119 |
+
|
120 |
+
def generate_private_and_evaluation_keys(self, force=False):
|
121 |
+
"""Generate the private and evaluation keys.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
force (bool): If True, regenerate the keys even if they already exist.
|
125 |
+
"""
|
126 |
+
self.client.keygen(force)
|
127 |
+
|
128 |
+
def get_serialized_evaluation_keys(self):
|
129 |
+
"""Get the serialized evaluation keys.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
bytes: The evaluation keys.
|
133 |
+
"""
|
134 |
+
return self.client.evaluation_keys.serialize()
|
135 |
+
|
136 |
+
def pre_process_encrypt_serialize(self, input_image):
|
137 |
+
"""Pre-process, encrypt and serialize the input image.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
input_image (numpy.ndarray): The image to pre-process, encrypt and serialize.
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
bytes: The pre-processed, encrypted and serialized image.
|
144 |
+
"""
|
145 |
+
# Pre-process the image
|
146 |
+
preprocessed_image = self.filter.pre_processing(input_image)
|
147 |
+
|
148 |
+
# Encrypt the image
|
149 |
+
encrypted_image = self.client.encrypt(preprocessed_image)
|
150 |
+
|
151 |
+
# Serialize the encrypted image to be sent to the server
|
152 |
+
serialized_encrypted_image = self.client.specs.serialize_public_args(encrypted_image)
|
153 |
+
return serialized_encrypted_image
|
154 |
+
|
155 |
+
def deserialize_decrypt_post_process(self, serialized_encrypted_output_image):
|
156 |
+
"""Deserialize, decrypt and post-process the output image.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
serialized_encrypted_output_image (bytes): The serialized and encrypted output image.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
numpy.ndarray: The decrypted, deserialized and post-processed image.
|
163 |
+
"""
|
164 |
+
# Deserialize the encrypted image
|
165 |
+
encrypted_output_image = self.client.specs.unserialize_public_result(
|
166 |
+
serialized_encrypted_output_image
|
167 |
+
)
|
168 |
+
|
169 |
+
# Decrypt the image
|
170 |
+
output_image = self.client.decrypt(encrypted_output_image)
|
171 |
+
|
172 |
+
# Post-process the image
|
173 |
+
post_processed_output_image = self.filter.post_processing(output_image)
|
174 |
+
|
175 |
+
return post_processed_output_image
|
compile.py
CHANGED
@@ -3,9 +3,9 @@
|
|
3 |
import json
|
4 |
import shutil
|
5 |
import onnx
|
|
|
6 |
from common import AVAILABLE_FILTERS, FILTERS_PATH, KEYS_PATH
|
7 |
-
from
|
8 |
-
from concrete.ml.deployment import FHEModelDev
|
9 |
|
10 |
print("Starting compiling the filters.")
|
11 |
|
@@ -16,13 +16,13 @@ for filter_name in AVAILABLE_FILTERS:
|
|
16 |
deployment_path = FILTERS_PATH / f"{filter_name}/deployment"
|
17 |
|
18 |
# Retrieve the client associated to the current filter
|
19 |
-
|
20 |
|
21 |
-
# Load the onnx
|
22 |
-
|
23 |
|
24 |
-
# Compile the
|
25 |
-
|
26 |
|
27 |
processing_json_path = deployment_path / "serialized_processing.json"
|
28 |
|
@@ -35,7 +35,7 @@ for filter_name in AVAILABLE_FILTERS:
|
|
35 |
shutil.rmtree(deployment_path)
|
36 |
|
37 |
# Save the development files needed for deployment
|
38 |
-
fhe_dev =
|
39 |
fhe_dev.save()
|
40 |
|
41 |
# Write the serialized_processing.json file in the deployment directory
|
|
|
3 |
import json
|
4 |
import shutil
|
5 |
import onnx
|
6 |
+
|
7 |
from common import AVAILABLE_FILTERS, FILTERS_PATH, KEYS_PATH
|
8 |
+
from client_server_interface import FHEClient, FHEDev
|
|
|
9 |
|
10 |
print("Starting compiling the filters.")
|
11 |
|
|
|
16 |
deployment_path = FILTERS_PATH / f"{filter_name}/deployment"
|
17 |
|
18 |
# Retrieve the client associated to the current filter
|
19 |
+
filter = FHEClient(deployment_path, KEYS_PATH).filter
|
20 |
|
21 |
+
# Load the onnx graph
|
22 |
+
onnx_graph = onnx.load(FILTERS_PATH / f"{filter_name}/server.onnx")
|
23 |
|
24 |
+
# Compile the filter on a representative inputset, using the loaded onnx graph
|
25 |
+
filter.compile(onnx_graph=onnx_graph)
|
26 |
|
27 |
processing_json_path = deployment_path / "serialized_processing.json"
|
28 |
|
|
|
35 |
shutil.rmtree(deployment_path)
|
36 |
|
37 |
# Save the development files needed for deployment
|
38 |
+
fhe_dev = FHEDev(filter, deployment_path)
|
39 |
fhe_dev.save()
|
40 |
|
41 |
# Write the serialized_processing.json file in the deployment directory
|
custom_client_server.py
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
"Client-server interface custom implementation for filter models."
|
2 |
-
|
3 |
-
import json
|
4 |
-
import concrete.numpy as cnp
|
5 |
-
from filters import Filter
|
6 |
-
|
7 |
-
from concrete.ml.deployment import FHEModelClient
|
8 |
-
from concrete.ml.version import __version__ as CML_VERSION
|
9 |
-
|
10 |
-
|
11 |
-
class CustomFHEClient(FHEModelClient):
|
12 |
-
"""Client interface to encrypt and decrypt FHE data associated to a Filter."""
|
13 |
-
|
14 |
-
def load(self):
|
15 |
-
"""Load the parameters along with the FHE specs."""
|
16 |
-
|
17 |
-
# Load the client
|
18 |
-
self.client = cnp.Client.load(self.path_dir / "client.zip", self.key_dir)
|
19 |
-
|
20 |
-
# Load the filter's parameters from the json file
|
21 |
-
with (self.path_dir / "serialized_processing.json").open("r", encoding="utf-8") as f:
|
22 |
-
serialized_processing = json.load(f)
|
23 |
-
|
24 |
-
# Make sure the version in serialized_model is the same as CML_VERSION
|
25 |
-
assert serialized_processing["cml_version"] == CML_VERSION, (
|
26 |
-
f"The version of Concrete ML library ({CML_VERSION}) is different "
|
27 |
-
f"from the one used to save the model ({serialized_processing['cml_version']}). "
|
28 |
-
"Please update to the proper Concrete ML version.",
|
29 |
-
)
|
30 |
-
|
31 |
-
# Initialize the filter model using its filter name
|
32 |
-
filter_name = serialized_processing["model_post_processing_params"]["filter_name"]
|
33 |
-
self.model = Filter(filter_name)
|
34 |
-
|
35 |
-
return self.model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
filters.py
CHANGED
@@ -2,17 +2,16 @@
|
|
2 |
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
-
from common import AVAILABLE_FILTERS, INPUT_SHAPE
|
6 |
-
from concrete.numpy.compilation.compiler import Compiler
|
7 |
from torch import nn
|
|
|
8 |
|
9 |
-
from concrete.
|
10 |
from concrete.ml.common.utils import generate_proxy_function
|
11 |
from concrete.ml.onnx.convert import get_equivalent_numpy_forward
|
12 |
from concrete.ml.torch.numpy_module import NumpyModule
|
13 |
|
14 |
|
15 |
-
class
|
16 |
"""Torch identity model."""
|
17 |
|
18 |
def forward(self, x):
|
@@ -27,7 +26,7 @@ class _TorchIdentity(nn.Module):
|
|
27 |
return x
|
28 |
|
29 |
|
30 |
-
class
|
31 |
"""Torch inverted model."""
|
32 |
|
33 |
def forward(self, x):
|
@@ -42,7 +41,7 @@ class _TorchInverted(nn.Module):
|
|
42 |
return 255 - x
|
43 |
|
44 |
|
45 |
-
class
|
46 |
"""Torch rotated model."""
|
47 |
|
48 |
def forward(self, x):
|
@@ -57,8 +56,8 @@ class _TorchRotate(nn.Module):
|
|
57 |
return x.transpose(2, 3)
|
58 |
|
59 |
|
60 |
-
class
|
61 |
-
"""Torch model for applying
|
62 |
|
63 |
def __init__(self, kernel, n_in_channels=3, n_out_channels=3, groups=1, threshold=None):
|
64 |
"""Initialize the filter.
|
@@ -74,7 +73,7 @@ class _TorchConv2D(nn.Module):
|
|
74 |
self.threshold = threshold
|
75 |
|
76 |
def forward(self, x):
|
77 |
-
"""Forward pass for filtering the image using a 2D kernel.
|
78 |
|
79 |
Args:
|
80 |
x (torch.Tensor): The input image.
|
@@ -133,20 +132,13 @@ class Filter:
|
|
133 |
filter_name (str): The filter to consider.
|
134 |
"""
|
135 |
|
136 |
-
|
137 |
-
filter_name in AVAILABLE_FILTERS,
|
138 |
f"Unsupported image filter or transformation. Expected one of {*AVAILABLE_FILTERS,}, "
|
139 |
f"but got {filter_name}",
|
140 |
)
|
141 |
|
142 |
-
# Define attributes needed in order to prevent the Concrete-ML client-server interface
|
143 |
-
# from breaking
|
144 |
-
self.post_processing_params = {"filter_name": filter_name}
|
145 |
-
self.input_quantizers = []
|
146 |
-
self.output_quantizers = []
|
147 |
-
|
148 |
# Define attributes associated to the filter
|
149 |
-
self.
|
150 |
self.onnx_model = None
|
151 |
self.fhe_circuit = None
|
152 |
self.divide = None
|
@@ -154,13 +146,13 @@ class Filter:
|
|
154 |
|
155 |
# Instantiate the torch module associated to the given filter name
|
156 |
if filter_name == "identity":
|
157 |
-
self.torch_model =
|
158 |
|
159 |
elif filter_name == "inverted":
|
160 |
-
self.torch_model =
|
161 |
|
162 |
elif filter_name == "rotate":
|
163 |
-
self.torch_model =
|
164 |
|
165 |
elif filter_name == "black and white":
|
166 |
# Define the grayscale weights (RGB order)
|
@@ -173,7 +165,7 @@ class Filter:
|
|
173 |
# post-processing in order to retrieve the correct result
|
174 |
kernel = [299, 587, 114]
|
175 |
|
176 |
-
self.torch_model =
|
177 |
|
178 |
# Define the value used when for dividing the output values in post-processing
|
179 |
self.divide = 1000
|
@@ -185,7 +177,7 @@ class Filter:
|
|
185 |
elif filter_name == "blur":
|
186 |
kernel = np.ones((3, 3))
|
187 |
|
188 |
-
self.torch_model =
|
189 |
|
190 |
# Define the value used when for dividing the output values in post-processing
|
191 |
self.divide = 9
|
@@ -197,7 +189,7 @@ class Filter:
|
|
197 |
[0, -1, 0],
|
198 |
]
|
199 |
|
200 |
-
self.torch_model =
|
201 |
|
202 |
elif filter_name == "ridge detection":
|
203 |
kernel = [
|
@@ -208,18 +200,18 @@ class Filter:
|
|
208 |
|
209 |
# Additionally to the convolution operator, the filter will subtract a given threshold
|
210 |
# value to the result in order to better display the ridges
|
211 |
-
self.torch_model =
|
212 |
|
213 |
# Indicate that the out_channels will need to be repeated, as Gradio requires all
|
214 |
# images to have a RGB format, even for grayscaled ones. Ridge detection images are
|
215 |
# ususally displayed as such
|
216 |
self.repeat_out_channels = True
|
217 |
|
218 |
-
def compile(self,
|
219 |
"""Compile the model on a representative inputset.
|
220 |
|
221 |
Args:
|
222 |
-
|
223 |
generated automatically using a NumpyModule. Default to None.
|
224 |
"""
|
225 |
# Generate a random representative set of images used for compilation, following Torch's
|
@@ -232,17 +224,17 @@ class Filter:
|
|
232 |
)
|
233 |
|
234 |
# If no onnx model was given, generate a new one.
|
235 |
-
if
|
236 |
numpy_module = NumpyModule(
|
237 |
self.torch_model,
|
238 |
dummy_input=torch.from_numpy(inputset[0]),
|
239 |
)
|
240 |
|
241 |
-
|
242 |
|
243 |
# Get the proxy function and parameter mappings for initializing the compiler
|
244 |
-
self.
|
245 |
-
numpy_filter = get_equivalent_numpy_forward(
|
246 |
|
247 |
numpy_filter_proxy, parameters_mapping = generate_proxy_function(numpy_filter, ["inputs"])
|
248 |
|
@@ -256,20 +248,6 @@ class Filter:
|
|
256 |
|
257 |
return self.fhe_circuit
|
258 |
|
259 |
-
def quantize_input(self, input_image):
|
260 |
-
"""Quantize the input.
|
261 |
-
|
262 |
-
Images are already quantized in this case, however we need to define this method in order
|
263 |
-
to prevent the Concrete-ML client-server interface from breaking.
|
264 |
-
|
265 |
-
Args:
|
266 |
-
input_image (np.ndarray): The input to quantize.
|
267 |
-
|
268 |
-
Returns:
|
269 |
-
np.ndarray: The quantized input.
|
270 |
-
"""
|
271 |
-
return input_image
|
272 |
-
|
273 |
def pre_processing(self, input_image):
|
274 |
"""Apply pre-processing to the encrypted input images.
|
275 |
|
|
|
2 |
|
3 |
import numpy as np
|
4 |
import torch
|
|
|
|
|
5 |
from torch import nn
|
6 |
+
from common import AVAILABLE_FILTERS, INPUT_SHAPE
|
7 |
|
8 |
+
from concrete.numpy.compilation.compiler import Compiler
|
9 |
from concrete.ml.common.utils import generate_proxy_function
|
10 |
from concrete.ml.onnx.convert import get_equivalent_numpy_forward
|
11 |
from concrete.ml.torch.numpy_module import NumpyModule
|
12 |
|
13 |
|
14 |
+
class TorchIdentity(nn.Module):
|
15 |
"""Torch identity model."""
|
16 |
|
17 |
def forward(self, x):
|
|
|
26 |
return x
|
27 |
|
28 |
|
29 |
+
class TorchInverted(nn.Module):
|
30 |
"""Torch inverted model."""
|
31 |
|
32 |
def forward(self, x):
|
|
|
41 |
return 255 - x
|
42 |
|
43 |
|
44 |
+
class TorchRotate(nn.Module):
|
45 |
"""Torch rotated model."""
|
46 |
|
47 |
def forward(self, x):
|
|
|
56 |
return x.transpose(2, 3)
|
57 |
|
58 |
|
59 |
+
class TorchConv(nn.Module):
|
60 |
+
"""Torch model for applying convolution operators on images."""
|
61 |
|
62 |
def __init__(self, kernel, n_in_channels=3, n_out_channels=3, groups=1, threshold=None):
|
63 |
"""Initialize the filter.
|
|
|
73 |
self.threshold = threshold
|
74 |
|
75 |
def forward(self, x):
|
76 |
+
"""Forward pass for filtering the image using a 1D or 2D kernel.
|
77 |
|
78 |
Args:
|
79 |
x (torch.Tensor): The input image.
|
|
|
132 |
filter_name (str): The filter to consider.
|
133 |
"""
|
134 |
|
135 |
+
assert filter_name in AVAILABLE_FILTERS, (
|
|
|
136 |
f"Unsupported image filter or transformation. Expected one of {*AVAILABLE_FILTERS,}, "
|
137 |
f"but got {filter_name}",
|
138 |
)
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
# Define attributes associated to the filter
|
141 |
+
self.filter_name = filter_name
|
142 |
self.onnx_model = None
|
143 |
self.fhe_circuit = None
|
144 |
self.divide = None
|
|
|
146 |
|
147 |
# Instantiate the torch module associated to the given filter name
|
148 |
if filter_name == "identity":
|
149 |
+
self.torch_model = TorchIdentity()
|
150 |
|
151 |
elif filter_name == "inverted":
|
152 |
+
self.torch_model = TorchInverted()
|
153 |
|
154 |
elif filter_name == "rotate":
|
155 |
+
self.torch_model = TorchRotate()
|
156 |
|
157 |
elif filter_name == "black and white":
|
158 |
# Define the grayscale weights (RGB order)
|
|
|
165 |
# post-processing in order to retrieve the correct result
|
166 |
kernel = [299, 587, 114]
|
167 |
|
168 |
+
self.torch_model = TorchConv(kernel, n_out_channels=1, groups=1)
|
169 |
|
170 |
# Define the value used when for dividing the output values in post-processing
|
171 |
self.divide = 1000
|
|
|
177 |
elif filter_name == "blur":
|
178 |
kernel = np.ones((3, 3))
|
179 |
|
180 |
+
self.torch_model = TorchConv(kernel, n_out_channels=3, groups=3)
|
181 |
|
182 |
# Define the value used when for dividing the output values in post-processing
|
183 |
self.divide = 9
|
|
|
189 |
[0, -1, 0],
|
190 |
]
|
191 |
|
192 |
+
self.torch_model = TorchConv(kernel, n_out_channels=3, groups=3)
|
193 |
|
194 |
elif filter_name == "ridge detection":
|
195 |
kernel = [
|
|
|
200 |
|
201 |
# Additionally to the convolution operator, the filter will subtract a given threshold
|
202 |
# value to the result in order to better display the ridges
|
203 |
+
self.torch_model = TorchConv(kernel, n_out_channels=1, groups=1, threshold=900)
|
204 |
|
205 |
# Indicate that the out_channels will need to be repeated, as Gradio requires all
|
206 |
# images to have a RGB format, even for grayscaled ones. Ridge detection images are
|
207 |
# ususally displayed as such
|
208 |
self.repeat_out_channels = True
|
209 |
|
210 |
+
def compile(self, onnx_graph=None):
|
211 |
"""Compile the model on a representative inputset.
|
212 |
|
213 |
Args:
|
214 |
+
onnx_graph (onnx.ModelProto): The loaded onnx model to consider. If None, it will be
|
215 |
generated automatically using a NumpyModule. Default to None.
|
216 |
"""
|
217 |
# Generate a random representative set of images used for compilation, following Torch's
|
|
|
224 |
)
|
225 |
|
226 |
# If no onnx model was given, generate a new one.
|
227 |
+
if onnx_graph is None:
|
228 |
numpy_module = NumpyModule(
|
229 |
self.torch_model,
|
230 |
dummy_input=torch.from_numpy(inputset[0]),
|
231 |
)
|
232 |
|
233 |
+
onnx_graph = numpy_module.onnx_model
|
234 |
|
235 |
# Get the proxy function and parameter mappings for initializing the compiler
|
236 |
+
self.onnx_graph = onnx_graph
|
237 |
+
numpy_filter = get_equivalent_numpy_forward(onnx_graph)
|
238 |
|
239 |
numpy_filter_proxy, parameters_mapping = generate_proxy_function(numpy_filter, ["inputs"])
|
240 |
|
|
|
248 |
|
249 |
return self.fhe_circuit
|
250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
def pre_processing(self, input_image):
|
252 |
"""Apply pre-processing to the encrypted input images.
|
253 |
|
filters/black and white/deployment/client.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fa7aa78811be0810d523a8d94a1aa27e24e40d29eced0395fd3bc743568b62b8
|
3 |
+
size 550
|
filters/black and white/deployment/serialized_processing.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"
|
|
|
1 |
+
{"filter_name": "black and white"}
|
filters/black and white/deployment/server.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4364
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5fe0885c010b076062a9b5887d0ee5ebf07f7b879633324af1b14e58a2fefeec
|
3 |
size 4364
|
filters/blur/deployment/client.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9fafcd6cd32109e17bad3ee5945b54c64525f3db5d3e893a5618cfb765a8748e
|
3 |
+
size 542
|
filters/blur/deployment/serialized_processing.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"
|
|
|
1 |
+
{"filter_name": "blur"}
|
filters/blur/deployment/server.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 7263
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:49b6c8a391e67ba424f156ef6049175e5c49b13d5b92052fddf05214741175c6
|
3 |
size 7263
|
filters/identity/deployment/client.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:19a7d7831af7f4a7a55a734a12e772ec41058502138e15925e229c89fcd8b195
|
3 |
+
size 533
|
filters/identity/deployment/serialized_processing.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"
|
|
|
1 |
+
{"filter_name": "identity"}
|
filters/identity/deployment/server.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 2559
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d2891ffa3e35d14d40a79b533fb331d557c82b4a8fe20568aa095aa7a22164a9
|
3 |
size 2559
|
filters/inverted/deployment/client.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:67169abe3c33f7c7f377cd7e3b17031dd43054432a9d1b39f5469417156b5f2d
|
3 |
+
size 533
|
filters/inverted/deployment/serialized_processing.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"
|
|
|
1 |
+
{"filter_name": "inverted"}
|
filters/inverted/deployment/server.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4179
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:781488531b0049ecd05d3cc0e0eb95a9350553848bf218e726e97dce2b3ebd42
|
3 |
size 4179
|
filters/ridge detection/deployment/client.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:05b54c87d88297316aeb864d7292c9a4c930d486e7d0b7232bdf77e9b76a7692
|
3 |
+
size 559
|
filters/ridge detection/deployment/serialized_processing.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"
|
|
|
1 |
+
{"filter_name": "ridge detection"}
|
filters/ridge detection/deployment/server.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 5043
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:238980dd76c8155164b84d0096d11a8cbba25c933f4335fc7369e77f2328bd26
|
3 |
size 5043
|
filters/rotate/deployment/client.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7a3c2ae45ef9887682e3e89d4138b2ce74b8e560b858f3adc0461f98f223f3f
|
3 |
+
size 531
|
filters/rotate/deployment/serialized_processing.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"
|
|
|
1 |
+
{"filter_name": "rotate"}
|
filters/rotate/deployment/server.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4431
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a92a49387f05f4548cb4910506e66a6a2fa591b8f27818934d4283c8c2981a99
|
3 |
size 4431
|
filters/sharpen/deployment/client.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8cf53584e83a91cb975e9f078ef63e777f11453582b664f65685b3a6da89f17e
|
3 |
+
size 550
|
filters/sharpen/deployment/serialized_processing.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"
|
|
|
1 |
+
{"filter_name": "sharpen"}
|
filters/sharpen/deployment/server.zip
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 7311
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4710fbe92afdd4f8f6beff7eca302a46ee4fde5b85fe8aa7d6ab832080ae5a2e
|
3 |
size 7311
|
generate_dev_files.py
CHANGED
@@ -4,7 +4,7 @@ import shutil
|
|
4 |
import onnx
|
5 |
from common import AVAILABLE_FILTERS, FILTERS_PATH
|
6 |
from filters import Filter
|
7 |
-
from
|
8 |
|
9 |
print("Generating deployment files for all available filters")
|
10 |
|
@@ -28,10 +28,10 @@ for filter_name in AVAILABLE_FILTERS:
|
|
28 |
shutil.rmtree(deployment_path)
|
29 |
|
30 |
# Save the files needed for deployment
|
31 |
-
fhe_dev_filter =
|
32 |
fhe_dev_filter.save()
|
33 |
|
34 |
# Save the ONNX model
|
35 |
-
onnx.save(filter.
|
36 |
|
37 |
print("Done !")
|
|
|
4 |
import onnx
|
5 |
from common import AVAILABLE_FILTERS, FILTERS_PATH
|
6 |
from filters import Filter
|
7 |
+
from client_server_interface import FHEDev
|
8 |
|
9 |
print("Generating deployment files for all available filters")
|
10 |
|
|
|
28 |
shutil.rmtree(deployment_path)
|
29 |
|
30 |
# Save the files needed for deployment
|
31 |
+
fhe_dev_filter = FHEDev(filter, deployment_path)
|
32 |
fhe_dev_filter.save()
|
33 |
|
34 |
# Save the ONNX model
|
35 |
+
onnx.save(filter.onnx_graph, filter_path / "server.onnx")
|
36 |
|
37 |
print("Done !")
|
server.py
CHANGED
@@ -2,12 +2,11 @@
|
|
2 |
|
3 |
import time
|
4 |
from typing import List
|
5 |
-
|
6 |
-
from common import FILTERS_PATH, SERVER_TMP_PATH
|
7 |
from fastapi import FastAPI, File, Form, UploadFile
|
8 |
from fastapi.responses import JSONResponse, Response
|
9 |
-
|
10 |
-
from
|
|
|
11 |
|
12 |
|
13 |
def get_server_file_path(name, user_id, filter_name):
|
@@ -24,10 +23,6 @@ def get_server_file_path(name, user_id, filter_name):
|
|
24 |
return SERVER_TMP_PATH / f"{name}_{filter_name}_{user_id}"
|
25 |
|
26 |
|
27 |
-
class FilterRequest(BaseModel):
|
28 |
-
filter: str
|
29 |
-
|
30 |
-
|
31 |
# Initialize an instance of FastAPI
|
32 |
app = FastAPI()
|
33 |
|
@@ -74,7 +69,7 @@ def run_fhe(
|
|
74 |
evaluation_key = evaluation_key_file.read()
|
75 |
|
76 |
# Load the FHE server
|
77 |
-
fhe_server =
|
78 |
|
79 |
# Run the FHE execution
|
80 |
start = time.time()
|
|
|
2 |
|
3 |
import time
|
4 |
from typing import List
|
|
|
|
|
5 |
from fastapi import FastAPI, File, Form, UploadFile
|
6 |
from fastapi.responses import JSONResponse, Response
|
7 |
+
|
8 |
+
from common import FILTERS_PATH, SERVER_TMP_PATH
|
9 |
+
from client_server_interface import FHEServer
|
10 |
|
11 |
|
12 |
def get_server_file_path(name, user_id, filter_name):
|
|
|
23 |
return SERVER_TMP_PATH / f"{name}_{filter_name}_{user_id}"
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
26 |
# Initialize an instance of FastAPI
|
27 |
app = FastAPI()
|
28 |
|
|
|
69 |
evaluation_key = evaluation_key_file.read()
|
70 |
|
71 |
# Load the FHE server
|
72 |
+
fhe_server = FHEServer(FILTERS_PATH / f"{filter}/deployment")
|
73 |
|
74 |
# Run the FHE execution
|
75 |
start = time.time()
|