Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
from pathlib import Path | |
from typing import List, Tuple, Union | |
import numpy | |
import pandas | |
from concrete.ml.sklearn import XGBClassifier as ConcreteXGBoostClassifier | |
# Max Input to be displayed on the HuggingFace space brower using Gradio | |
# Too large inputs, slow down the server: https://github.com/gradio-app/gradio/issues/1877 | |
INPUT_BROWSER_LIMIT = 380 | |
# Store the server's URL | |
SERVER_URL = "http://localhost:8000/" | |
CURRENT_DIR = Path(__file__).parent | |
DEPLOYMENT_DIR = CURRENT_DIR / "deployment_files_generic" | |
DEPLOYMENT_DIR_MODEL1 = CURRENT_DIR / "deployment_files_model1" | |
DEPLOYMENT_DIR_MODEL2 = CURRENT_DIR / "deployment_files_model2" | |
DEPLOYMENT_DIR_MODEL3 = CURRENT_DIR / "deployment_files_model3" | |
KEYS_DIR = DEPLOYMENT_DIR / ".fhe_keys" | |
CLIENT_DIR = DEPLOYMENT_DIR / "client_dir" | |
SERVER_DIR = DEPLOYMENT_DIR / "server_dir" | |
ALL_DIRS = [KEYS_DIR, CLIENT_DIR, SERVER_DIR] | |
# Columns that define the target | |
TARGET_COLUMNS = ["prognosis_encoded", "prognosis"] | |
TRAINING_FILENAME = "./data/Training_preprocessed.csv" | |
TESTING_FILENAME = "./data/Testing_preprocessed.csv" | |
# pylint: disable=invalid-name | |
def pretty_print( | |
inputs, case_conversion=str.title, which_replace: str = "_", to_what: str = " ", delimiter=None | |
): | |
""" | |
Prettify and sort the input as a list of string. | |
Args: | |
inputs (Any): The inputs to be prettified. | |
Returns: | |
List: The prettified and sorted list of inputs. | |
""" | |
# Flatten the list if required | |
pretty_list = [] | |
for item in inputs: | |
if isinstance(item, list): | |
pretty_list.extend(item) | |
else: | |
pretty_list.append(item) | |
# Sort | |
pretty_list = sorted(list(set(pretty_list))) | |
# Replace | |
pretty_list = [item.replace(which_replace, to_what) for item in pretty_list] | |
pretty_list = [case_conversion(item) for item in pretty_list] | |
if delimiter: | |
pretty_list = f"{delimiter.join(pretty_list)}." | |
return pretty_list | |
def clean_directory() -> None: | |
""" | |
Clear direcgtories | |
""" | |
print("Cleaning...\n") | |
for target_dir in ALL_DIRS: | |
if os.path.exists(target_dir) and os.path.isdir(target_dir): | |
shutil.rmtree(target_dir) | |
target_dir.mkdir(exist_ok=True, parents=True) | |
def get_disease_name(encoded_prediction: int, file_name: str = TRAINING_FILENAME) -> str: | |
"""Return the disease name given its encoded label. | |
Args: | |
encoded_prediction (int): The encoded prediction | |
file_name (str): The data file path | |
Returns: | |
str: The according disease name | |
""" | |
df = pandas.read_csv(file_name, usecols=TARGET_COLUMNS).drop_duplicates() | |
disease_name, _ = df[df[TARGET_COLUMNS[0]] == encoded_prediction].values.flatten() | |
return disease_name | |
def load_data() -> Union[Tuple[pandas.DataFrame, numpy.ndarray], List]: | |
""" | |
Return the data | |
Args: | |
None | |
Return: | |
The train, testing set and valid symptoms. | |
""" | |
# Load data | |
df_train = pandas.read_csv(TRAINING_FILENAME) | |
df_test = pandas.read_csv(TESTING_FILENAME) | |
# Separate the traget from the training / testing set: | |
# TARGET_COLUMNS[0] -> "prognosis_encoded" -> contains the numeric label of the disease | |
# TARGET_COLUMNS[1] -> "prognosis" -> contains the name of the disease | |
y_train = df_train[TARGET_COLUMNS[0]] | |
X_train = df_train.drop(columns=TARGET_COLUMNS, axis=1, errors="ignore") | |
y_test = df_test[TARGET_COLUMNS[0]] | |
X_test = df_test.drop(columns=TARGET_COLUMNS, axis=1, errors="ignore") | |
return ( | |
(X_train, X_test), | |
(y_train, y_test), | |
X_train.columns.to_list(), | |
df_train[TARGET_COLUMNS[1]].unique().tolist(), | |
) | |
def load_model(X_train: pandas.DataFrame, y_train: numpy.ndarray): | |
""" | |
Load a pre-trained serialized model | |
Args: | |
X_train (pandas.DataFrame): Training set | |
y_train (numpy.ndarray): Targets of the training set | |
Return: | |
The Concrete ML model and its circuit | |
""" | |
# Parameters | |
concrete_args = {"max_depth": 1, "n_bits": 3, "n_estimators": 3, "n_jobs": -1} | |
classifier = ConcreteXGBoostClassifier(**concrete_args) | |
# Train the model | |
classifier.fit(X_train, y_train) | |
# Compile the model | |
circuit = classifier.compile(X_train) | |
return classifier, circuit | |