import subprocess import time from typing import Dict, List, Tuple import gradio as gr import numpy as np import pandas as pd import requests from symptoms_categories import SYMPTOMS_LIST from utils import ( # pylint: disable=no-name-in-module CLIENT_DIR, CURRENT_DIR, DEPLOYMENT_DIR, INPUT_BROWSER_LIMIT, KEYS_DIR, SERVER_URL, TARGET_COLUMNS, TRAINING_FILENAME, clean_directory, get_disease_name, load_data, pretty_print, ) from concrete.ml.deployment import FHEModelClient subprocess.Popen(["uvicorn", "server:app"], cwd=CURRENT_DIR) time.sleep(3) # pylint: disable=c-extension-no-member,invalid-name def is_nan(inputs) -> bool: """ Check if the input is NaN. Args: inputs (any): The input to be checked. Returns: bool: True if the input is NaN or empty, False otherwise. """ return inputs is None or (inputs is not None and len(inputs) < 1) # def fill_in_fn(default_disease: str, *checkbox_symptoms: Tuple[str]) -> Dict: # """ # Fill in the gr.CheckBoxGroup list with the predefined symptoms of a selected default disease. # Args: # default_disease (str): The default disease # *checkbox_symptoms (Tuple[str]): Tuple of selected symptoms # Returns: # dict: The updated gr.CheckBoxesGroup. # """ # df = pd.read_csv(TRAINING_FILENAME) # df_filtred = df[df[TARGET_COLUMNS[1]] == default_disease] # symptoms = pretty_print(df_filtred.columns[df_filtred.eq(1).any()].to_list()) # if any(lst for lst in checkbox_symptoms if lst): # for sublist in checkbox_symptoms: # symptoms.extend(sublist) # return {box: symptoms for box in check_boxes} def get_user_symptoms_from_checkboxgroup(checkbox_symptoms: List) -> np.array: """ Convert the user symptoms into a binary vector representation. Args: checkbox_symptoms (list): A list of user symptoms. Returns: np.array: A binary vector representing the user's symptoms. Raises: KeyError: If a provided symptom is not recognized as a valid symptom. """ symptoms_vector = {key: 0 for key in valid_columns} for pretty_symptom in checkbox_symptoms: original_symptom = "_".join((pretty_symptom.lower().split(" "))) if original_symptom not in symptoms_vector.keys(): raise KeyError( f"The symptom '{original_symptom}' you provided is not recognized as a valid " f"symptom.\nHere is the list of valid symptoms: {symptoms_vector}" ) symptoms_vector[original_symptom] = 1 user_symptoms_vect = np.fromiter(symptoms_vector.values(), dtype=float)[np.newaxis, :] assert all(value == 0 or value == 1 for value in user_symptoms_vect.flatten()) return user_symptoms_vect def get_features_fn(*checked_symptoms: Tuple[str]) -> Dict: """ Get vector features based on the selected symptoms. Args: checked_symptoms (Tuple[str]): User symptoms Returns: Dict: The encoded user vector symptoms. """ if not any(lst for lst in checked_symptoms if lst): return { error_box1: gr.update( visible=True, value="Enter a default disease or select your own symptoms" ), } return { error_box1: gr.update(visible=False), user_vect_box1: get_user_symptoms_from_checkboxgroup(pretty_print(checked_symptoms)), } def key_gen_fn(user_symptoms: List[str]) -> Dict: """ Generate keys for a given user. Args: user_symptoms (List[str]): The vector symptoms provided by the user. Returns: dict: A dictionary containing the generated keys and related information. """ clean_directory() if is_nan(user_symptoms): print("Error: Please submit your symptoms or select a default disease.") return { error_box2: gr.update(visible=True, value="Please submit your symptoms first"), } # Generate a random user ID user_id = np.random.randint(0, 2**32) print(f"Your user ID is: {user_id}....") client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") client.load() # Creates the private and evaluation keys on the client side client.generate_private_and_evaluation_keys() # Get the serialized evaluation keys serialized_evaluation_keys = client.get_serialized_evaluation_keys() assert isinstance(serialized_evaluation_keys, bytes) # Save the evaluation key evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key" with evaluation_key_path.open("wb") as f: f.write(serialized_evaluation_keys) serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[:INPUT_BROWSER_LIMIT] return { error_box2: gr.update(visible=False), key_box: serialized_evaluation_keys_shorten_hex, user_id_box: user_id, key_len_box: f"{len(serialized_evaluation_keys) / (10**6):.2f} MB", } def encrypt_fn(user_symptoms: np.ndarray, user_id: str) -> None: """ Encrypt the user symptoms vector in the `Client Side`. Args: user_symptoms (List[str]): The vector symptoms provided by the user user_id (user): The current user's ID """ if is_nan(user_id) or is_nan(user_symptoms): print("Error in encryption step: Provide your symptoms and generate the evaluation keys.") return { error_box3: gr.update( visible=True, value="Please provide your symptoms and generate the evaluation keys." ) } # Retrieve the client API client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") client.load() user_symptoms = np.fromstring(user_symptoms[2:-2], dtype=int, sep=".").reshape(1, -1) quant_user_symptoms = client.model.quantize_input(user_symptoms) encrypted_quantized_user_symptoms = client.quantize_encrypt_serialize(user_symptoms) assert isinstance(encrypted_quantized_user_symptoms, bytes) encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_symptoms" with encrypted_input_path.open("wb") as f: f.write(encrypted_quantized_user_symptoms) encrypted_quantized_user_symptoms_shorten_hex = encrypted_quantized_user_symptoms.hex()[ :INPUT_BROWSER_LIMIT ] return { error_box3: gr.update(visible=False), user_vect_box2: user_symptoms, quant_vect_box: quant_user_symptoms, enc_vect_box: encrypted_quantized_user_symptoms_shorten_hex, } def send_input_fn(user_id: str, user_symptoms: np.ndarray) -> Dict: """Send the encrypted data and the evaluation key to the server. Args: user_id (int): The current user's ID user_symptoms (numpy.ndarray): The user symptoms """ if is_nan(user_id) or is_nan(user_symptoms): return { error_box4: gr.update( visible=True, value="Please ensure that the evaluation key has been generated " "and the symptoms have been submitted before sending the data to the server", ) } evaluation_key_path = KEYS_DIR / f"{user_id}/evaluation_key" encrypted_input_path = KEYS_DIR / f"{user_id}/encrypted_symptoms" if not evaluation_key_path.is_file(): print( "Error Encountered While Sending Data to the Server: " f"The key has been generated correctly - {evaluation_key_path.is_file()=}" ) return {error_box4: gr.update(visible=True, value="Please generate the private key first.")} if not encrypted_input_path.is_file(): print( "Error Encountered While Sending Data to the Server: The data has not been encrypted " f"correctly on the client side - {encrypted_input_path.is_file()=}" ) return { error_box4: gr.update( visible=True, value="Please encrypt the data with the private key first.", ), } # Define the data and files to post data = { "user_id": user_id, "filter": user_symptoms, } files = [ ("files", open(encrypted_input_path, "rb")), ("files", open(evaluation_key_path, "rb")), ] # Send the encrypted input and evaluation key to the server url = SERVER_URL + "send_input" with requests.post( url=url, data=data, files=files, ) as response: print(f"Sending Data: {response.ok=}") return {error_box4: gr.update(visible=False), srv_resp_send_data_box: "Data sent"} def run_fhe_fn(user_id: str) -> Dict: """Send the encrypted input as well as the evaluation key to the server. Args: user_id (int): The current user's ID. """ if is_nan(user_id): # or is_nan(user_symptoms): return { error_box5: gr.update( visible=True, value="Please ensure that the evaluation key has been generated " "and the symptoms have been submitted before sending the data to the server", ) } data = { "user_id": user_id, } # Trigger the FHE execution on the encrypted previously sent url = SERVER_URL + "run_fhe" with requests.post( url=url, data=data, ) as response: if not response.ok: return { error_box5: gr.update(visible=True, value="Please wait."), fhe_execution_time_box: gr.update(visible=True), } else: print(f"response.ok: {response.ok}, {response.json()} - Computed") return { error_box5: gr.update(visible=False), fhe_execution_time_box: gr.update(value=f"{response.json()} seconds"), } def get_output_fn(user_id: str, user_symptoms: np.ndarray) -> Dict: """Retreive the encrypted data from the server. Args: user_id (int): The current user's ID user_symptoms (numpy.ndarray): The user symptoms """ if is_nan(user_id) or is_nan(user_symptoms): return { error_box6: gr.update( visible=True, value="Please ensure that the evaluation key has been generated " "and the symptoms have been submitted before sending the data to the server", ) } data = { "user_id": user_id, } # Retrieve the encrypted output url = SERVER_URL + "get_output" with requests.post( url=url, data=data, ) as response: if response.ok: print(f"Receive Data: {response.ok=}") encrypted_output = response.content # Save the encrypted output to bytes in a file as it is too large to pass through # regular Gradio buttons (see https://github.com/gradio-app/gradio/issues/1877) encrypted_output_path = CLIENT_DIR / f"{user_id}_encrypted_output" with encrypted_output_path.open("wb") as f: f.write(encrypted_output) return {error_box6: gr.update(visible=False), srv_resp_retrieve_data_box: "Data received"} def decrypt_fn(user_id: str, user_symptoms: np.ndarray) -> Dict: """Dencrypt the data on the `Client Side`. Args: user_id (int): The current user's ID user_symptoms (numpy.ndarray): The user symptoms Returns: Decrypted output """ if is_nan(user_id) or is_nan(user_symptoms): return { error_box7: gr.update( visible=True, value="Please ensure that the symptoms have been submitted and the evaluation " "key has been generated", ) } # Get the encrypted output path encrypted_output_path = CLIENT_DIR / f"{user_id}_encrypted_output" if not encrypted_output_path.is_file(): print("Error in decryption step: Please run the FHE execution, first.") return { error_box7: gr.update( visible=True, value="Please ensure that the symptoms have been submitted, the evaluation " "key has been generated and step 5 and 6 have been performed on the Server " "side before decrypting the prediction", ) } # Load the encrypted output as bytes with encrypted_output_path.open("rb") as f: encrypted_output = f.read() # Retrieve the client API client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{user_id}") client.load() # Deserialize, decrypt and post-process the encrypted output output = client.deserialize_decrypt_dequantize(encrypted_output) return { error_box7: gr.update(visible=False), decrypt_target_box: get_disease_name(output.argmax()), } def clear_all_btn(): """Clear all the box outputs.""" clean_directory() return { # disease_box: None, user_id_box: None, user_vect_box1: None, user_vect_box2: None, quant_vect_box: None, enc_vect_box: None, key_box: None, key_len_box: None, fhe_execution_time_box: None, decrypt_target_box: None, error_box7: gr.update(visible=False), error_box1: gr.update(visible=False), error_box2: gr.update(visible=False), error_box3: gr.update(visible=False), error_box4: gr.update(visible=False), error_box5: gr.update(visible=False), error_box6: gr.update(visible=False), srv_resp_send_data_box: None, srv_resp_retrieve_data_box: None, **{box: None for box in check_boxes}, } CSS = """ #them {color: orange} #them {font-size: 25px} #them {font-weight: bold} .gradio-container {background-color: white} .feedback {font-size: 3px !important} /* #them {text-align: center} */ """ if __name__ == "__main__": print("Starting demo ...") clean_directory() (X_train, X_test), (y_train, y_test) = load_data() valid_columns = X_train.columns.to_list() with gr.Blocks(css=CSS) as demo: # Link + images gr.Markdown( """
Concrete-ML — Documentation — Community — @zama_fhe
""" ) with gr.Tabs(elem_id="them"): with gr.TabItem("1. Symptoms Selection") as feature: gr.Markdown("Client Side") gr.Markdown("## Step 1: Provide your symptoms") gr.Markdown( "You can provide your health condition either by checking " "the symptoms available in the boxes or by selecting a known disease with " "its predefined set of symptoms." ) # Box symptoms check_boxes = [] for i, category in enumerate(SYMPTOMS_LIST): with gr.Accordion( pretty_print(category.keys()), open=False, elem_classes="feedback" ) as accordion: check_box = gr.CheckboxGroup( pretty_print(category.values()), label=pretty_print(category.keys()), info=f"Symptoms related to `{pretty_print(category.values())}`", ) check_boxes.append(check_box) error_box1 = gr.Textbox(label="Error", visible=False) # Default disease, picked from the dataframe # disease_box = gr.Dropdown(list(sorted(set(df_test["prognosis"]))), # label="Disease:") # disease_box.change( # fn=fill_in_fn, # inputs=[disease_box, *check_boxes], # outputs=[*check_boxes], # ) # User symptom vector user_vect_box1 = gr.Textbox(label="User Symptoms Vector:", interactive=False) # Submit botton submit_button = gr.Button("Submit") with gr.Row(): # Clear botton clear_button = gr.Button("Reset") submit_button.click( fn=get_features_fn, inputs=[*check_boxes], outputs=[user_vect_box1, error_box1], ) with gr.TabItem("2. Data Encryption") as encryption_tab: gr.Markdown("Client Side") gr.Markdown("## Step 2: Generate the keys") gen_key_btn = gr.Button("Generate the keys") error_box2 = gr.Textbox(label="Error", visible=False) with gr.Row(): # User ID with gr.Column(scale=1, min_width=600): user_id_box = gr.Textbox(label="User ID:", interactive=False) # Evaluation key size with gr.Column(scale=1, min_width=600): key_len_box = gr.Textbox(label="Evaluation Key Size:", interactive=False) # Evaluation key (truncated) with gr.Column(scale=2, min_width=600): key_box = gr.Textbox( label="Evaluation key (truncated):", max_lines=3, interactive=False, ) gen_key_btn.click( key_gen_fn, inputs=user_vect_box1, outputs=[ key_box, user_id_box, key_len_box, error_box2, ], ) gr.Markdown("## Step 3: Encrypt the symptoms") encrypt_btn = gr.Button("Encrypt the symptoms with the private key") error_box3 = gr.Textbox(label="Error", visible=False) with gr.Row(): with gr.Column(scale=1, min_width=600): user_vect_box2 = gr.Textbox( label="User Symptoms Vector:", interactive=False ) with gr.Column(scale=1, min_width=600): quant_vect_box = gr.Textbox(label="Quantized Vector:", interactive=False) with gr.Column(scale=1, min_width=600): enc_vect_box = gr.Textbox( label="Encrypted Vector:", max_lines=3, interactive=False ) encrypt_btn.click( encrypt_fn, inputs=[user_vect_box1, user_id_box], outputs=[ user_vect_box2, quant_vect_box, enc_vect_box, error_box3, ], ) gr.Markdown( "## Step 4: Send the encrypted data to the " "Server Side" ) error_box4 = gr.Textbox(label="Error", visible=False) with gr.Row().style(equal_height=False): with gr.Column(scale=4): send_input_btn = gr.Button("Send the encrypted data") with gr.Column(scale=1): srv_resp_send_data_box = gr.Checkbox( label="Data Sent", show_label=False, interactive=False ) send_input_btn.click( send_input_fn, inputs=[user_id_box, user_vect_box1], outputs=[error_box4, srv_resp_send_data_box], ) with gr.TabItem("3. FHE execution") as fhe_tab: gr.Markdown("Client Side") gr.Markdown("## Step 5: Run the FHE evaluation") run_fhe_btn = gr.Button("Run the FHE evaluation") error_box5 = gr.Textbox(label="Error", visible=False) fhe_execution_time_box = gr.Textbox( label="Total FHE Execution Time:", interactive=False ) run_fhe_btn.click( run_fhe_fn, inputs=[user_id_box], outputs=[fhe_execution_time_box, error_box5], ) with gr.TabItem("4. Data Decryption") as decryption_tab: gr.Markdown("Client Side") gr.Markdown( "## Step 6: Get the data from the Server Side" ) error_box6 = gr.Textbox(label="Error", visible=False) with gr.Row().style(equal_height=True): with gr.Column(scale=4): get_output_btn = gr.Button("Get data") with gr.Column(scale=1): srv_resp_retrieve_data_box = gr.Checkbox( label="Data Received", show_label=False, interactive=False ) get_output_btn.click( get_output_fn, inputs=[user_id_box, user_vect_box1], outputs=[srv_resp_retrieve_data_box, error_box6], ) gr.Markdown("## Step 7: Decrypt the output") decrypt_target_btn = gr.Button("Decrypt the output") error_box7 = gr.Textbox(label="Error", visible=False) decrypt_target_box = gr.Textbox(abel="Decrypted Output:", interactive=False) decrypt_target_btn.click( decrypt_fn, inputs=[user_id_box, user_vect_box1], outputs=[decrypt_target_box, error_box7], ) clear_button.click( clear_all_btn, outputs=[ user_vect_box1, user_vect_box2, # disease_box, error_box1, error_box2, error_box3, error_box4, error_box5, error_box6, error_box7, user_id_box, key_len_box, key_box, quant_vect_box, enc_vect_box, srv_resp_send_data_box, srv_resp_retrieve_data_box, fhe_execution_time_box, decrypt_target_box, *check_boxes, ], ) demo.launch()