File size: 2,790 Bytes
0908a41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa707a9
0908a41
 
 
fa707a9
0908a41
 
 
 
fa707a9
 
0908a41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa707a9
0908a41
 
 
 
 
 
fa707a9
0908a41
 
 
 
 
 
 
fa707a9
0908a41
 
 
 
 
 
fa707a9
0908a41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import random
import numpy as np
from torch import nn
import torch

from concrete.fhe.compilation.compiler import Compiler
from concrete.ml.common.utils import generate_proxy_function
from concrete.ml.torch.numpy_module import NumpyModule

from common import AVAILABLE_MATCHERS


class TorchRandomGuessing(nn.Module):
    """Torch identity model."""

    def __init__(self, classes_=[0, 1]):
        super().__init__()
        self.classes_ = classes_

    def forward(self, x):
        """Random guessing forward pass.

        Args:
            x (torch.Tensor): concat of query and reference.

        Returns:
             (torch.Tensor): .
        """
        x = x.sum()
        return torch.tensor([random.choice([0, 1])]) + x - x


class Matcher:
    def __init__(self, matcher_name):
        assert matcher_name in AVAILABLE_MATCHERS, (
            f"Unsupported image matcher. Expected one of {*AVAILABLE_MATCHERS,}, "
            f"but got {matcher_name}",
        )
        self.fhe_circuit = None
        self.matcher_name = matcher_name

        if self.matcher_name == "random guessing":
            self.torch_model = TorchRandomGuessing()

    def compile(self):

        inputset = (np.array([10]), np.array([5]))

        print("torch module > numpy module ...")
        numpy_module = NumpyModule(
            # torch_model, dummy_input=torch.from_numpy(np.array([10], dtype=np.int64))
            self.torch_model,
            # dummy_input=(torch.tensor([10]), torch.tensor([5])),
            dummy_input=torch.from_numpy(inputset[0]),
        )

        print("get proxy function ...")
        # Get the proxy function and parameter mappings used for initializing the compiler
        # This is done in order to be able to provide any modules with arbitrary numbers of
        # encrypted arguments to Concrete Numpy's compiler
        numpy_filter_proxy, parameters_mapping = generate_proxy_function(
            numpy_module.numpy_forward, ["inputs"]
        )

        print("Compile the filter and retrieve its FHE circuit ...")
        compiler = Compiler(
            numpy_filter_proxy,
            {
                parameters_mapping["inputs"]: "encrypted",
            },
        )
        self.fhe_circuit = compiler.compile(inputset)
        return self.fhe_circuit

    def post_processing(self, output_result):
        """Apply post-processing to the decrypted output result.

        Args:
            output_result (np.ndarray): The decrypted result to post-process.

        Returns:
            output_result (np.ndarray): The post-processed result.
        """
        print(f"{output_result=}")

        return "PASS" if output_result[0] == 1 else "FAIL"


# matcher = Matcher(matcher_name=AVAILABLE_MATCHERS[0])
# fhe_circuit = matcher.compile()