kcelia commited on
Commit
bf8c653
•
1 Parent(s): 6f25c7a

chore: add resize_img

Browse files
Files changed (1) hide show
  1. app.py +26 -8
app.py CHANGED
@@ -197,12 +197,9 @@ def encrypt(user_id, input_image, filter_name):
197
 
198
  # Resize the image if it hasn't the shape (100, 100, 3)
199
  if input_image.shape != (100 , 100, 3):
200
- print(f"Before: {type(input_image)=}, {input_image.shape=}")
201
  input_image_pil = Image.fromarray(input_image)
202
- # Resize the image
203
  input_image_pil = input_image_pil.resize((100, 100))
204
  input_image = numpy.array(input_image_pil)
205
- print(f"After: {type(input_image)=}, {input_image.shape=}")
206
 
207
  # Retrieve the client API
208
  client = get_client(user_id, filter_name)
@@ -220,7 +217,7 @@ def encrypt(user_id, input_image, filter_name):
220
  # Create a truncated version of the encrypted image for display
221
  encrypted_image_short = shorten_bytes_object(encrypted_image)
222
 
223
- return (input_image, encrypted_image_short)
224
 
225
 
226
  def send_input(user_id, filter_name):
@@ -321,7 +318,8 @@ def get_output(user_id, filter_name):
321
  # Decrypt the image using a different (wrong) key for display
322
  output_image_representation = decrypt_output_with_wrong_key(encrypted_output, filter_name)
323
 
324
- return output_image_representation
 
325
  else:
326
  raise gr.Error("Please wait for the FHE execution to be completed.")
327
 
@@ -338,6 +336,9 @@ def decrypt_output(user_id, filter_name):
338
  well as two booleans used for resetting Gradio checkboxes
339
 
340
  """
 
 
 
341
  if user_id == "":
342
  raise gr.Error("Please generate the private key first.")
343
 
@@ -355,10 +356,27 @@ def decrypt_output(user_id, filter_name):
355
  client = get_client(user_id, filter_name)
356
 
357
  # Deserialize, decrypt and post-process the encrypted output
358
- output_image = client.deserialize_decrypt_post_process(encrypted_output_image)
359
 
360
- return output_image, False, False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
 
 
362
 
363
  demo = gr.Blocks()
364
 
@@ -464,7 +482,7 @@ with demo:
464
 
465
  with gr.Row():
466
  encrypted_output_representation = gr.Image(
467
- label=f"Encrypted output representation ({INPUT_SHAPE[0]}x{INPUT_SHAPE[1]}):",
468
  interactive=False,
469
  height=256,
470
  width=256,
 
197
 
198
  # Resize the image if it hasn't the shape (100, 100, 3)
199
  if input_image.shape != (100 , 100, 3):
 
200
  input_image_pil = Image.fromarray(input_image)
 
201
  input_image_pil = input_image_pil.resize((100, 100))
202
  input_image = numpy.array(input_image_pil)
 
203
 
204
  # Retrieve the client API
205
  client = get_client(user_id, filter_name)
 
217
  # Create a truncated version of the encrypted image for display
218
  encrypted_image_short = shorten_bytes_object(encrypted_image)
219
 
220
+ return (resize_img(input_image), encrypted_image_short)
221
 
222
 
223
  def send_input(user_id, filter_name):
 
318
  # Decrypt the image using a different (wrong) key for display
319
  output_image_representation = decrypt_output_with_wrong_key(encrypted_output, filter_name)
320
 
321
+ return {encrypted_output_representation: gr.update(value=resize_img(output_image_representation))}
322
+
323
  else:
324
  raise gr.Error("Please wait for the FHE execution to be completed.")
325
 
 
336
  well as two booleans used for resetting Gradio checkboxes
337
 
338
  """
339
+
340
+ print("Decrypt output ------------------------------------------------")
341
+
342
  if user_id == "":
343
  raise gr.Error("Please generate the private key first.")
344
 
 
356
  client = get_client(user_id, filter_name)
357
 
358
  # Deserialize, decrypt and post-process the encrypted output
359
+ decrypted_ouput = client.deserialize_decrypt_post_process(encrypted_output_image)
360
 
361
+ print(f"Decrypted output: {decrypted_ouput.shape=}")
362
+
363
+ return {output_image: gr.update(value=resize_img(decrypted_ouput))}
364
+
365
+
366
+ def resize_img(img, width=256, height=256):
367
+ # Convert to PIL Image
368
+ print("Reshape img before", img.shape, type(img))
369
+ if img.dtype != numpy.uint8:
370
+ img = img.astype(numpy.uint8)
371
+
372
+ img_pil = Image.fromarray(img)
373
+ print(type(img_pil))
374
+ # Resize the image
375
+ resized_img_pil = img_pil.resize((width, height))
376
+ print("Reshape img before", resized_img_pil.size)
377
 
378
+ # Convert back to a NumPy array
379
+ return numpy.array(resized_img_pil)
380
 
381
  demo = gr.Blocks()
382
 
 
482
 
483
  with gr.Row():
484
  encrypted_output_representation = gr.Image(
485
+ label=f"Encrypted output representation ({INPUT_SHAPE[0]}x{INPUT_SHAPE[1]}):",
486
  interactive=False,
487
  height=256,
488
  width=256,