File size: 6,757 Bytes
21c7197 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
"Client-server interface implementation for custom models."
from pathlib import Path
from typing import Any
import concrete.numpy as cnp
import numpy as np
from filters import Filter
from concrete.ml.common.debugging.custom_assert import assert_true
class CustomFHEDev:
"""Dev API to save the custom model and then load and run the FHE circuit."""
model: Any = None
def __init__(self, path_dir: str, model: Any = None):
"""Initialize the FHE API.
Args:
path_dir (str): the path to the directory where the circuit is saved
model (Any): the model to use for the FHE API
"""
self.path_dir = Path(path_dir)
self.model = model
# Create the directory path if it does not exist yet
Path(self.path_dir).mkdir(parents=True, exist_ok=True)
def save(self):
"""Export all needed artifacts for the client and server.
Raises:
Exception: path_dir is not empty
"""
# Check if the path_dir is empty with pathlib
listdir = list(Path(self.path_dir).glob("**/*"))
if len(listdir) > 0:
raise Exception(
f"path_dir: {self.path_dir} is not empty."
"Please delete it before saving a new model."
)
assert_true(
hasattr(self.model, "fhe_circuit"),
"The model must be compiled and have a fhe_circuit object",
)
# Model must be compiled with jit=False
# In a jit model, everything is in memory so it is not serializable.
assert_true(
not self.model.fhe_circuit.configuration.jit,
"The model must be compiled with the configuration option jit=False.",
)
# Export the parameters
self.model.to_json(path_dir=self.path_dir, file_name="serialized_processing")
# Save the circuit for the server
path_circuit_server = self.path_dir / "server.zip"
self.model.fhe_circuit.server.save(path_circuit_server)
# Save the circuit for the client
path_circuit_client = self.path_dir / "client.zip"
self.model.fhe_circuit.client.save(path_circuit_client)
class CustomFHEClient:
"""Client API to encrypt and decrypt FHE data."""
client: cnp.Client
def __init__(self, path_dir: str, key_dir: str = None):
"""Initialize the FHE API.
Args:
path_dir (str): the path to the directory where the circuit is saved
key_dir (str): the path to the directory where the keys are stored
"""
self.path_dir = Path(path_dir)
self.key_dir = Path(key_dir)
# If path_dir does not exist, raise an error
assert_true(
Path(path_dir).exists(), f"{path_dir} does not exist. Please specify a valid path."
)
# Load
self.load()
def load(self): # pylint: disable=no-value-for-parameter
"""Load the parameters along with the FHE specs."""
# Load the client
self.client = cnp.Client.load(self.path_dir / "client.zip", self.key_dir)
# Load the model
self.model = Filter.from_json(self.path_dir / "serialized_processing.json")
def generate_private_and_evaluation_keys(self, force=False):
"""Generate the private and evaluation keys.
Args:
force (bool): if True, regenerate the keys even if they already exist
"""
self.client.keygen(force)
def get_serialized_evaluation_keys(self) -> cnp.EvaluationKeys:
"""Get the serialized evaluation keys.
Returns:
cnp.EvaluationKeys: the evaluation keys
"""
return self.client.evaluation_keys.serialize()
def pre_process_encrypt_serialize(self, x: np.ndarray) -> cnp.PublicArguments:
"""Encrypt and serialize the values.
Args:
x (numpy.ndarray): the values to encrypt and serialize
Returns:
cnp.PublicArguments: the encrypted and serialized values
"""
# Pre-process the values
x = self.model.pre_processing(x)
# Encrypt the values
enc_x = self.client.encrypt(x)
# Serialize the encrypted values to be sent to the server
serialized_enc_x = self.client.specs.serialize_public_args(enc_x)
return serialized_enc_x
def deserialize_decrypt_post_process(
self, serialized_encrypted_output: cnp.PublicArguments
) -> np.ndarray:
"""Deserialize, decrypt and post-process the values.
Args:
serialized_encrypted_output (cnp.PublicArguments): the serialized and encrypted output
Returns:
numpy.ndarray: the decrypted values
"""
# Deserialize the encrypted values
deserialized_encrypted_output = self.client.specs.unserialize_public_result(
serialized_encrypted_output
)
# Decrypt the values
deserialized_decrypted_output = self.client.decrypt(deserialized_encrypted_output)
# Apply the model post processing
deserialized_decrypted_output = self.model.post_processing(deserialized_decrypted_output)
return deserialized_decrypted_output
class CustomFHEServer:
"""Server API to load and run the FHE circuit."""
server: cnp.Server
def __init__(self, path_dir: str):
"""Initialize the FHE API.
Args:
path_dir (str): the path to the directory where the circuit is saved
"""
self.path_dir = Path(path_dir)
# Load the FHE circuit
self.load()
def load(self):
"""Load the circuit."""
self.server = cnp.Server.load(self.path_dir / "server.zip")
def run(
self,
serialized_encrypted_data: cnp.PublicArguments,
serialized_evaluation_keys: cnp.EvaluationKeys,
) -> cnp.PublicResult:
"""Run the model on the server over encrypted data.
Args:
serialized_encrypted_data (cnp.PublicArguments): the encrypted and serialized data
serialized_evaluation_keys (cnp.EvaluationKeys): the serialized evaluation keys
Returns:
cnp.PublicResult: the result of the model
"""
assert_true(self.server is not None, "Model has not been loaded.")
deserialized_encrypted_data = self.server.client_specs.unserialize_public_args(
serialized_encrypted_data
)
deserialized_evaluation_keys = cnp.EvaluationKeys.unserialize(serialized_evaluation_keys)
result = self.server.run(deserialized_encrypted_data, deserialized_evaluation_keys)
serialized_result = self.server.client_specs.serialize_public_result(result)
return serialized_result
|