import unittest from unittest.mock import patch, MagicMock import os from PIL import Image from io import BytesIO import numpy as np from handler import EndpointHandler class TestEndpointHandler(unittest.TestCase): @patch('handler.RealESRGANer') @patch('handler.boto3') def setUp(self, mock_boto3, mock_RealESRGANer): """Set up test environment before each test""" # Set required environment variables os.environ['TILING_SIZE'] = '0' os.environ['AWS_ACCESS_KEY_ID'] = 'test_key' os.environ['AWS_SECRET_ACCESS_KEY'] = 'test_secret' os.environ['S3_BUCKET_NAME'] = 'test-bucket' self.handler = EndpointHandler() self.mock_model = mock_RealESRGANer.return_value self.mock_s3 = mock_boto3.client.return_value def image_to_bytes(self, image): """Helper method to convert PIL Image to bytes""" buffered = BytesIO() image.save(buffered, format="PNG") return buffered.getvalue() @patch('handler.requests.get') def test_successful_upscale(self, mock_get): """Test successful image upscaling""" # Create test image and mock response test_image = Image.new('RGB', (100, 100)) mock_response = MagicMock() mock_response.content = self.image_to_bytes(test_image) mock_get.return_value = mock_response # Mock model output self.mock_model.enhance.return_value = (np.zeros((200, 200, 3), dtype=np.uint8), None) input_data = { "inputs": { "image_url": "http://example.com/test.png", "outscale": 2 } } result = self.handler(input_data) self.assertIsNotNone(result["image_url"]) self.assertIsNotNone(result["image_key"]) self.assertIsNone(result["error"]) @patch('handler.requests.get') def test_invalid_outscale(self, mock_get): """Test handling of invalid outscale values""" # Create test image and mock response test_image = Image.new('RGB', (100, 100)) mock_response = MagicMock() mock_response.content = self.image_to_bytes(test_image) mock_get.return_value = mock_response input_data = { "inputs": { "image_url": "http://example.com/test.png", "outscale": 0.5 # Too small } } result = self.handler(input_data) self.assertIsNone(result["image_url"]) self.assertIsNone(result["image_key"]) self.assertIn("Outscale must be between 1 and 10", result["error"]) @patch('handler.requests.get') def test_download_failure(self, mock_get): """Test handling of failed image downloads""" mock_get.side_effect = Exception("Download failed") input_data = { "inputs": { "image_url": "http://example.com/test.png", "outscale": 2 } } result = self.handler(input_data) self.assertIsNone(result["image_url"]) self.assertIsNone(result["image_key"]) self.assertIn("Failed to download image", result["error"]) @patch('handler.requests.get') def test_large_image_no_tiling(self, mock_get): """Test handling of large images when tiling is disabled""" # Create an image larger than max_image_size test_image = Image.new('RGB', (1500, 1500)) mock_response = MagicMock() mock_response.content = self.image_to_bytes(test_image) mock_get.return_value = mock_response input_data = { "inputs": { "image_url": "http://example.com/test.png", "outscale": 2 } } result = self.handler(input_data) self.assertIsNone(result["image_url"]) self.assertIsNone(result["image_key"]) self.assertIn("Image is too large", result["error"]) @patch('handler.requests.get') def test_s3_upload_failure(self, mock_get): """Test handling of S3 upload failures""" # Create test image and mock response test_image = Image.new('RGB', (100, 100)) mock_response = MagicMock() mock_response.content = self.image_to_bytes(test_image) mock_get.return_value = mock_response # Mock model output self.mock_model.enhance.return_value = (np.zeros((200, 200, 3), dtype=np.uint8), None) # Mock S3 upload failure self.mock_s3.upload_fileobj.side_effect = Exception("Upload failed") input_data = { "inputs": { "image_url": "http://example.com/test.png", "outscale": 2 } } result = self.handler(input_data) self.assertIsNone(result["image_url"]) self.assertIsNone(result["image_key"]) self.assertIn("Failed to upload image to s3", result["error"]) def test_missing_image_url(self): """Test handling of missing image URL""" input_data = { "inputs": { "outscale": 2 } } result = self.handler(input_data) # Check if result contains all required keys self.assertIn("image_url", result) self.assertIn("image_key", result) self.assertIn("error", result) # Check if values are as expected self.assertIsNone(result["image_url"]) self.assertIsNone(result["image_key"]) self.assertIn("Failed to get inputs", result["error"]) @patch('handler.requests.get') def test_grayscale_image(self, mock_get): """Test handling of grayscale images""" test_image = Image.new('L', (100, 100)) mock_response = MagicMock() mock_response.content = self.image_to_bytes(test_image) mock_get.return_value = mock_response # Mock model output self.mock_model.enhance.return_value = (np.zeros((200, 200), dtype=np.uint8), None) input_data = { "inputs": { "image_url": "http://example.com/test.png", "outscale": 2 } } result = self.handler(input_data) self.assertIsNotNone(result["image_url"]) self.assertIsNotNone(result["image_key"]) self.assertIsNone(result["error"]) @patch('handler.requests.get') def test_rgba_image(self, mock_get): """Test handling of RGBA images""" test_image = Image.new('RGBA', (100, 100)) mock_response = MagicMock() mock_response.content = self.image_to_bytes(test_image) mock_get.return_value = mock_response # Mock model output self.mock_model.enhance.return_value = (np.zeros((200, 200, 4), dtype=np.uint8), None) input_data = { "inputs": { "image_url": "http://example.com/test.png", "outscale": 2 } } result = self.handler(input_data) self.assertIsNotNone(result["image_url"]) self.assertIsNotNone(result["image_key"]) self.assertIsNone(result["error"]) if __name__ == '__main__': unittest.main()