batuergun commited on
Commit
4d2771f
1 Parent(s): 7ea0738

memory fix

Browse files
Files changed (3) hide show
  1. app.py +22 -21
  2. requirements.txt +2 -1
  3. server.py +53 -28
app.py CHANGED
@@ -238,32 +238,33 @@ def send_input(user_id):
238
  return response.ok
239
 
240
  def run_fhe(user_id):
241
- """Apply the seizure detection model on the encrypted image previously sent using FHE."""
242
  data = {"user_id": user_id}
243
  url = SERVER_URL + "run_fhe"
244
 
245
  try:
246
- logger.info(f"Sending request to {url} with user_id: {user_id}")
247
- with requests_retry_session().post(url=url, data=data, timeout=300) as response:
248
- logger.info(f"Received response with status code: {response.status_code}")
249
- response.raise_for_status() # Raises an HTTPError for bad responses
250
- if response.ok:
251
- return response.json()
252
- else:
253
- logger.error(f"Server responded with status code {response.status_code}")
254
- raise gr.Error(f"Server responded with status code {response.status_code}")
255
- except requests.exceptions.Timeout:
256
- logger.error("The request timed out. The server might be overloaded.")
257
- raise gr.Error("The request timed out. The server might be overloaded.")
258
- except requests.exceptions.ConnectionError as e:
259
- logger.error(f"Failed to connect to the server. Error: {str(e)}")
260
- raise gr.Error("Failed to connect to the server. Please check your network connection.")
 
 
 
 
261
  except requests.exceptions.RequestException as e:
262
- logger.error(f"An error occurred: {str(e)}")
263
- raise gr.Error(f"An error occurred: {str(e)}")
264
- except Exception as e:
265
- logger.error(f"An unexpected error occurred: {str(e)}")
266
- raise gr.Error(f"An unexpected error occurred: {str(e)}")
267
 
268
  def get_output(user_id):
269
  """Retrieve the encrypted output (boolean).
 
238
  return response.ok
239
 
240
  def run_fhe(user_id):
241
+ """Start the FHE execution and poll for results."""
242
  data = {"user_id": user_id}
243
  url = SERVER_URL + "run_fhe"
244
 
245
  try:
246
+ with requests_retry_session().post(url=url, data=data, timeout=30) as response:
247
+ response.raise_for_status()
248
+ logger.info("FHE execution started successfully")
249
+
250
+ # Poll for FHE execution status
251
+ status_url = SERVER_URL + f"fhe_status/{user_id}"
252
+ max_attempts = 60 # Adjust this value based on expected execution time
253
+ attempt = 0
254
+ while attempt < max_attempts:
255
+ with requests_retry_session().get(status_url, timeout=10) as status_response:
256
+ status_response.raise_for_status()
257
+ status = status_response.json()["status"]
258
+ if status == "completed":
259
+ logger.info("FHE execution completed successfully")
260
+ return "FHE execution completed successfully"
261
+ time.sleep(5) # Wait for 5 seconds before next poll
262
+ attempt += 1
263
+
264
+ raise gr.Error("FHE execution timed out")
265
  except requests.exceptions.RequestException as e:
266
+ logger.error(f"Error during FHE execution: {str(e)}")
267
+ raise gr.Error(f"An error occurred during FHE execution: {str(e)}")
 
 
 
268
 
269
  def get_output(user_id):
270
  """Retrieve the encrypted output (boolean).
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  concrete-ml
2
  gradio
3
- fastapi
 
 
1
  concrete-ml
2
  gradio
3
+ fastapi
4
+ psutil
server.py CHANGED
@@ -4,8 +4,10 @@ import time
4
  import logging
5
  from pathlib import Path
6
  from typing import List
 
 
7
 
8
- from fastapi import FastAPI, File, Form, UploadFile, HTTPException
9
  from fastapi.responses import JSONResponse, Response
10
  from fastapi.middleware.cors import CORSMiddleware
11
 
@@ -38,6 +40,12 @@ def get_server_file_path(file_type: str, user_id: str) -> Path:
38
  file_type = "evaluation"
39
  return Path(__file__).parent / "server_tmp" / f"{file_type}_{user_id}"
40
 
 
 
 
 
 
 
41
  @app.post("/send_input")
42
  async def send_input(user_id: str = Form(), files: List[UploadFile] = File(...)):
43
  """Receive the encrypted input image and the evaluation key from the client."""
@@ -62,45 +70,62 @@ async def send_input(user_id: str = Form(), files: List[UploadFile] = File(...))
62
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
63
 
64
  @app.post("/run_fhe")
65
- def run_fhe(user_id: str = Form()):
66
- """Execute seizure detection on the encrypted input image using FHE."""
67
  logger.info(f"Starting FHE execution for user {user_id}")
68
  try:
69
- # Retrieve the encrypted input image and the evaluation key paths
70
  encrypted_image_path = get_server_file_path("encrypted", user_id)
71
  evaluation_key_path = get_server_file_path("evaluation", user_id)
72
 
73
- logger.info(f"Looking for encrypted_image at: {encrypted_image_path}")
74
- logger.info(f"Looking for evaluation_key at: {evaluation_key_path}")
75
 
76
- # Check if files exist
77
- if not encrypted_image_path.exists():
78
- raise FileNotFoundError(f"Encrypted image file not found at {encrypted_image_path}")
79
- if not evaluation_key_path.exists():
80
- raise FileNotFoundError(f"Evaluation key file not found at {evaluation_key_path}")
81
 
82
- # Read the files using the above paths
83
- with encrypted_image_path.open("rb") as encrypted_image_file, evaluation_key_path.open("rb") as evaluation_key_file:
84
- encrypted_image = encrypted_image_file.read()
85
- evaluation_key = evaluation_key_file.read()
86
 
87
- # Run the FHE execution
88
- start = time.time()
89
- encrypted_output = FHE_SERVER.run(encrypted_image, evaluation_key)
90
- fhe_execution_time = round(time.time() - start, 2)
 
91
 
92
- # Retrieve the encrypted output path
93
- encrypted_output_path = get_server_file_path("encrypted_output", user_id)
94
 
95
- # Write the file using the above path
96
- with encrypted_output_path.open("wb") as encrypted_output_file:
97
- encrypted_output_file.write(encrypted_output)
 
 
 
 
 
 
 
 
 
98
 
99
- logger.info(f"FHE execution completed for user {user_id} in {fhe_execution_time} seconds")
100
- return JSONResponse(content=fhe_execution_time)
 
 
 
101
  except Exception as e:
102
- logger.error(f"Error in run_fhe for user {user_id}: {str(e)}")
103
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
 
 
 
 
 
 
 
104
 
105
  @app.post("/get_output")
106
  def get_output(user_id: str = Form()):
 
4
  import logging
5
  from pathlib import Path
6
  from typing import List
7
+ import psutil
8
+ import numpy as np
9
 
10
+ from fastapi import FastAPI, File, Form, UploadFile, HTTPException, BackgroundTasks
11
  from fastapi.responses import JSONResponse, Response
12
  from fastapi.middleware.cors import CORSMiddleware
13
 
 
40
  file_type = "evaluation"
41
  return Path(__file__).parent / "server_tmp" / f"{file_type}_{user_id}"
42
 
43
+ def get_available_memory():
44
+ return psutil.virtual_memory().available
45
+
46
+ def chunk_data(data, num_chunks):
47
+ return np.array_split(data, num_chunks)
48
+
49
  @app.post("/send_input")
50
  async def send_input(user_id: str = Form(), files: List[UploadFile] = File(...)):
51
  """Receive the encrypted input image and the evaluation key from the client."""
 
70
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
71
 
72
  @app.post("/run_fhe")
73
+ async def run_fhe(background_tasks: BackgroundTasks, user_id: str = Form()):
74
+ """Start the FHE execution in the background with memory management."""
75
  logger.info(f"Starting FHE execution for user {user_id}")
76
  try:
 
77
  encrypted_image_path = get_server_file_path("encrypted", user_id)
78
  evaluation_key_path = get_server_file_path("evaluation", user_id)
79
 
80
+ if not encrypted_image_path.exists() or not evaluation_key_path.exists():
81
+ raise FileNotFoundError("Required files not found")
82
 
83
+ # Start the FHE execution in the background
84
+ background_tasks.add_task(execute_fhe_with_memory_management, user_id, encrypted_image_path, evaluation_key_path)
 
 
 
85
 
86
+ return JSONResponse(content={"message": "FHE execution started"})
87
+ except Exception as e:
88
+ logger.error(f"Error in run_fhe for user {user_id}: {str(e)}")
89
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
90
 
91
+ async def execute_fhe_with_memory_management(user_id: str, encrypted_image_path: Path, evaluation_key_path: Path):
92
+ try:
93
+ with encrypted_image_path.open("rb") as f, evaluation_key_path.open("rb") as g:
94
+ encrypted_image = f.read()
95
+ evaluation_key = g.read()
96
 
97
+ available_memory = get_available_memory()
98
+ estimated_memory_needed = len(encrypted_image) * 2 # Rough estimate, adjust as needed
99
 
100
+ if estimated_memory_needed > available_memory:
101
+ num_chunks = (estimated_memory_needed // available_memory) + 1
102
+ encrypted_image_chunks = chunk_data(np.frombuffer(encrypted_image, dtype=np.uint8), num_chunks)
103
+
104
+ results = []
105
+ for chunk in encrypted_image_chunks:
106
+ chunk_result = FHE_SERVER.run(chunk.tobytes(), evaluation_key)
107
+ results.append(chunk_result)
108
+
109
+ encrypted_output = b''.join(results)
110
+ else:
111
+ encrypted_output = FHE_SERVER.run(encrypted_image, evaluation_key)
112
 
113
+ encrypted_output_path = get_server_file_path("encrypted_output", user_id)
114
+ with encrypted_output_path.open("wb") as f:
115
+ f.write(encrypted_output)
116
+
117
+ logger.info(f"FHE execution completed for user {user_id}")
118
  except Exception as e:
119
+ logger.error(f"Error during FHE execution for user {user_id}: {str(e)}")
120
+
121
+ @app.get("/fhe_status/{user_id}")
122
+ async def fhe_status(user_id: str):
123
+ """Check the status of the FHE execution."""
124
+ encrypted_output_path = get_server_file_path("encrypted_output", user_id)
125
+ if encrypted_output_path.exists():
126
+ return JSONResponse(content={"status": "completed"})
127
+ else:
128
+ return JSONResponse(content={"status": "in_progress"})
129
 
130
  @app.post("/get_output")
131
  def get_output(user_id: str = Form()):