File size: 5,642 Bytes
0908a41
7b32412
a13c444
2a040cc
0908a41
7b32412
 
 
 
 
 
 
 
 
 
 
 
 
 
a13c444
7b32412
0908a41
 
 
 
 
 
 
7b32412
 
0908a41
 
7b32412
 
 
0908a41
7b32412
 
0908a41
 
 
 
a13c444
7b32412
0908a41
 
 
 
 
 
7b32412
 
a13c444
7b32412
2a040cc
7b32412
 
 
0908a41
7b32412
0908a41
7b32412
 
 
0908a41
5369c66
7b32412
 
0908a41
7b32412
 
 
 
 
 
0908a41
 
 
 
7b32412
5a6dcda
 
7b32412
0908a41
7b32412
 
 
0908a41
7b32412
 
 
0908a41
7b32412
0908a41
7b32412
 
 
2a040cc
0908a41
5369c66
7b32412
 
 
 
 
0908a41
 
 
7b32412
 
a13c444
7b32412
0908a41
 
7b32412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1974e22
 
7b32412
 
1974e22
7b32412
 
 
 
 
1974e22
7b32412
 
a13c444
7b32412
 
 
0908a41
7b32412
 
 
 
 
 
 
 
0908a41
7b32412
0908a41
 
7b32412
0908a41
 
7b32412
0908a41
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"Client-server interface custom implementation for matcher models."

from concrete import fhe

from matchers import Matcher


class FHEServer:
    """Server interface run a FHE circuit."""

    def __init__(self, path_dir):
        """Initialize the FHE interface.

        Args:
            path_dir (Path): The path to the directory where the circuit is saved.
        """
        self.path_dir = path_dir

        # Load the FHE circuit
        self.server = fhe.Server.load(self.path_dir / "server.zip")

    def run(
        self,
        serialized_encrypted_query_image,
        serialized_encrypted_reference_image,
        serialized_evaluation_keys,
    ):
        """Run the matcher on the server over an encrypted image.

        Args:
            serialized_encrypted_query_image (bytes): The encrypted and serialized query image.
            serialized_encrypted_reference_image (bytes): The encrypted and serialized referenceimage.
            serialized_evaluation_keys (bytes): The serialized evaluation keys.

        Returns:
            bytes: The matcher's output.
        """
        # Deserialize the encrypted input image and the evaluation keys
        encrypted_query_image = fhe.Value.deserialize(serialized_encrypted_query_image)
        encrypted_reference_image = fhe.Value.deserialize(
            serialized_encrypted_reference_image
        )
        evaluation_keys = fhe.EvaluationKeys.deserialize(serialized_evaluation_keys)

        # Execute the matcher in FHE
        encrypted_output = self.server.run(
            encrypted_query_image,
            encrypted_reference_image,
            evaluation_keys=evaluation_keys,
        )

        # Serialize the encrypted output image
        serialized_encrypted_output = encrypted_output.serialize()

        return serialized_encrypted_output


class FHEDev:
    """Development interface to save and load the matcher."""

    def __init__(self, matcher, path_dir):
        """Initialize the FHE interface.

        Args:
            matcher (Matcher): The matcher to use in the FHE interface.
            path_dir (str): The path to the directory where the circuit is saved.
        """

        self.matcher = matcher
        self.path_dir = path_dir

        self.path_dir.mkdir(parents=True, exist_ok=True)

    def save(self):
        """Export all needed artifacts for the client and server interfaces."""

        assert (
            self.matcher.fhe_circuit is not None
        ), "The model must be compiled before saving it."

        # Save the circuit for the server, using the via_mlir in order to handle cross-platform
        # execution
        path_circuit_server = self.path_dir / "server.zip"
        self.matcher.fhe_circuit.server.save(path_circuit_server, via_mlir=True)

        # Save the circuit for the client
        path_circuit_client = self.path_dir / "client.zip"
        self.matcher.fhe_circuit.client.save(path_circuit_client)


class FHEClient:
    """Client interface to encrypt and decrypt FHE data associated to a matcher."""

    def __init__(self, path_dir, matcher_name, key_dir=None):
        """Initialize the FHE interface.

        Args:
            path_dir (Path): The path to the directory where the circuit is saved.
            matcher_name (str): The matcher's name to consider.
            key_dir (Path): The path to the directory where the keys are stored. Default to None.
        """
        self.path_dir = path_dir
        self.key_dir = key_dir

        # If path_dir does not exist raise
        assert (
            path_dir.exists()
        ), f"{path_dir} does not exist. Please specify a valid path."

        # Load the client
        self.client = fhe.Client.load(self.path_dir / "client.zip", self.key_dir)

        # Instantiate the matcher
        self.matcher = Matcher(matcher_name)

    def generate_private_and_evaluation_keys(self, force=False):
        """Generate the private and evaluation keys.

        Args:
            force (bool): If True, regenerate the keys even if they already exist.
        """
        self.client.keygen(force)

    def get_serialized_evaluation_keys(self):
        """Get the serialized evaluation keys.

        Returns:
            bytes: The evaluation keys.
        """
        return self.client.evaluation_keys.serialize()

    def encrypt_serialize(self, input_image):
        """Encrypt and serialize the input image in the clear.

        Args:
            input_image (numpy.ndarray): The image to encrypt and serialize.

        Returns:
            bytes: The pre-processed, encrypted and serialized image.
        """
        # Encrypt the image
        encrypted_image = self.client.encrypt(input_image)

        # Serialize the encrypted image to be sent to the server
        serialized_encrypted_image = encrypted_image.serialize()
        return serialized_encrypted_image

    def deserialize_decrypt_post_process(self, serialized_encrypted_output_image):
        """Deserialize, decrypt and post-process the output result in the clear.

        Args:
            serialized_encrypted_output_image (bytes): The serialized and encrypted output image.

        Returns:
            numpy.ndarray: The decrypted, deserialized and post-processed image.
        """
        # Deserialize the encrypted image
        encrypted_output = fhe.Value.deserialize(serialized_encrypted_output_image)

        # Decrypt the result
        output_result = self.client.decrypt(encrypted_output)

        # Post-process the result
        post_processed_output = self.matcher.post_processing(output_result)

        return post_processed_output