|
"""Train and compile the model.""" |
|
|
|
import shutil |
|
import numpy |
|
import pandas |
|
import pickle |
|
|
|
from settings import ( |
|
DEPLOYMENT_PATH, |
|
DATA_PATH, |
|
INPUT_SLICES, |
|
PRE_PROCESSOR_APPLICANT_PATH, |
|
PRE_PROCESSOR_BANK_PATH, |
|
PRE_PROCESSOR_CREDIT_BUREAU_PATH, |
|
APPLICANT_COLUMNS, |
|
BANK_COLUMNS, |
|
CREDIT_BUREAU_COLUMNS, |
|
) |
|
from utils.client_server_interface import MultiInputsFHEModelDev |
|
from utils.model import MultiInputDecisionTreeClassifier |
|
from utils.pre_processing import get_pre_processors |
|
|
|
|
|
def get_multi_inputs(data): |
|
"""Get inputs for all three parties from the input data, using fixed slices. |
|
|
|
Args: |
|
data (numpy.ndarray): The input data to consider. |
|
|
|
Returns: |
|
(Tuple[numpy.ndarray]): The inputs for all three parties. |
|
""" |
|
return ( |
|
data[:, INPUT_SLICES["applicant"]], |
|
data[:, INPUT_SLICES["bank"]], |
|
data[:, INPUT_SLICES["credit_bureau"]] |
|
) |
|
|
|
|
|
print("Load and pre-process the data") |
|
|
|
|
|
data = pandas.read_csv(DATA_PATH, encoding="utf-8") |
|
|
|
|
|
data_x = data.copy() |
|
data_y = data_x.pop("Target").copy().to_frame() |
|
|
|
|
|
data_applicant = data_x[APPLICANT_COLUMNS].copy() |
|
data_bank = data_x[BANK_COLUMNS].copy() |
|
data_credit_bureau = data_x[CREDIT_BUREAU_COLUMNS].copy() |
|
|
|
|
|
pre_processor_applicant, pre_processor_bank, pre_processor_credit_bureau = get_pre_processors() |
|
|
|
preprocessed_data_applicant = pre_processor_applicant.fit_transform(data_applicant) |
|
preprocessed_data_bank = pre_processor_bank.fit_transform(data_bank) |
|
preprocessed_data_credit_bureau = pre_processor_credit_bureau.fit_transform(data_credit_bureau) |
|
|
|
preprocessed_data_x = numpy.concatenate((preprocessed_data_applicant, preprocessed_data_bank, preprocessed_data_credit_bureau), axis=1) |
|
|
|
|
|
print("\nTrain and compile the model") |
|
|
|
model = MultiInputDecisionTreeClassifier() |
|
|
|
model, sklearn_model = model.fit_benchmark(preprocessed_data_x, data_y) |
|
|
|
multi_inputs_train = get_multi_inputs(preprocessed_data_x) |
|
|
|
model.compile(*multi_inputs_train, inputs_encryption_status=["encrypted", "encrypted", "encrypted"]) |
|
|
|
print("\nSave deployment files") |
|
|
|
|
|
if DEPLOYMENT_PATH.is_dir(): |
|
shutil.rmtree(DEPLOYMENT_PATH) |
|
|
|
|
|
fhe_model_dev = MultiInputsFHEModelDev(DEPLOYMENT_PATH, model) |
|
fhe_model_dev.save(via_mlir=True) |
|
|
|
|
|
with ( |
|
PRE_PROCESSOR_APPLICANT_PATH.open('wb') as file_applicant, |
|
PRE_PROCESSOR_BANK_PATH.open('wb') as file_bank, |
|
PRE_PROCESSOR_CREDIT_BUREAU_PATH.open('wb') as file_credit_bureau, |
|
): |
|
pickle.dump(pre_processor_applicant, file_applicant) |
|
pickle.dump(pre_processor_bank, file_bank) |
|
pickle.dump(pre_processor_credit_bureau, file_credit_bureau) |
|
|
|
print("\nDone !") |
|
|