team1_Dhiria / app.py
lucacolombo97's picture
Update app.py
99e13c9 verified
raw
history blame
25.3 kB
import subprocess
import time
from typing import Dict, List, Tuple
import gradio as gr # pylint: disable=import-error
import numpy as np
import pandas as pd
import requests
from stuff import get_emoticon, plot_tachometer
from utils import (
CLIENT_DIR,
CURRENT_DIR,
DEPLOYMENT_DIR_MODEL1,
DEPLOYMENT_DIR_MODEL2,
DEPLOYMENT_DIR_MODEL3,
INPUT_BROWSER_LIMIT,
KEYS_DIR,
SERVER_URL,
clean_directory,
)
from dev_dhiria import frequency_domain, interpolation, statistics
from concrete.ml.deployment import FHEModelClient
global_df1 = None
global_df2 = None
global_output_1 = None
global_output_2 = None
subprocess.Popen(["uvicorn", "server:app"], cwd=CURRENT_DIR)
time.sleep(3)
# pylint: disable=c-extension-no-member,invalid-name
def is_none(obj) -> bool:
"""
Check if the object is None.
Args:
obj (any): The input to be checked.
Returns:
bool: True if the object is None or empty, False otherwise.
"""
return obj is None or (obj is not None and len(obj) < 1)
def key_gen_fn() -> Dict:
"""
Generate keys for a given user.
Args:
user_symptoms (List[str]): The vector symptoms provided by the user.
Returns:
dict: A dictionary containing the generated keys and related information.
"""
clean_directory()
# Generate a random user ID
user_id = np.random.randint(0, 2**32)
print(f"Your user ID is: {user_id}....")
client = FHEModelClient(path_dir=DEPLOYMENT_DIR_MODEL1, key_dir=KEYS_DIR / f"{user_id}_1")
client.load()
# Creates the private and evaluation keys on the client side
client.generate_private_and_evaluation_keys()
# Get the serialized evaluation keys
serialized_evaluation_keys = client.get_serialized_evaluation_keys()
assert isinstance(serialized_evaluation_keys, bytes)
# Save the evaluation key
evaluation_key_path = KEYS_DIR / f"{user_id}_1/evaluation_key_1"
with evaluation_key_path.open("wb") as f:
f.write(serialized_evaluation_keys)
client = FHEModelClient(path_dir=DEPLOYMENT_DIR_MODEL2, key_dir=KEYS_DIR / f"{user_id}_2")
client.load()
# Creates the private and evaluation keys on the client side
client.generate_private_and_evaluation_keys()
# Get the serialized evaluation keys
serialized_evaluation_keys = client.get_serialized_evaluation_keys()
assert isinstance(serialized_evaluation_keys, bytes)
# Save the evaluation key
evaluation_key_path = KEYS_DIR / f"{user_id}_2/evaluation_key_2"
with evaluation_key_path.open("wb") as f:
f.write(serialized_evaluation_keys)
return {
error_box2: gr.update(visible=False),
user_id_box: gr.update(visible=False, value=user_id),
gen_key_btn: gr.update(value="Keys have been generated ✅")
}
def encrypt_fn(arr: np.ndarray, user_id: str, input_id: int) -> None:
"""
Encrypt the user symptoms vector in the `Client Side`.
Args:
user_symptoms (List[str]): The vector symptoms provided by the user
user_id (user): The current user's ID
"""
if is_none(user_id) or is_none(arr):
print("Error in encryption step: Provide your symptoms and generate the evaluation keys.")
return {
# error_box3: gr.update(
# visible=True,
# value="⚠️ Please ensure that your symptoms have been submitted and "
# "that you have generated the evaluation key.",
# )
}
# Retrieve the client API
if input_id == 1:
client = FHEModelClient(path_dir=DEPLOYMENT_DIR_MODEL1, key_dir=KEYS_DIR / f"{user_id}_1")
client.load()
else:
client = FHEModelClient(path_dir=DEPLOYMENT_DIR_MODEL2, key_dir=KEYS_DIR / f"{user_id}_2")
client.load()
encrypted_quantized_arr = client.quantize_encrypt_serialize(arr)
assert isinstance(encrypted_quantized_arr, bytes)
encrypted_input_path = KEYS_DIR / f"{user_id}_{input_id}/encrypted_input_{input_id}"
with encrypted_input_path.open("wb") as f:
f.write(encrypted_quantized_arr)
return {
# error_box3: gr.update(visible=False),
# one_hot_vect_box: gr.update(visible=True, value=user_symptoms),
# enc_vect_box: gr.update(visible=True, value=encrypted_quantized_user_symptoms_shorten_hex),
}
def send_input_fn(user_id: str, models_layer: int = 1) -> Dict:
"""Send the encrypted data and the evaluation key to the server.
Args:
user_id (str): The current user's ID
arr (np.ndarray): The input for a model
"""
if is_none(user_id):
return {
# error_box4: gr.update(
# visible=True,
# value="⚠️ Please check your connectivity \n"
# "⚠️ Ensure that the symptoms have been submitted and the evaluation "
# "key has been generated before sending the data to the server.",
# )
}
evaluation_key_path_1 = KEYS_DIR / f"{user_id}_1/evaluation_key_1"
evaluation_key_path_2 = KEYS_DIR / f"{user_id}_2/evaluation_key_2"
if models_layer == 1:
# First layer of models, we have two encrypted inputs
encrypted_input_path_1 = KEYS_DIR / f"{user_id}_1/encrypted_input_1"
encrypted_input_path_2 = KEYS_DIR / f"{user_id}_2/encrypted_input_2"
else:
encrypted_input_path_3 = KEYS_DIR / f"{user_id}/encrypted_input_3"
if not evaluation_key_path_1.is_file():
print(
"Error Encountered While Sending Data to the Server: "
f"The key has been generated correctly - {evaluation_key_path_1.is_file()=}"
)
return {
# error_box4: gr.update(visible=True, value="⚠️ Please generate the private key first.")
}
if not encrypted_input_path_1.is_file():
print(
"Error Encountered While Sending Data to the Server: The data has not been encrypted "
f"correctly on the client side - {encrypted_input_path_1.is_file()=}"
)
return {
# error_box4: gr.update(
# visible=True,
# value="⚠️ Please encrypt the data with the private key first.",
# ),
}
# Define the data and files to post
data = {
"user_id": user_id,
# "input": user_symptoms,
}
if models_layer == 1:
files = [
("files", open(encrypted_input_path_1, "rb")),
("files", open(encrypted_input_path_2, "rb")),
("files", open(evaluation_key_path_1, "rb")),
("files", open(evaluation_key_path_2, "rb")),
]
else:
files = [
("files", open(encrypted_input_path_3, "rb")),
# ("files", open(evaluation_key_path, "rb")),
]
# Send the encrypted input and evaluation key to the server
url = SERVER_URL + "send_input_first_layer"
with requests.post(
url=url,
data=data,
files=files,
) as response:
print(f"Sending Data: {response.ok=}")
return {
error_box4: gr.update(visible=False),
srv_resp_send_data_box: "Data sent",
}
def run_fhe_fn(user_id: str) -> Dict:
"""Send the encrypted input and the evaluation key to the server.
Args:
user_id (int): The current user's ID.
"""
if is_none(user_id):
return {
error_box5: gr.update(
visible=True,
value="⚠️ Please check your connectivity \n"
"⚠️ Ensure that the symptoms have been submitted, the evaluation "
"key has been generated and the server received the data "
"before processing the data.",
),
fhe_execution_time_box: None,
}
start_time = time.time()
data = {
"user_id": user_id,
}
# Run the first layer
url = SERVER_URL + "run_fhe_first_layer"
with requests.post(
url=url,
data=data,
) as response:
if not response.ok:
return {
error_box5: gr.update(
visible=True,
value=(
"⚠️ An error occurred on the Server Side. "
"Please check connectivity and data transmission."
),
),
fhe_execution_time_box: gr.update(visible=False),
}
else:
time.sleep(1)
print(f"response.ok: {response.ok}, {response.json()} - Computed")
print(f"First layer done!")
# Decrypt because ConcreteML doesn't provide output to input
url = SERVER_URL + "get_output_first_layer_1"
with requests.post(
url=url,
data=data,
) as response:
if response.ok:
print(f"Receive Data: {response.ok=}")
encrypted_output = response.content
# Save the encrypted output to bytes in a file as it is too large to pass through
# regular Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877)
encrypted_output_path = CLIENT_DIR / f"{user_id}_encrypted_output_1"
with encrypted_output_path.open("wb") as f:
f.write(encrypted_output)
url = SERVER_URL + "get_output_first_layer_2"
with requests.post(
url=url,
data=data,
) as response:
if response.ok:
print(f"Receive Data: {response.ok=}")
encrypted_output = response.content
# Save the encrypted output to bytes in a file as it is too large to pass through
# regular Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877)
encrypted_output_path = CLIENT_DIR / f"{user_id}_encrypted_output_2"
with encrypted_output_path.open("wb") as f:
f.write(encrypted_output)
encrypted_output_path_1 = CLIENT_DIR / f"{user_id}_encrypted_output_1"
encrypted_output_path_2 = CLIENT_DIR / f"{user_id}_encrypted_output_2"
# Load the encrypted output as bytes
with encrypted_output_path_1.open("rb") as f1, \
encrypted_output_path_2.open("rb") as f2:
encrypted_output_1 = f1.read()
encrypted_output_2 = f2.read()
# Retrieve the client API
client = FHEModelClient(path_dir=DEPLOYMENT_DIR_MODEL1, key_dir=KEYS_DIR / f"{user_id}_1")
client.load()
# Deserialize, decrypt and post-process the encrypted output
global global_output_1, global_output_2
global_output_1 = client.deserialize_decrypt_dequantize(encrypted_output_1)[0][0]
min_risk_score = 1.8145127821625648
max_risk_score = 1.9523557655864805
global_output_1 = (global_output_1 - min_risk_score) / (max_risk_score - min_risk_score)
client = FHEModelClient(path_dir=DEPLOYMENT_DIR_MODEL2, key_dir=KEYS_DIR / f"{user_id}_2")
client.load()
global_output_2 = client.deserialize_decrypt_dequantize(encrypted_output_2)
global_output_2 = int(global_output_2 > 0.6)
# Now re-encrypt the two values because ConcreteML does not allow
# to use the output of two models as input of a third one.
new_input = np.array([[global_output_1, global_output_2]])
# Retrieve the client API
client = FHEModelClient(path_dir=DEPLOYMENT_DIR_MODEL3, key_dir=KEYS_DIR / f"{user_id}")
client.load()
# Creates the private and evaluation keys on the client side
client.generate_private_and_evaluation_keys()
# Get the serialized evaluation keys
serialized_evaluation_keys = client.get_serialized_evaluation_keys()
assert isinstance(serialized_evaluation_keys, bytes)
# Save the evaluation key
evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key_second_layer"
with evaluation_key_path.open("wb") as f:
f.write(serialized_evaluation_keys)
encrypted_quantized_arr = client.quantize_encrypt_serialize(new_input)
assert isinstance(encrypted_quantized_arr, bytes)
encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_input_3"
with encrypted_input_path.open("wb") as f:
f.write(encrypted_quantized_arr)
# Send it
evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key_second_layer"
files = [
("files", open(encrypted_input_path, "rb")),
("files", open(evaluation_key_path, "rb")),
]
# Send the encrypted input and evaluation key to the server
url = SERVER_URL + "send_input_second_layer"
with requests.post(
url=url,
data=data,
files=files,
) as response:
print(f"Sending Data: {response.ok}")
# Run the second layer
url = SERVER_URL + "run_fhe_second_layer"
with requests.post(
url=url,
data=data,
) as response:
if not response.ok:
return {
error_box5: gr.update(
visible=True,
value=(
"⚠️ An error occurred on the Server Side. "
"Please check connectivity and data transmission."
),
),
fhe_execution_time_box: gr.update(visible=False),
}
else:
time.sleep(1)
print(f"response.ok: {response.ok}, {response.json()} - Computed")
print("Second layer done!")
total_time = time.time() - start_time
return {
error_box5: gr.update(visible=False),
fhe_execution_time_box: gr.update(visible=True, value=f"{total_time:.2f} seconds"),
}
def get_output_fn(user_id: str) -> Dict:
"""Retreive
the encrypted data from the server.
Args:
user_id (str): The current user's ID
user_symptoms (np.ndarray): The user symptoms
"""
if is_none(user_id):
return {
error_box6: gr.update(
visible=True,
value="⚠️ Please check your connectivity \n"
"⚠️ Ensure that the server has successfully processed and transmitted the data to the client.",
)
}
data = {
"user_id": user_id,
}
# Retrieve the encrypted output
url = SERVER_URL + "get_output_second_layer"
with requests.post(
url=url,
data=data,
) as response:
if response.ok:
print(f"Receive Data: {response.ok=}")
encrypted_output = response.content
# Save the encrypted output to bytes in a file as it is too large to pass through
# regular Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877)
encrypted_output_path = CLIENT_DIR / f"{user_id}_encrypted_output_3"
with encrypted_output_path.open("wb") as f:
f.write(encrypted_output)
return {error_box6: gr.update(visible=False), srv_resp_retrieve_data_box: "Data received"}
def decrypt_fn(user_id: str) -> Dict:
if is_none(user_id):
return {
error_box7: gr.update(
visible=True,
value="⚠️ Please check your connectivity \n"
"⚠️ Ensure that the client has successfully received the data from the server.",
)
}
# Get the encrypted output path
encrypted_output_path = CLIENT_DIR / f"{user_id}_encrypted_output_3"
if not encrypted_output_path.is_file():
print("Error in decryption step: Please run the FHE execution, first.")
return {
error_box7: gr.update(
visible=True,
value="⚠️ Please ensure that: \n"
"- the connectivity \n"
"- the symptoms have been submitted \n"
"- the evaluation key has been generated \n"
"- the server processed the encrypted data \n"
"- the Client received the data from the Server before decrypting the prediction",
),
decrypt_box: None,
}
with encrypted_output_path.open("rb") as f:
encrypted_output = f.read()
client = FHEModelClient(path_dir=DEPLOYMENT_DIR_MODEL3, key_dir=KEYS_DIR / f"{user_id}")
client.load()
# Deserialize, decrypt and post-process the encrypted output
output = client.deserialize_decrypt_dequantize(encrypted_output)
# Load also the data from the first two models (they are already downloaded)
global global_output_1, global_output_2
tachometer_plot = plot_tachometer(global_output_1 * 100)
emoticon_image = get_emoticon(global_output_2)
# Predicted class
predicted_class = np.argmax(output)
# Labels
labels = {
0: "Continue what you are doing!",
1: "Focus on technique!",
2: "Focus on mental health!",
3: "Rest!"
}
out = (
f"Given your recent running and mental stress statistics, you should... "
f"{labels[predicted_class]}"
)
return [
gr.update(value=out, visible=True),
gr.update(visible=False),
gr.update(value="Submit"),
gr.update(value=tachometer_plot, visible=True),
gr.update(value=emoticon_image, visible=True)
]
def reset_fn():
"""Reset the space and clear all the box outputs."""
clean_directory()
return {
# one_hot_vect: None,
# one_hot_vect_box: None,
# enc_vect_box: gr.update(visible=True, value=None),
# quant_vect_box: gr.update(visible=False, value=None),
# user_id_box: gr.update(visible=False, value=None),
# default_symptoms: gr.update(visible=True, value=None),
# default_disease_box: gr.update(visible=True, value=None),
# key_box: gr.update(visible=True, value=None),
# key_len_box: gr.update(visible=False, value=None),
# fhe_execution_time_box: gr.update(visible=True, value=None),
# decrypt_box: None,
# submit_btn: gr.update(value="Submit"),
# error_box7: gr.update(visible=False),
# error_box1: gr.update(visible=False),
# error_box2: gr.update(visible=False),
# error_box3: gr.update(visible=False),
# error_box4: gr.update(visible=False),
# error_box5: gr.update(visible=False),
# error_box6: gr.update(visible=False),
# srv_resp_send_data_box: None,
# srv_resp_retrieve_data_box: None,
# **{box: None for box in check_boxes},
}
def process_files(file1, file2):
global global_df1, global_df2
# Read the CSV files
df1 = pd.read_csv(file1.name)
df2 = pd.read_csv(file2.name)
# Store them in global variables to access later
global_df1 = df1
global_df2 = df2
return {
upload_button: gr.update(value="Data uploaded! ✅")
}
def encrypt_layer1(user_id_box):
global global_df1, global_df2
# INPUT ONE - RUNNING DATA
running_data, risk = statistics(global_df1)
running_data = pd.DataFrame(running_data)
input_model_1 = running_data.iloc[0, :].to_numpy()
input_model_1 = input_model_1.reshape(1, len(input_model_1))
# INPUT TWO - MENTAL HEALTH DATA
data = global_df2.iloc[:,2::].T
data.dropna(how='any', inplace=True, axis=0)
data = data.T
data = np.where((data.values > 1000) | (data.values<600), np.median(data.values), data.values)
rr_interpolated = interpolation(data, 4.0)
results = []
for i in range(len(data)):
results.append(frequency_domain(rr_interpolated[i]))
freq_col=['vlf','lf','hf','tot_pow','lf_hf_ratio','peak_vlf','peak_lf','peak_hf']
freq_features = pd.DataFrame(results, columns = freq_col)
input_model_2 = freq_features.iloc[0, :].to_numpy()
input_model_2 = input_model_2.reshape(1, len(input_model_2))
encrypt_fn(input_model_1, user_id_box, 1)
encrypt_fn(input_model_2, user_id_box, 2)
return {
error_box3: gr.update(visible=False, value="Error"),
encrypt_btn: gr.update(value="Data encrypted! ✅")
}
if __name__ == "__main__":
print("Starting demo ...")
clean_directory()
css = """
.centered-textbox textarea {
font-size: 24px !important;
text-align: center;
}
.large-emoticon textarea {
font-size: 72px !important;
text-align: center;
}
.logo {
display:block;
margin-left: auto;
margin-right: auto;
}
"""
with gr.Blocks(theme="light", css=css, title='AtlHEte') as demo:
# Link + images
gr.Markdown()
# gr.Markdown(
# """
# <p align="center">
# <img width=300 src="file/atlhete-high-resolution-logo-black-transparent.png">
# </p>
# """)
gr.Image('atlhete-high-resolution-logo-black-transparent.png', width=500, elem_classes="logo")
# Title
gr.Markdown("""
# AtlHEte
## Data loading
Upload your running time-series, and your PPG.
> The app of AtlHEte would do this automatically.
""")
with gr.Row():
file1 = gr.File(label="Upload running time-series")
file2 = gr.File(label="Upload PPG")
upload_button = gr.Button("Upload")
upload_button.click(process_files, inputs=[file1, file2], outputs=[upload_button])
# Keys generation
gr.Markdown("""
## Keys generation
Generate the TFHE keys.
""")
gen_key_btn = gr.Button("Generate the private and evaluation keys.")
error_box2 = gr.Textbox(label="Error ❌", visible=False)
user_id_box = gr.Textbox(label="User ID:", visible=False)
gen_key_btn.click(
key_gen_fn,
outputs=[
user_id_box,
error_box2,
gen_key_btn,
],
)
# Data encryption
gr.Markdown("""
## Data encryption
Encrypt both your running time-series and your PPG.
""")
encrypt_btn = gr.Button("Encrypt the data using the private secret key")
error_box3 = gr.Textbox(label="Error ❌", visible=False)
encrypt_btn.click(encrypt_layer1, inputs=[user_id_box], outputs=[error_box3, encrypt_btn])
# Data uploading
gr.Markdown("""
## Data upload
Upload your data safely to us.
""")
error_box4 = gr.Textbox(label="Error ❌", visible=False)
with gr.Row():
with gr.Column(scale=4):
send_input_btn = gr.Button("Send data")
with gr.Column(scale=1):
srv_resp_send_data_box = gr.Checkbox(label="Data Sent", show_label=False)
send_input_btn.click(
send_input_fn,
inputs=[user_id_box],
outputs=[error_box4, srv_resp_send_data_box],
)
# Encrypted processing
gr.Markdown("""
## Encrypted processing
Process your <span style='color:grey'>encrypted data</span> with AtlHEte!
""")
run_fhe_btn = gr.Button("Run the FHE evaluation")
error_box5 = gr.Textbox(label="Error ❌", visible=False)
fhe_execution_time_box = gr.Textbox(label="Total FHE Execution Time:", visible=True)
run_fhe_btn.click(
run_fhe_fn,
inputs=[user_id_box],
outputs=[fhe_execution_time_box, error_box5],
)
# Download the report
gr.Markdown("""
## Download the encrypted report
Download your personalized encrypted report...
""")
error_box6 = gr.Textbox(label="Error ❌", visible=False)
with gr.Row():
with gr.Column(scale=4):
get_output_btn = gr.Button("Get data")
with gr.Column(scale=1):
srv_resp_retrieve_data_box = gr.Checkbox(label="Data Received", show_label=False)
get_output_btn.click(
get_output_fn,
inputs=[user_id_box],
outputs=[srv_resp_retrieve_data_box, error_box6],
)
# Download the report
gr.Markdown("""
## Decrypt the report
Decrypt the report to know how you are doing!
""")
decrypt_btn = gr.Button("Decrypt the output using the private secret key")
error_box7 = gr.Textbox(label="Error ❌", visible=False)
# Layout components
with gr.Row():
tachometer_plot = gr.Plot(label="Running Quality", visible=False)
emoticon_display = gr.Textbox(label="Mental Health", visible=False, elem_classes="large-emoticon")
with gr.Column():
decrypt_box = gr.Textbox(label="Decrypted Output:", visible=False, elem_classes="centered-textbox")
decrypt_btn.click(
decrypt_fn,
inputs=[user_id_box],
outputs=[decrypt_box,
error_box7,
decrypt_btn,
tachometer_plot,
emoticon_display],
)
demo.launch(favicon_path='atlhete-high-resolution-logo-black-transparent.png')