|
import unittest |
|
from unittest.mock import patch, MagicMock |
|
from PIL import Image |
|
import base64 |
|
import numpy as np |
|
from io import BytesIO |
|
from handler import EndpointHandler |
|
|
|
class TestEndpointHandler(unittest.TestCase): |
|
|
|
@patch('handler.RealESRGANer') |
|
def setUp(self, mock_RealESRGANer): |
|
self.handler = EndpointHandler(path=".") |
|
self.mock_model = mock_RealESRGANer.return_value |
|
|
|
def create_test_image(self, mode='RGB', size=(100, 100)): |
|
image = Image.new(mode, size) |
|
buffered = BytesIO() |
|
image.save(buffered, format="PNG") |
|
return base64.b64encode(buffered.getvalue()).decode() |
|
|
|
def get_svg_image(self): |
|
test_image = "test_data/834989.svg" |
|
return test_image |
|
|
|
def test_float_outscale(self): |
|
test_image = self.create_test_image() |
|
input_data = {"inputs": {"image": test_image, "outscale": 2.5}} |
|
|
|
self.mock_model.enhance.return_value = (np.zeros((250, 250, 3), dtype=np.uint8), None) |
|
result = self.handler(input_data) |
|
|
|
self.assertIn("out_image", result) |
|
self.assertIsNone(result["error"]) |
|
|
|
def test_outscale_too_small(self): |
|
test_image = self.create_test_image() |
|
input_data = {"inputs": {"image": test_image, "outscale": 0.5}} |
|
|
|
result = self.handler(input_data) |
|
|
|
self.assertIsNone(result["out_image"]) |
|
self.assertIn("Outscale must be between 1 and 10", result["error"]) |
|
|
|
def test_outscale_too_large(self): |
|
test_image = self.create_test_image() |
|
input_data = {"inputs": {"image": test_image, "outscale": 11}} |
|
|
|
result = self.handler(input_data) |
|
|
|
self.assertIsNone(result["out_image"]) |
|
self.assertIn("Outscale must be between 1 and 10", result["error"]) |
|
|
|
def test_valid_rgb_image(self): |
|
test_image = self.create_test_image() |
|
input_data = {"inputs": {"image": test_image, "outscale": 2}} |
|
|
|
self.mock_model.enhance.return_value = (np.zeros((200, 200, 3), dtype=np.uint8), None) |
|
|
|
result = self.handler(input_data) |
|
|
|
self.assertIn("out_image", result) |
|
self.assertIsNone(result["error"]) |
|
self.mock_model.enhance.assert_called_once() |
|
|
|
def test_valid_rgba_image(self): |
|
test_image = self.create_test_image(mode='RGBA') |
|
input_data = {"inputs": {"image": test_image, "outscale": 2}} |
|
|
|
self.mock_model.enhance.return_value = (np.zeros((400, 400, 4), dtype=np.uint8), None) |
|
|
|
result = self.handler(input_data) |
|
|
|
self.assertIn("out_image", result) |
|
self.assertIsNone(result["error"]) |
|
|
|
def test_missing_image_key(self): |
|
input_data = {"inputs": {}} |
|
|
|
result = self.handler(input_data) |
|
|
|
self.assertIsNone(result["out_image"]) |
|
self.assertIn("Missing key", result["error"]) |
|
|
|
if __name__ == '__main__': |
|
unittest.main() |