File size: 5,231 Bytes
7b32412
 
2a040cc
 
7b32412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a040cc
7b32412
 
2a040cc
7b32412
 
2a040cc
 
7b32412
 
 
2a040cc
 
 
7b32412
2a040cc
7b32412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a040cc
7b32412
 
 
2a040cc
 
7b32412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1974e22
 
7b32412
 
1974e22
7b32412
 
 
 
 
1974e22
7b32412
 
 
 
 
 
2a040cc
7b32412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"Client-server interface custom implementation for filter models."

import concrete.numpy as cnp

from filters import Filter


class FHEServer:
    """Server interface run a FHE circuit."""

    def __init__(self, path_dir):
        """Initialize the FHE interface.

        Args:
            path_dir (Path): The path to the directory where the circuit is saved.
        """
        self.path_dir = path_dir

        # Load the FHE circuit
        self.server = cnp.Server.load(self.path_dir / "server.zip")

    def run(self, serialized_encrypted_image, serialized_evaluation_keys):
        """Run the filter on the server over an encrypted image.

        Args:
            serialized_encrypted_image (bytes): The encrypted and serialized image.
            serialized_evaluation_keys (bytes): The serialized evaluation keys.

        Returns:
            bytes: The filter's output.
        """
        # Deserialize the encrypted input image and the evaluation keys
        encrypted_image = self.server.client_specs.unserialize_public_args(
            serialized_encrypted_image
        )
        evaluation_keys = cnp.EvaluationKeys.unserialize(serialized_evaluation_keys)

        # Execute the filter in FHE
        encrypted_output = self.server.run(
            encrypted_image, evaluation_keys
        )

        # Serialize the encrypted output image
        serialized_encrypted_output = self.server.client_specs.serialize_public_result(
            encrypted_output
        )

        return serialized_encrypted_output


class FHEDev:
    """Development interface to save and load the filter."""

    def __init__(self, filter, path_dir):
        """Initialize the FHE interface.

        Args:
            path_dir (str): The path to the directory where the circuit is saved.
            filter (Filter): The filter to use in the FHE interface.
        """

        self.filter = filter
        self.path_dir = path_dir

        self.path_dir.mkdir(parents=True, exist_ok=True)

    def save(self):
        """Export all needed artifacts for the client and server interfaces."""
        
        assert self.filter.fhe_circuit is not None, (
            "The model must be compiled before saving it."
        )

        # Save the circuit for the server
        path_circuit_server = self.path_dir / "server.zip"
        self.filter.fhe_circuit.server.save(path_circuit_server)

        # Save the circuit for the client
        path_circuit_client = self.path_dir / "client.zip"
        self.filter.fhe_circuit.client.save(path_circuit_client)


class FHEClient:
    """Client interface to encrypt and decrypt FHE data associated to a Filter."""

    def __init__(self, path_dir, key_dir, filter_name):
        """Initialize the FHE interface.

        Args:
            path_dir (Path): The path to the directory where the circuit is saved.
            key_dir (Path): The path to the directory where the keys are stored.
        """
        self.path_dir = path_dir
        self.key_dir = key_dir

        # If path_dir does not exist raise
        assert path_dir.exists(), f"{path_dir} does not exist. Please specify a valid path."

        # Load the client
        self.client = cnp.Client.load(self.path_dir / "client.zip", self.key_dir)

        # Instantiate the filter
        self.filter = Filter(filter_name)

    def generate_private_and_evaluation_keys(self, force=False):
        """Generate the private and evaluation keys.

        Args:
            force (bool): If True, regenerate the keys even if they already exist.
        """
        self.client.keygen(force)

    def get_serialized_evaluation_keys(self):
        """Get the serialized evaluation keys.

        Returns:
            bytes: The evaluation keys.
        """
        return self.client.evaluation_keys.serialize()

    def encrypt_serialize(self, input_image):
        """Encrypt and serialize the input image in the clear.

        Args:
            input_image (numpy.ndarray): The image to encrypt and serialize.

        Returns:
            bytes: The pre-processed, encrypted and serialized image.
        """
        # Encrypt the image
        encrypted_image = self.client.encrypt(input_image)

        # Serialize the encrypted image to be sent to the server
        serialized_encrypted_image = self.client.specs.serialize_public_args(encrypted_image)
        return serialized_encrypted_image

    def deserialize_decrypt_post_process(self, serialized_encrypted_output_image):
        """Deserialize, decrypt and post-process the output image in the clear.

        Args:
            serialized_encrypted_output_image (bytes): The serialized and encrypted output image.

        Returns:
            numpy.ndarray: The decrypted, deserialized and post-processed image.
        """
        # Deserialize the encrypted image
        encrypted_output_image = self.client.specs.unserialize_public_result(
            serialized_encrypted_output_image
        )

        # Decrypt the image
        output_image = self.client.decrypt(encrypted_output_image)

        # Post-process the image
        post_processed_output_image = self.filter.post_processing(output_image)

        return post_processed_output_image