team14 / matchers.py
tagny's picture
team14: verio - working version 1
fa707a9
raw
history blame
No virus
2.79 kB
import random
import numpy as np
from torch import nn
import torch
from concrete.fhe.compilation.compiler import Compiler
from concrete.ml.common.utils import generate_proxy_function
from concrete.ml.torch.numpy_module import NumpyModule
from common import AVAILABLE_MATCHERS
class TorchRandomGuessing(nn.Module):
"""Torch identity model."""
def __init__(self, classes_=[0, 1]):
super().__init__()
self.classes_ = classes_
def forward(self, x):
"""Random guessing forward pass.
Args:
x (torch.Tensor): concat of query and reference.
Returns:
(torch.Tensor): .
"""
x = x.sum()
return torch.tensor([random.choice([0, 1])]) + x - x
class Matcher:
def __init__(self, matcher_name):
assert matcher_name in AVAILABLE_MATCHERS, (
f"Unsupported image matcher. Expected one of {*AVAILABLE_MATCHERS,}, "
f"but got {matcher_name}",
)
self.fhe_circuit = None
self.matcher_name = matcher_name
if self.matcher_name == "random guessing":
self.torch_model = TorchRandomGuessing()
def compile(self):
inputset = (np.array([10]), np.array([5]))
print("torch module > numpy module ...")
numpy_module = NumpyModule(
# torch_model, dummy_input=torch.from_numpy(np.array([10], dtype=np.int64))
self.torch_model,
# dummy_input=(torch.tensor([10]), torch.tensor([5])),
dummy_input=torch.from_numpy(inputset[0]),
)
print("get proxy function ...")
# Get the proxy function and parameter mappings used for initializing the compiler
# This is done in order to be able to provide any modules with arbitrary numbers of
# encrypted arguments to Concrete Numpy's compiler
numpy_filter_proxy, parameters_mapping = generate_proxy_function(
numpy_module.numpy_forward, ["inputs"]
)
print("Compile the filter and retrieve its FHE circuit ...")
compiler = Compiler(
numpy_filter_proxy,
{
parameters_mapping["inputs"]: "encrypted",
},
)
self.fhe_circuit = compiler.compile(inputset)
return self.fhe_circuit
def post_processing(self, output_result):
"""Apply post-processing to the decrypted output result.
Args:
output_result (np.ndarray): The decrypted result to post-process.
Returns:
output_result (np.ndarray): The post-processed result.
"""
print(f"{output_result=}")
return "PASS" if output_result[0] == 1 else "FAIL"
# matcher = Matcher(matcher_name=AVAILABLE_MATCHERS[0])
# fhe_circuit = matcher.compile()