English
Inference Endpoints
real-esrgan / unit_tests.py
garg-aayush's picture
Update and add more unit tests
d47736e
raw
history blame
7.41 kB
import unittest
from unittest.mock import patch, MagicMock
import os
from PIL import Image
from io import BytesIO
import numpy as np
from handler import EndpointHandler
class TestEndpointHandler(unittest.TestCase):
@patch('handler.RealESRGANer')
@patch('handler.boto3')
def setUp(self, mock_boto3, mock_RealESRGANer):
"""Set up test environment before each test"""
# Set required environment variables
os.environ['TILING_SIZE'] = '0'
os.environ['AWS_ACCESS_KEY_ID'] = 'test_key'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'test_secret'
os.environ['S3_BUCKET_NAME'] = 'test-bucket'
self.handler = EndpointHandler()
self.mock_model = mock_RealESRGANer.return_value
self.mock_s3 = mock_boto3.client.return_value
def image_to_bytes(self, image):
"""Helper method to convert PIL Image to bytes"""
buffered = BytesIO()
image.save(buffered, format="PNG")
return buffered.getvalue()
@patch('handler.requests.get')
def test_successful_upscale(self, mock_get):
"""Test successful image upscaling"""
# Create test image and mock response
test_image = Image.new('RGB', (100, 100))
mock_response = MagicMock()
mock_response.content = self.image_to_bytes(test_image)
mock_get.return_value = mock_response
# Mock model output
self.mock_model.enhance.return_value = (np.zeros((200, 200, 3), dtype=np.uint8), None)
input_data = {
"inputs": {
"image_url": "http://example.com/test.png",
"outscale": 2
}
}
result = self.handler(input_data)
self.assertIsNotNone(result["image_url"])
self.assertIsNotNone(result["image_key"])
self.assertIsNone(result["error"])
@patch('handler.requests.get')
def test_invalid_outscale(self, mock_get):
"""Test handling of invalid outscale values"""
# Create test image and mock response
test_image = Image.new('RGB', (100, 100))
mock_response = MagicMock()
mock_response.content = self.image_to_bytes(test_image)
mock_get.return_value = mock_response
input_data = {
"inputs": {
"image_url": "http://example.com/test.png",
"outscale": 0.5 # Too small
}
}
result = self.handler(input_data)
self.assertIsNone(result["image_url"])
self.assertIsNone(result["image_key"])
self.assertIn("Outscale must be between 1 and 10", result["error"])
@patch('handler.requests.get')
def test_download_failure(self, mock_get):
"""Test handling of failed image downloads"""
mock_get.side_effect = Exception("Download failed")
input_data = {
"inputs": {
"image_url": "http://example.com/test.png",
"outscale": 2
}
}
result = self.handler(input_data)
self.assertIsNone(result["image_url"])
self.assertIsNone(result["image_key"])
self.assertIn("Failed to download image", result["error"])
@patch('handler.requests.get')
def test_large_image_no_tiling(self, mock_get):
"""Test handling of large images when tiling is disabled"""
# Create an image larger than max_image_size
test_image = Image.new('RGB', (1500, 1500))
mock_response = MagicMock()
mock_response.content = self.image_to_bytes(test_image)
mock_get.return_value = mock_response
input_data = {
"inputs": {
"image_url": "http://example.com/test.png",
"outscale": 2
}
}
result = self.handler(input_data)
self.assertIsNone(result["image_url"])
self.assertIsNone(result["image_key"])
self.assertIn("Image is too large", result["error"])
@patch('handler.requests.get')
def test_s3_upload_failure(self, mock_get):
"""Test handling of S3 upload failures"""
# Create test image and mock response
test_image = Image.new('RGB', (100, 100))
mock_response = MagicMock()
mock_response.content = self.image_to_bytes(test_image)
mock_get.return_value = mock_response
# Mock model output
self.mock_model.enhance.return_value = (np.zeros((200, 200, 3), dtype=np.uint8), None)
# Mock S3 upload failure
self.mock_s3.upload_fileobj.side_effect = Exception("Upload failed")
input_data = {
"inputs": {
"image_url": "http://example.com/test.png",
"outscale": 2
}
}
result = self.handler(input_data)
self.assertIsNone(result["image_url"])
self.assertIsNone(result["image_key"])
self.assertIn("Failed to upload image to s3", result["error"])
def test_missing_image_url(self):
"""Test handling of missing image URL"""
input_data = {
"inputs": {
"outscale": 2
}
}
result = self.handler(input_data)
# Check if result contains all required keys
self.assertIn("image_url", result)
self.assertIn("image_key", result)
self.assertIn("error", result)
# Check if values are as expected
self.assertIsNone(result["image_url"])
self.assertIsNone(result["image_key"])
self.assertIn("Failed to get inputs", result["error"])
@patch('handler.requests.get')
def test_grayscale_image(self, mock_get):
"""Test handling of grayscale images"""
test_image = Image.new('L', (100, 100))
mock_response = MagicMock()
mock_response.content = self.image_to_bytes(test_image)
mock_get.return_value = mock_response
# Mock model output
self.mock_model.enhance.return_value = (np.zeros((200, 200), dtype=np.uint8), None)
input_data = {
"inputs": {
"image_url": "http://example.com/test.png",
"outscale": 2
}
}
result = self.handler(input_data)
self.assertIsNotNone(result["image_url"])
self.assertIsNotNone(result["image_key"])
self.assertIsNone(result["error"])
@patch('handler.requests.get')
def test_rgba_image(self, mock_get):
"""Test handling of RGBA images"""
test_image = Image.new('RGBA', (100, 100))
mock_response = MagicMock()
mock_response.content = self.image_to_bytes(test_image)
mock_get.return_value = mock_response
# Mock model output
self.mock_model.enhance.return_value = (np.zeros((200, 200, 4), dtype=np.uint8), None)
input_data = {
"inputs": {
"image_url": "http://example.com/test.png",
"outscale": 2
}
}
result = self.handler(input_data)
self.assertIsNotNone(result["image_url"])
self.assertIsNotNone(result["image_key"])
self.assertIsNone(result["error"])
if __name__ == '__main__':
unittest.main()