chore: reshape the image to (100, 100, 3) if not and add a check for rgb format
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
"""A local gradio app that filters images using FHE."""
|
2 |
-
|
3 |
import os
|
4 |
import shutil
|
5 |
import subprocess
|
@@ -191,6 +191,18 @@ def encrypt(user_id, input_image, filter_name):
|
|
191 |
|
192 |
if input_image is None:
|
193 |
raise gr.Error("Please choose an image first.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
# Retrieve the client API
|
196 |
client = get_client(user_id, filter_name)
|
@@ -482,7 +494,7 @@ with demo:
|
|
482 |
)
|
483 |
|
484 |
output_image = gr.Image(
|
485 |
-
label=f"Output image ({INPUT_SHAPE[0]}x{INPUT_SHAPE[1]}):",
|
486 |
interactive=False,
|
487 |
height=256,
|
488 |
width=256,
|
@@ -513,7 +525,7 @@ with demo:
|
|
513 |
# Button to send the encodings to the server using post method
|
514 |
get_output_button.click(
|
515 |
get_output,
|
516 |
-
inputs=[user_id, filter_name],
|
517 |
outputs=[encrypted_output_representation]
|
518 |
)
|
519 |
|
|
|
1 |
"""A local gradio app that filters images using FHE."""
|
2 |
+
from PIL import Image
|
3 |
import os
|
4 |
import shutil
|
5 |
import subprocess
|
|
|
191 |
|
192 |
if input_image is None:
|
193 |
raise gr.Error("Please choose an image first.")
|
194 |
+
|
195 |
+
if input_image.shape[-1] != 3:
|
196 |
+
raise ValueError(f"Input image must have 3 channels (RGB). Current shape: {input_image.shape}")
|
197 |
+
|
198 |
+
# Resize the image if it hasn't the shape (100, 100, 3)
|
199 |
+
if input_image.shape != (100 , 100, 3):
|
200 |
+
print(f"Before: {type(input_image)=}, {input_image.shape=}")
|
201 |
+
input_image_pil = Image.fromarray(input_image)
|
202 |
+
# Resize the image
|
203 |
+
input_image_pil = input_image_pil.resize((100, 100))
|
204 |
+
input_image = numpy.array(input_image_pil)
|
205 |
+
print(f"After: {type(input_image)=}, {input_image.shape=}")
|
206 |
|
207 |
# Retrieve the client API
|
208 |
client = get_client(user_id, filter_name)
|
|
|
494 |
)
|
495 |
|
496 |
output_image = gr.Image(
|
497 |
+
label=f"Output image ({INPUT_SHAPE[0]}x{INPUT_SHAPE[1]}):",
|
498 |
interactive=False,
|
499 |
height=256,
|
500 |
width=256,
|
|
|
525 |
# Button to send the encodings to the server using post method
|
526 |
get_output_button.click(
|
527 |
get_output,
|
528 |
+
inputs=[user_id, filter_name],
|
529 |
outputs=[encrypted_output_representation]
|
530 |
)
|
531 |
|