import random import numpy as np from torch import nn import torch from concrete.fhe.compilation.compiler import Compiler from concrete.ml.common.utils import generate_proxy_function from concrete.ml.torch.numpy_module import NumpyModule from common import AVAILABLE_MATCHERS class TorchRandomGuessing(nn.Module): """Torch identity model.""" def __init__(self, classes_=[0, 1]): super().__init__() self.classes_ = classes_ def forward(self, q, r): """Random guessing forward pass. Args: q (torch.Tensor): The input query. r (torch.Tensor): The input reference. Returns: (torch.Tensor): . """ q = q.sum() r = r.sum() return torch.tensor([random.choice([0, 1])]) + q - q + r - r class Matcher: def __init__(self, matcher_name): assert matcher_name in AVAILABLE_MATCHERS, ( f"Unsupported image matcher. Expected one of {*AVAILABLE_MATCHERS,}, " f"but got {matcher_name}", ) self.fhe_circuit = None self.matcher_name = matcher_name if self.matcher_name == "random guessing": self.torch_model = TorchRandomGuessing() def compile(self): inputset = [(np.array([10]), np.array([5]))] print("torch module > numpy module ...") numpy_module = NumpyModule( # torch_model, dummy_input=torch.from_numpy(np.array([10], dtype=np.int64)) self.torch_model, # dummy_input=(torch.tensor([10]), torch.tensor([5])), dummy_input=( torch.from_numpy(inputset[0][1]), torch.from_numpy(inputset[0][1]), ), ) print("get proxy function ...") # Get the proxy function and parameter mappings used for initializing the compiler # This is done in order to be able to provide any modules with arbitrary numbers of # encrypted arguments to Concrete Numpy's compiler numpy_filter_proxy, parameters_mapping = generate_proxy_function( numpy_module.numpy_forward, ["query", "reference"] ) print("Compile the filter and retrieve its FHE circuit ...") compiler = Compiler( numpy_filter_proxy, { parameters_mapping["query"]: "encrypted", parameters_mapping["reference"]: "encrypted", }, ) self.fhe_circuit = compiler.compile(inputset) return self.fhe_circuit def post_processing(self, output_result): """Apply post-processing to the decrypted output result. Args: output_result (np.ndarray): The decrypted result to post-process. Returns: output_result (np.ndarray): The post-processed result. """ print(f"{output_result=}") return "PASS" if output_result[0] == 1 else "FAIL" # matcher = Matcher(matcher_name=AVAILABLE_MATCHERS[0]) # fhe_circuit = matcher.compile()