|
import random |
|
import numpy as np |
|
from torch import nn |
|
import torch |
|
|
|
from concrete.fhe.compilation.compiler import Compiler |
|
from concrete.ml.common.utils import generate_proxy_function |
|
from concrete.ml.torch.numpy_module import NumpyModule |
|
|
|
from common import AVAILABLE_MATCHERS |
|
|
|
|
|
class TorchRandomGuessing(nn.Module): |
|
"""Torch identity model.""" |
|
|
|
def __init__(self, classes_=[0, 1]): |
|
super().__init__() |
|
self.classes_ = classes_ |
|
|
|
def forward(self, x): |
|
"""Random guessing forward pass. |
|
|
|
Args: |
|
x (torch.Tensor): concat of query and reference. |
|
|
|
Returns: |
|
(torch.Tensor): . |
|
""" |
|
x = x.sum() |
|
return torch.tensor([random.choice([0, 1])]) + x - x |
|
|
|
|
|
class Matcher: |
|
def __init__(self, matcher_name): |
|
assert matcher_name in AVAILABLE_MATCHERS, ( |
|
f"Unsupported image matcher. Expected one of {*AVAILABLE_MATCHERS,}, " |
|
f"but got {matcher_name}", |
|
) |
|
self.fhe_circuit = None |
|
self.matcher_name = matcher_name |
|
|
|
if self.matcher_name == "random guessing": |
|
self.torch_model = TorchRandomGuessing() |
|
|
|
def compile(self): |
|
|
|
inputset = (np.array([10]), np.array([5])) |
|
|
|
print("torch module > numpy module ...") |
|
numpy_module = NumpyModule( |
|
|
|
self.torch_model, |
|
|
|
dummy_input=torch.from_numpy(inputset[0]), |
|
) |
|
|
|
print("get proxy function ...") |
|
|
|
|
|
|
|
numpy_filter_proxy, parameters_mapping = generate_proxy_function( |
|
numpy_module.numpy_forward, ["inputs"] |
|
) |
|
|
|
print("Compile the filter and retrieve its FHE circuit ...") |
|
compiler = Compiler( |
|
numpy_filter_proxy, |
|
{ |
|
parameters_mapping["inputs"]: "encrypted", |
|
}, |
|
) |
|
self.fhe_circuit = compiler.compile(inputset) |
|
return self.fhe_circuit |
|
|
|
def post_processing(self, output_result): |
|
"""Apply post-processing to the decrypted output result. |
|
|
|
Args: |
|
output_result (np.ndarray): The decrypted result to post-process. |
|
|
|
Returns: |
|
output_result (np.ndarray): The post-processed result. |
|
""" |
|
print(f"{output_result=}") |
|
|
|
return "PASS" if output_result[0] == 1 else "FAIL" |
|
|
|
|
|
|
|
|
|
|