English
Inference Endpoints
garg-aayush commited on
Commit
6ead77d
1 Parent(s): be14d64

update handler file: add checks for image size, mode. Add exceptions

Browse files
Files changed (1) hide show
  1. handler.py +85 -28
handler.py CHANGED
@@ -8,13 +8,11 @@ from pathlib import Path
8
  from basicsr.archs.rrdbnet_arch import RRDBNet
9
  import numpy as np
10
  import cv2
 
11
 
12
 
13
  import torch
14
  import base64
15
- # torch.cuda.empty_cache()
16
- # torch.cuda.set_per_process_memory_fraction(0.99)
17
- # os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64,garbage_collection_threshold:0.7"
18
 
19
 
20
  class EndpointHandler:
@@ -22,9 +20,16 @@ class EndpointHandler:
22
 
23
  self.model = RealESRGANer(
24
  scale=4,
25
- model_path=f"/repository/weights/Real-ESRGAN-x4plus.pth",
 
26
  # dni_weight=dni_weight,
27
- model= RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
 
 
 
 
 
 
28
  tile=0,
29
  tile_pad=10,
30
  # pre_pad=args.pre_pad,
@@ -33,28 +38,80 @@ class EndpointHandler:
33
  )
34
 
35
  def __call__(self, data: Any) -> Dict[str, List[float]]:
36
- inputs = data.pop("inputs", data)
37
- outscale = 3
38
 
39
- # decode base64 image to PIL
40
- image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
41
- # Convert PIL image to NumPy array
42
- opencv_image = np.array(image)
43
- # Convert RGB to BGR (PIL uses RGB, OpenCV expects BGR)
44
- opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR)
45
- output, _ = self.model.enhance(opencv_image, outscale=outscale)
46
-
47
- out_shape = output.shape
48
- if len(out_shape) == 3:
49
- if out_shape[2] == 3:
50
- output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
51
- elif out_shape[2] == 4:
52
- output = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
53
- else:
54
- output = cv2.cvtColor(output, cv2.COLOR_GRAY2RGB)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- img_byte_arr = BytesIO()
57
- output = Image.fromarray(output)
58
- output.save(img_byte_arr, format='PNG')
59
- img_str = base64.b64encode(img_byte_arr.getvalue())
60
- return {"out_image": img_str.decode()}
 
 
 
 
 
 
 
 
 
 
 
 
8
  from basicsr.archs.rrdbnet_arch import RRDBNet
9
  import numpy as np
10
  import cv2
11
+ import PIL
12
 
13
 
14
  import torch
15
  import base64
 
 
 
16
 
17
 
18
  class EndpointHandler:
 
20
 
21
  self.model = RealESRGANer(
22
  scale=4,
23
+ # model_path=f"/repository/weights/Real-ESRGAN-x4plus.pth",
24
+ model_path="/workspace/real-esrgan/weights/Real-ESRGAN-x4plus.pth",
25
  # dni_weight=dni_weight,
26
+ model= RRDBNet(num_in_ch=3,
27
+ num_out_ch=3,
28
+ num_feat=64,
29
+ num_block=23,
30
+ num_grow_ch=32,
31
+ scale=4
32
+ ),
33
  tile=0,
34
  tile_pad=10,
35
  # pre_pad=args.pre_pad,
 
38
  )
39
 
40
  def __call__(self, data: Any) -> Dict[str, List[float]]:
 
 
41
 
42
+ try:
43
+
44
+ # get inputs
45
+ inputs = data.pop("inputs", data)
46
+
47
+ # get outscale
48
+ outscale = float(inputs.pop("outscale", 3))
49
+
50
+ # decode base64 image to PIL
51
+ image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
52
+ in_size, in_mode = image.size, image.mode
53
+
54
+ # check image size and mode and return dict
55
+ assert in_mode in ["RGB", "RGBA", "L"], f"Unsupported image mode: {in_mode}"
56
+ assert in_size[0] * in_size[1] < 1400*1400, f"Image is too large: {in_size}: {in_size[0] * in_size[1]} is greater than {1400*1400}"
57
+ assert outscale > 1 and outscale <=10, f"Outscale must be between 1 and 10: {outscale}"
58
+
59
+ # debug
60
+ print(f"image.size: {in_size}, image.mode: {in_mode}, outscale: {outscale}")
61
+
62
+ # Convert RGB to BGR (PIL uses RGB, OpenCV expects BGR)
63
+ opencv_image = np.array(image)
64
+ if in_mode == "RGB":
65
+ opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR)
66
+ elif in_mode == "RGBA":
67
+ opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGBA2BGRA)
68
+ elif in_mode == "L":
69
+ opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_GRAY2RGB)
70
+ else:
71
+ raise ValueError(f"Unsupported image mode: {in_mode}")
72
+
73
+ # enhance image
74
+ output, _ = self.model.enhance(opencv_image, outscale=outscale)
75
+
76
+ # debug
77
+ print(f"output.shape: {output.shape}")
78
+
79
+ # convert to RGB/RGBA format
80
+ out_shape = output.shape
81
+ if len(out_shape) == 3:
82
+ if out_shape[2] == 3:
83
+ output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
84
+ elif out_shape[2] == 4:
85
+ output = cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
86
+ else:
87
+ output = cv2.cvtColor(output, cv2.COLOR_GRAY2RGB)
88
+
89
+ # convert to PIL image
90
+ img_byte_arr = BytesIO()
91
+ output = Image.fromarray(output)
92
+
93
+ # save to BytesIO
94
+ output.save(img_byte_arr, format='PNG')
95
+ img_str = base64.b64encode(img_byte_arr.getvalue())
96
+ img_str = img_str.decode()
97
+
98
+ return {"out_image": img_str,
99
+ "error": None
100
+ }
101
 
102
+ # handle errors
103
+ except AssertionError as e:
104
+ print(f"AssertionError: {e}")
105
+ return {"out_image": None, "error": str(e)}
106
+ except KeyError as e:
107
+ print(f"KeyError: {e}")
108
+ return {"out_image": None, "error": f"Missing key: {e}"}
109
+ except ValueError as e:
110
+ print(f"ValueError: {e}")
111
+ return {"out_image": None, "error": str(e)}
112
+ except PIL.UnidentifiedImageError as e:
113
+ print(f"PIL.UnidentifiedImageError: {e}")
114
+ return {"out_image": None, "error": "Invalid image format"}
115
+ except Exception as e:
116
+ print(f"Exception: {e}")
117
+ return {"out_image": None, "error": "An unexpected error occurred"}