English
Inference Endpoints
real-esrgan / unit_tests.py
garg-aayush's picture
add unit tests, create endpoint test notebook
022c628
raw
history blame
3.31 kB
import unittest
from unittest.mock import patch, MagicMock
from PIL import Image
import base64
import numpy as np
from io import BytesIO
from handler import EndpointHandler
class TestEndpointHandler(unittest.TestCase):
@patch('handler.RealESRGANer')
def setUp(self, mock_RealESRGANer):
self.handler = EndpointHandler(path=".")
self.mock_model = mock_RealESRGANer.return_value
def create_test_image(self, mode='RGB', size=(100, 100)):
image = Image.new(mode, size)
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def get_svg_image(self):
test_image = "test_data/834989.svg"
return test_image
def test_float_outscale(self):
test_image = self.create_test_image()
input_data = {"inputs": {"image": test_image, "outscale": 2.5}}
self.mock_model.enhance.return_value = (np.zeros((250, 250, 3), dtype=np.uint8), None)
result = self.handler(input_data)
self.assertIn("out_image", result)
self.assertIsNone(result["error"])
def test_outscale_too_small(self):
test_image = self.create_test_image()
input_data = {"inputs": {"image": test_image, "outscale": 0.5}}
result = self.handler(input_data)
self.assertIsNone(result["out_image"])
self.assertIn("Outscale must be between 1 and 10", result["error"])
def test_outscale_too_large(self):
test_image = self.create_test_image()
input_data = {"inputs": {"image": test_image, "outscale": 11}}
result = self.handler(input_data)
self.assertIsNone(result["out_image"])
self.assertIn("Outscale must be between 1 and 10", result["error"])
def test_valid_rgb_image(self):
test_image = self.create_test_image()
input_data = {"inputs": {"image": test_image, "outscale": 2}}
self.mock_model.enhance.return_value = (np.zeros((200, 200, 3), dtype=np.uint8), None)
result = self.handler(input_data)
self.assertIn("out_image", result)
self.assertIsNone(result["error"])
self.mock_model.enhance.assert_called_once()
def test_valid_rgba_image(self):
test_image = self.create_test_image(mode='RGBA')
input_data = {"inputs": {"image": test_image, "outscale": 2}}
self.mock_model.enhance.return_value = (np.zeros((400, 400, 4), dtype=np.uint8), None)
result = self.handler(input_data)
self.assertIn("out_image", result)
self.assertIsNone(result["error"])
def test_image_too_large(self):
test_image = self.create_test_image(size=(1500, 1500))
input_data = {"inputs": {"image": test_image}}
result = self.handler(input_data)
self.assertIsNone(result["out_image"])
self.assertIn("Image is too large", result["error"])
def test_missing_image_key(self):
input_data = {"inputs": {}}
result = self.handler(input_data)
self.assertIsNone(result["out_image"])
self.assertIn("Missing key", result["error"])
if __name__ == '__main__':
unittest.main()