garg-aayush
commited on
Commit
•
6ead77d
1
Parent(s):
be14d64
update handler file: add checks for image size, mode. Add exceptions
Browse files- handler.py +85 -28
handler.py
CHANGED
@@ -8,13 +8,11 @@ from pathlib import Path
|
|
8 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
9 |
import numpy as np
|
10 |
import cv2
|
|
|
11 |
|
12 |
|
13 |
import torch
|
14 |
import base64
|
15 |
-
# torch.cuda.empty_cache()
|
16 |
-
# torch.cuda.set_per_process_memory_fraction(0.99)
|
17 |
-
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64,garbage_collection_threshold:0.7"
|
18 |
|
19 |
|
20 |
class EndpointHandler:
|
@@ -22,9 +20,16 @@ class EndpointHandler:
|
|
22 |
|
23 |
self.model = RealESRGANer(
|
24 |
scale=4,
|
25 |
-
model_path=f"/repository/weights/Real-ESRGAN-x4plus.pth",
|
|
|
26 |
# dni_weight=dni_weight,
|
27 |
-
model= RRDBNet(num_in_ch=3,
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
tile=0,
|
29 |
tile_pad=10,
|
30 |
# pre_pad=args.pre_pad,
|
@@ -33,28 +38,80 @@ class EndpointHandler:
|
|
33 |
)
|
34 |
|
35 |
def __call__(self, data: Any) -> Dict[str, List[float]]:
|
36 |
-
inputs = data.pop("inputs", data)
|
37 |
-
outscale = 3
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
9 |
import numpy as np
|
10 |
import cv2
|
11 |
+
import PIL
|
12 |
|
13 |
|
14 |
import torch
|
15 |
import base64
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
class EndpointHandler:
|
|
|
20 |
|
21 |
self.model = RealESRGANer(
|
22 |
scale=4,
|
23 |
+
# model_path=f"/repository/weights/Real-ESRGAN-x4plus.pth",
|
24 |
+
model_path="/workspace/real-esrgan/weights/Real-ESRGAN-x4plus.pth",
|
25 |
# dni_weight=dni_weight,
|
26 |
+
model= RRDBNet(num_in_ch=3,
|
27 |
+
num_out_ch=3,
|
28 |
+
num_feat=64,
|
29 |
+
num_block=23,
|
30 |
+
num_grow_ch=32,
|
31 |
+
scale=4
|
32 |
+
),
|
33 |
tile=0,
|
34 |
tile_pad=10,
|
35 |
# pre_pad=args.pre_pad,
|
|
|
38 |
)
|
39 |
|
40 |
def __call__(self, data: Any) -> Dict[str, List[float]]:
|
|
|
|
|
41 |
|
42 |
+
try:
|
43 |
+
|
44 |
+
# get inputs
|
45 |
+
inputs = data.pop("inputs", data)
|
46 |
+
|
47 |
+
# get outscale
|
48 |
+
outscale = float(inputs.pop("outscale", 3))
|
49 |
+
|
50 |
+
# decode base64 image to PIL
|
51 |
+
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
|
52 |
+
in_size, in_mode = image.size, image.mode
|
53 |
+
|
54 |
+
# check image size and mode and return dict
|
55 |
+
assert in_mode in ["RGB", "RGBA", "L"], f"Unsupported image mode: {in_mode}"
|
56 |
+
assert in_size[0] * in_size[1] < 1400*1400, f"Image is too large: {in_size}: {in_size[0] * in_size[1]} is greater than {1400*1400}"
|
57 |
+
assert outscale > 1 and outscale <=10, f"Outscale must be between 1 and 10: {outscale}"
|
58 |
+
|
59 |
+
# debug
|
60 |
+
print(f"image.size: {in_size}, image.mode: {in_mode}, outscale: {outscale}")
|
61 |
+
|
62 |
+
# Convert RGB to BGR (PIL uses RGB, OpenCV expects BGR)
|
63 |
+
opencv_image = np.array(image)
|
64 |
+
if in_mode == "RGB":
|
65 |
+
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR)
|
66 |
+
elif in_mode == "RGBA":
|
67 |
+
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGBA2BGRA)
|
68 |
+
elif in_mode == "L":
|
69 |
+
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_GRAY2RGB)
|
70 |
+
else:
|
71 |
+
raise ValueError(f"Unsupported image mode: {in_mode}")
|
72 |
+
|
73 |
+
# enhance image
|
74 |
+
output, _ = self.model.enhance(opencv_image, outscale=outscale)
|
75 |
+
|
76 |
+
# debug
|
77 |
+
print(f"output.shape: {output.shape}")
|
78 |
+
|
79 |
+
# convert to RGB/RGBA format
|
80 |
+
out_shape = output.shape
|
81 |
+
if len(out_shape) == 3:
|
82 |
+
if out_shape[2] == 3:
|
83 |
+
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
|
84 |
+
elif out_shape[2] == 4:
|
85 |
+
output = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
|
86 |
+
else:
|
87 |
+
output = cv2.cvtColor(output, cv2.COLOR_GRAY2RGB)
|
88 |
+
|
89 |
+
# convert to PIL image
|
90 |
+
img_byte_arr = BytesIO()
|
91 |
+
output = Image.fromarray(output)
|
92 |
+
|
93 |
+
# save to BytesIO
|
94 |
+
output.save(img_byte_arr, format='PNG')
|
95 |
+
img_str = base64.b64encode(img_byte_arr.getvalue())
|
96 |
+
img_str = img_str.decode()
|
97 |
+
|
98 |
+
return {"out_image": img_str,
|
99 |
+
"error": None
|
100 |
+
}
|
101 |
|
102 |
+
# handle errors
|
103 |
+
except AssertionError as e:
|
104 |
+
print(f"AssertionError: {e}")
|
105 |
+
return {"out_image": None, "error": str(e)}
|
106 |
+
except KeyError as e:
|
107 |
+
print(f"KeyError: {e}")
|
108 |
+
return {"out_image": None, "error": f"Missing key: {e}"}
|
109 |
+
except ValueError as e:
|
110 |
+
print(f"ValueError: {e}")
|
111 |
+
return {"out_image": None, "error": str(e)}
|
112 |
+
except PIL.UnidentifiedImageError as e:
|
113 |
+
print(f"PIL.UnidentifiedImageError: {e}")
|
114 |
+
return {"out_image": None, "error": "Invalid image format"}
|
115 |
+
except Exception as e:
|
116 |
+
print(f"Exception: {e}")
|
117 |
+
return {"out_image": None, "error": "An unexpected error occurred"}
|