|
import unittest |
|
from unittest.mock import patch, MagicMock |
|
import os |
|
from PIL import Image |
|
from io import BytesIO |
|
import numpy as np |
|
from handler import EndpointHandler |
|
|
|
class TestEndpointHandler(unittest.TestCase): |
|
@patch('handler.RealESRGANer') |
|
@patch('handler.boto3') |
|
def setUp(self, mock_boto3, mock_RealESRGANer): |
|
"""Set up test environment before each test""" |
|
|
|
os.environ['TILING_SIZE'] = '0' |
|
os.environ['AWS_ACCESS_KEY_ID'] = 'test_key' |
|
os.environ['AWS_SECRET_ACCESS_KEY'] = 'test_secret' |
|
os.environ['S3_BUCKET_NAME'] = 'test-bucket' |
|
|
|
self.handler = EndpointHandler() |
|
self.mock_model = mock_RealESRGANer.return_value |
|
self.mock_s3 = mock_boto3.client.return_value |
|
|
|
def image_to_bytes(self, image): |
|
"""Helper method to convert PIL Image to bytes""" |
|
buffered = BytesIO() |
|
image.save(buffered, format="PNG") |
|
return buffered.getvalue() |
|
|
|
@patch('handler.requests.get') |
|
def test_successful_upscale(self, mock_get): |
|
"""Test successful image upscaling""" |
|
|
|
test_image = Image.new('RGB', (100, 100)) |
|
mock_response = MagicMock() |
|
mock_response.content = self.image_to_bytes(test_image) |
|
mock_get.return_value = mock_response |
|
|
|
|
|
self.mock_model.enhance.return_value = (np.zeros((200, 200, 3), dtype=np.uint8), None) |
|
|
|
input_data = { |
|
"inputs": { |
|
"image_url": "http://example.com/test.png", |
|
"outscale": 2 |
|
} |
|
} |
|
|
|
result = self.handler(input_data) |
|
|
|
self.assertIsNotNone(result["image_url"]) |
|
self.assertIsNotNone(result["image_key"]) |
|
self.assertIsNone(result["error"]) |
|
|
|
@patch('handler.requests.get') |
|
def test_invalid_outscale(self, mock_get): |
|
"""Test handling of invalid outscale values""" |
|
|
|
test_image = Image.new('RGB', (100, 100)) |
|
mock_response = MagicMock() |
|
mock_response.content = self.image_to_bytes(test_image) |
|
mock_get.return_value = mock_response |
|
|
|
input_data = { |
|
"inputs": { |
|
"image_url": "http://example.com/test.png", |
|
"outscale": 0.5 |
|
} |
|
} |
|
|
|
result = self.handler(input_data) |
|
|
|
self.assertIsNone(result["image_url"]) |
|
self.assertIsNone(result["image_key"]) |
|
self.assertIn("Outscale must be between 1 and 10", result["error"]) |
|
|
|
@patch('handler.requests.get') |
|
def test_download_failure(self, mock_get): |
|
"""Test handling of failed image downloads""" |
|
mock_get.side_effect = Exception("Download failed") |
|
|
|
input_data = { |
|
"inputs": { |
|
"image_url": "http://example.com/test.png", |
|
"outscale": 2 |
|
} |
|
} |
|
|
|
result = self.handler(input_data) |
|
|
|
self.assertIsNone(result["image_url"]) |
|
self.assertIsNone(result["image_key"]) |
|
self.assertIn("Failed to download image", result["error"]) |
|
|
|
@patch('handler.requests.get') |
|
def test_large_image_no_tiling(self, mock_get): |
|
"""Test handling of large images when tiling is disabled""" |
|
|
|
test_image = Image.new('RGB', (1500, 1500)) |
|
mock_response = MagicMock() |
|
mock_response.content = self.image_to_bytes(test_image) |
|
mock_get.return_value = mock_response |
|
|
|
input_data = { |
|
"inputs": { |
|
"image_url": "http://example.com/test.png", |
|
"outscale": 2 |
|
} |
|
} |
|
|
|
result = self.handler(input_data) |
|
|
|
self.assertIsNone(result["image_url"]) |
|
self.assertIsNone(result["image_key"]) |
|
self.assertIn("Image is too large", result["error"]) |
|
|
|
@patch('handler.requests.get') |
|
def test_s3_upload_failure(self, mock_get): |
|
"""Test handling of S3 upload failures""" |
|
|
|
test_image = Image.new('RGB', (100, 100)) |
|
mock_response = MagicMock() |
|
mock_response.content = self.image_to_bytes(test_image) |
|
mock_get.return_value = mock_response |
|
|
|
|
|
self.mock_model.enhance.return_value = (np.zeros((200, 200, 3), dtype=np.uint8), None) |
|
|
|
|
|
self.mock_s3.upload_fileobj.side_effect = Exception("Upload failed") |
|
|
|
input_data = { |
|
"inputs": { |
|
"image_url": "http://example.com/test.png", |
|
"outscale": 2 |
|
} |
|
} |
|
|
|
result = self.handler(input_data) |
|
|
|
self.assertIsNone(result["image_url"]) |
|
self.assertIsNone(result["image_key"]) |
|
self.assertIn("Failed to upload image to s3", result["error"]) |
|
|
|
def test_missing_image_url(self): |
|
"""Test handling of missing image URL""" |
|
input_data = { |
|
"inputs": { |
|
"outscale": 2 |
|
} |
|
} |
|
|
|
result = self.handler(input_data) |
|
|
|
|
|
self.assertIn("image_url", result) |
|
self.assertIn("image_key", result) |
|
self.assertIn("error", result) |
|
|
|
|
|
self.assertIsNone(result["image_url"]) |
|
self.assertIsNone(result["image_key"]) |
|
self.assertIn("Failed to get inputs", result["error"]) |
|
|
|
@patch('handler.requests.get') |
|
def test_grayscale_image(self, mock_get): |
|
"""Test handling of grayscale images""" |
|
test_image = Image.new('L', (100, 100)) |
|
mock_response = MagicMock() |
|
mock_response.content = self.image_to_bytes(test_image) |
|
mock_get.return_value = mock_response |
|
|
|
|
|
self.mock_model.enhance.return_value = (np.zeros((200, 200), dtype=np.uint8), None) |
|
|
|
input_data = { |
|
"inputs": { |
|
"image_url": "http://example.com/test.png", |
|
"outscale": 2 |
|
} |
|
} |
|
|
|
result = self.handler(input_data) |
|
|
|
self.assertIsNotNone(result["image_url"]) |
|
self.assertIsNotNone(result["image_key"]) |
|
self.assertIsNone(result["error"]) |
|
|
|
@patch('handler.requests.get') |
|
def test_rgba_image(self, mock_get): |
|
"""Test handling of RGBA images""" |
|
test_image = Image.new('RGBA', (100, 100)) |
|
mock_response = MagicMock() |
|
mock_response.content = self.image_to_bytes(test_image) |
|
mock_get.return_value = mock_response |
|
|
|
|
|
self.mock_model.enhance.return_value = (np.zeros((200, 200, 4), dtype=np.uint8), None) |
|
|
|
input_data = { |
|
"inputs": { |
|
"image_url": "http://example.com/test.png", |
|
"outscale": 2 |
|
} |
|
} |
|
|
|
result = self.handler(input_data) |
|
|
|
self.assertIsNotNone(result["image_url"]) |
|
self.assertIsNotNone(result["image_key"]) |
|
self.assertIsNone(result["error"]) |
|
|
|
if __name__ == '__main__': |
|
unittest.main() |