team14 / matchers.py
tagny's picture
image pairs matching - issue on config file in server.zip
0908a41
raw
history blame
3.02 kB
import random
import numpy as np
from torch import nn
import torch
from concrete.fhe.compilation.compiler import Compiler
from concrete.ml.common.utils import generate_proxy_function
from concrete.ml.torch.numpy_module import NumpyModule
from common import AVAILABLE_MATCHERS
class TorchRandomGuessing(nn.Module):
"""Torch identity model."""
def __init__(self, classes_=[0, 1]):
super().__init__()
self.classes_ = classes_
def forward(self, q, r):
"""Random guessing forward pass.
Args:
q (torch.Tensor): The input query.
r (torch.Tensor): The input reference.
Returns:
(torch.Tensor): .
"""
q = q.sum()
r = r.sum()
return torch.tensor([random.choice([0, 1])]) + q - q + r - r
class Matcher:
def __init__(self, matcher_name):
assert matcher_name in AVAILABLE_MATCHERS, (
f"Unsupported image matcher. Expected one of {*AVAILABLE_MATCHERS,}, "
f"but got {matcher_name}",
)
self.fhe_circuit = None
self.matcher_name = matcher_name
if self.matcher_name == "random guessing":
self.torch_model = TorchRandomGuessing()
def compile(self):
inputset = [(np.array([10]), np.array([5]))]
print("torch module > numpy module ...")
numpy_module = NumpyModule(
# torch_model, dummy_input=torch.from_numpy(np.array([10], dtype=np.int64))
self.torch_model,
# dummy_input=(torch.tensor([10]), torch.tensor([5])),
dummy_input=(
torch.from_numpy(inputset[0][1]),
torch.from_numpy(inputset[0][1]),
),
)
print("get proxy function ...")
# Get the proxy function and parameter mappings used for initializing the compiler
# This is done in order to be able to provide any modules with arbitrary numbers of
# encrypted arguments to Concrete Numpy's compiler
numpy_filter_proxy, parameters_mapping = generate_proxy_function(
numpy_module.numpy_forward, ["query", "reference"]
)
print("Compile the filter and retrieve its FHE circuit ...")
compiler = Compiler(
numpy_filter_proxy,
{
parameters_mapping["query"]: "encrypted",
parameters_mapping["reference"]: "encrypted",
},
)
self.fhe_circuit = compiler.compile(inputset)
return self.fhe_circuit
def post_processing(self, output_result):
"""Apply post-processing to the decrypted output result.
Args:
output_result (np.ndarray): The decrypted result to post-process.
Returns:
output_result (np.ndarray): The post-processed result.
"""
print(f"{output_result=}")
return "PASS" if output_result[0] == 1 else "FAIL"
# matcher = Matcher(matcher_name=AVAILABLE_MATCHERS[0])
# fhe_circuit = matcher.compile()