English
Inference Endpoints
garg-aayush commited on
Commit
d47736e
1 Parent(s): 926a1ef

Update and add more unit tests

Browse files
Files changed (1) hide show
  1. unit_tests.py +172 -44
unit_tests.py CHANGED
@@ -1,86 +1,214 @@
1
  import unittest
2
  from unittest.mock import patch, MagicMock
 
3
  from PIL import Image
4
- import base64
5
- import numpy as np
6
  from io import BytesIO
 
7
  from handler import EndpointHandler
8
 
9
  class TestEndpointHandler(unittest.TestCase):
10
-
11
  @patch('handler.RealESRGANer')
12
- def setUp(self, mock_RealESRGANer):
13
- self.handler = EndpointHandler(path=".")
 
 
 
 
 
 
 
 
14
  self.mock_model = mock_RealESRGANer.return_value
 
15
 
16
- def create_test_image(self, mode='RGB', size=(100, 100)):
17
- image = Image.new(mode, size)
18
  buffered = BytesIO()
19
  image.save(buffered, format="PNG")
20
- return base64.b64encode(buffered.getvalue()).decode()
21
 
22
- def get_svg_image(self):
23
- test_image = "test_data/834989.svg"
24
- return test_image
25
-
26
- def test_float_outscale(self):
27
- test_image = self.create_test_image()
28
- input_data = {"inputs": {"image": test_image, "outscale": 2.5}}
29
-
30
- self.mock_model.enhance.return_value = (np.zeros((250, 250, 3), dtype=np.uint8), None)
 
 
 
 
 
 
 
 
 
 
31
  result = self.handler(input_data)
32
 
33
- self.assertIn("out_image", result)
 
34
  self.assertIsNone(result["error"])
35
-
36
- def test_outscale_too_small(self):
37
- test_image = self.create_test_image()
38
- input_data = {"inputs": {"image": test_image, "outscale": 0.5}}
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  result = self.handler(input_data)
41
 
42
- self.assertIsNone(result["out_image"])
 
43
  self.assertIn("Outscale must be between 1 and 10", result["error"])
44
-
45
- def test_outscale_too_large(self):
46
- test_image = self.create_test_image()
47
- input_data = {"inputs": {"image": test_image, "outscale": 11}}
 
 
 
 
 
 
 
 
48
 
49
  result = self.handler(input_data)
50
 
51
- self.assertIsNone(result["out_image"])
52
- self.assertIn("Outscale must be between 1 and 10", result["error"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- def test_valid_rgb_image(self):
55
- test_image = self.create_test_image()
56
- input_data = {"inputs": {"image": test_image, "outscale": 2}}
 
 
 
 
 
 
 
 
 
 
 
57
 
 
58
  self.mock_model.enhance.return_value = (np.zeros((200, 200, 3), dtype=np.uint8), None)
59
 
 
 
 
 
 
 
 
 
 
 
60
  result = self.handler(input_data)
61
 
62
- self.assertIn("out_image", result)
63
- self.assertIsNone(result["error"])
64
- self.mock_model.enhance.assert_called_once()
65
 
66
- def test_valid_rgba_image(self):
67
- test_image = self.create_test_image(mode='RGBA')
68
- input_data = {"inputs": {"image": test_image, "outscale": 2}}
 
 
 
 
 
 
69
 
70
- self.mock_model.enhance.return_value = (np.zeros((400, 400, 4), dtype=np.uint8), None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  result = self.handler(input_data)
73
 
74
- self.assertIn("out_image", result)
 
75
  self.assertIsNone(result["error"])
76
 
77
- def test_missing_image_key(self):
78
- input_data = {"inputs": {}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  result = self.handler(input_data)
81
 
82
- self.assertIsNone(result["out_image"])
83
- self.assertIn("Missing key", result["error"])
 
84
 
85
  if __name__ == '__main__':
86
  unittest.main()
 
1
  import unittest
2
  from unittest.mock import patch, MagicMock
3
+ import os
4
  from PIL import Image
 
 
5
  from io import BytesIO
6
+ import numpy as np
7
  from handler import EndpointHandler
8
 
9
  class TestEndpointHandler(unittest.TestCase):
 
10
  @patch('handler.RealESRGANer')
11
+ @patch('handler.boto3')
12
+ def setUp(self, mock_boto3, mock_RealESRGANer):
13
+ """Set up test environment before each test"""
14
+ # Set required environment variables
15
+ os.environ['TILING_SIZE'] = '0'
16
+ os.environ['AWS_ACCESS_KEY_ID'] = 'test_key'
17
+ os.environ['AWS_SECRET_ACCESS_KEY'] = 'test_secret'
18
+ os.environ['S3_BUCKET_NAME'] = 'test-bucket'
19
+
20
+ self.handler = EndpointHandler()
21
  self.mock_model = mock_RealESRGANer.return_value
22
+ self.mock_s3 = mock_boto3.client.return_value
23
 
24
+ def image_to_bytes(self, image):
25
+ """Helper method to convert PIL Image to bytes"""
26
  buffered = BytesIO()
27
  image.save(buffered, format="PNG")
28
+ return buffered.getvalue()
29
 
30
+ @patch('handler.requests.get')
31
+ def test_successful_upscale(self, mock_get):
32
+ """Test successful image upscaling"""
33
+ # Create test image and mock response
34
+ test_image = Image.new('RGB', (100, 100))
35
+ mock_response = MagicMock()
36
+ mock_response.content = self.image_to_bytes(test_image)
37
+ mock_get.return_value = mock_response
38
+
39
+ # Mock model output
40
+ self.mock_model.enhance.return_value = (np.zeros((200, 200, 3), dtype=np.uint8), None)
41
+
42
+ input_data = {
43
+ "inputs": {
44
+ "image_url": "http://example.com/test.png",
45
+ "outscale": 2
46
+ }
47
+ }
48
+
49
  result = self.handler(input_data)
50
 
51
+ self.assertIsNotNone(result["image_url"])
52
+ self.assertIsNotNone(result["image_key"])
53
  self.assertIsNone(result["error"])
54
+
55
+ @patch('handler.requests.get')
56
+ def test_invalid_outscale(self, mock_get):
57
+ """Test handling of invalid outscale values"""
58
+ # Create test image and mock response
59
+ test_image = Image.new('RGB', (100, 100))
60
+ mock_response = MagicMock()
61
+ mock_response.content = self.image_to_bytes(test_image)
62
+ mock_get.return_value = mock_response
63
+
64
+ input_data = {
65
+ "inputs": {
66
+ "image_url": "http://example.com/test.png",
67
+ "outscale": 0.5 # Too small
68
+ }
69
+ }
70
 
71
  result = self.handler(input_data)
72
 
73
+ self.assertIsNone(result["image_url"])
74
+ self.assertIsNone(result["image_key"])
75
  self.assertIn("Outscale must be between 1 and 10", result["error"])
76
+
77
+ @patch('handler.requests.get')
78
+ def test_download_failure(self, mock_get):
79
+ """Test handling of failed image downloads"""
80
+ mock_get.side_effect = Exception("Download failed")
81
+
82
+ input_data = {
83
+ "inputs": {
84
+ "image_url": "http://example.com/test.png",
85
+ "outscale": 2
86
+ }
87
+ }
88
 
89
  result = self.handler(input_data)
90
 
91
+ self.assertIsNone(result["image_url"])
92
+ self.assertIsNone(result["image_key"])
93
+ self.assertIn("Failed to download image", result["error"])
94
+
95
+ @patch('handler.requests.get')
96
+ def test_large_image_no_tiling(self, mock_get):
97
+ """Test handling of large images when tiling is disabled"""
98
+ # Create an image larger than max_image_size
99
+ test_image = Image.new('RGB', (1500, 1500))
100
+ mock_response = MagicMock()
101
+ mock_response.content = self.image_to_bytes(test_image)
102
+ mock_get.return_value = mock_response
103
+
104
+ input_data = {
105
+ "inputs": {
106
+ "image_url": "http://example.com/test.png",
107
+ "outscale": 2
108
+ }
109
+ }
110
 
111
+ result = self.handler(input_data)
112
+
113
+ self.assertIsNone(result["image_url"])
114
+ self.assertIsNone(result["image_key"])
115
+ self.assertIn("Image is too large", result["error"])
116
+
117
+ @patch('handler.requests.get')
118
+ def test_s3_upload_failure(self, mock_get):
119
+ """Test handling of S3 upload failures"""
120
+ # Create test image and mock response
121
+ test_image = Image.new('RGB', (100, 100))
122
+ mock_response = MagicMock()
123
+ mock_response.content = self.image_to_bytes(test_image)
124
+ mock_get.return_value = mock_response
125
 
126
+ # Mock model output
127
  self.mock_model.enhance.return_value = (np.zeros((200, 200, 3), dtype=np.uint8), None)
128
 
129
+ # Mock S3 upload failure
130
+ self.mock_s3.upload_fileobj.side_effect = Exception("Upload failed")
131
+
132
+ input_data = {
133
+ "inputs": {
134
+ "image_url": "http://example.com/test.png",
135
+ "outscale": 2
136
+ }
137
+ }
138
+
139
  result = self.handler(input_data)
140
 
141
+ self.assertIsNone(result["image_url"])
142
+ self.assertIsNone(result["image_key"])
143
+ self.assertIn("Failed to upload image to s3", result["error"])
144
 
145
+ def test_missing_image_url(self):
146
+ """Test handling of missing image URL"""
147
+ input_data = {
148
+ "inputs": {
149
+ "outscale": 2
150
+ }
151
+ }
152
+
153
+ result = self.handler(input_data)
154
 
155
+ # Check if result contains all required keys
156
+ self.assertIn("image_url", result)
157
+ self.assertIn("image_key", result)
158
+ self.assertIn("error", result)
159
+
160
+ # Check if values are as expected
161
+ self.assertIsNone(result["image_url"])
162
+ self.assertIsNone(result["image_key"])
163
+ self.assertIn("Failed to get inputs", result["error"])
164
+
165
+ @patch('handler.requests.get')
166
+ def test_grayscale_image(self, mock_get):
167
+ """Test handling of grayscale images"""
168
+ test_image = Image.new('L', (100, 100))
169
+ mock_response = MagicMock()
170
+ mock_response.content = self.image_to_bytes(test_image)
171
+ mock_get.return_value = mock_response
172
+
173
+ # Mock model output
174
+ self.mock_model.enhance.return_value = (np.zeros((200, 200), dtype=np.uint8), None)
175
+
176
+ input_data = {
177
+ "inputs": {
178
+ "image_url": "http://example.com/test.png",
179
+ "outscale": 2
180
+ }
181
+ }
182
 
183
  result = self.handler(input_data)
184
 
185
+ self.assertIsNotNone(result["image_url"])
186
+ self.assertIsNotNone(result["image_key"])
187
  self.assertIsNone(result["error"])
188
 
189
+ @patch('handler.requests.get')
190
+ def test_rgba_image(self, mock_get):
191
+ """Test handling of RGBA images"""
192
+ test_image = Image.new('RGBA', (100, 100))
193
+ mock_response = MagicMock()
194
+ mock_response.content = self.image_to_bytes(test_image)
195
+ mock_get.return_value = mock_response
196
+
197
+ # Mock model output
198
+ self.mock_model.enhance.return_value = (np.zeros((200, 200, 4), dtype=np.uint8), None)
199
+
200
+ input_data = {
201
+ "inputs": {
202
+ "image_url": "http://example.com/test.png",
203
+ "outscale": 2
204
+ }
205
+ }
206
 
207
  result = self.handler(input_data)
208
 
209
+ self.assertIsNotNone(result["image_url"])
210
+ self.assertIsNotNone(result["image_key"])
211
+ self.assertIsNone(result["error"])
212
 
213
  if __name__ == '__main__':
214
  unittest.main()