Spaces:
Runtime error
Runtime error
""" | |
Usage: | |
python3 -m unittest tests.test_image_utils | |
""" | |
import base64 | |
from io import BytesIO | |
import os | |
import unittest | |
import numpy as np | |
from PIL import Image | |
from fastchat.utils import ( | |
resize_image_and_return_image_in_bytes, | |
image_moderation_filter, | |
) | |
from fastchat.conversation import get_conv_template | |
def check_byte_size_in_mb(image_base64_str): | |
return len(image_base64_str) / 1024 / 1024 | |
def generate_random_image(target_size_mb, image_format="PNG"): | |
# Convert target size from MB to bytes | |
target_size_bytes = target_size_mb * 1024 * 1024 | |
# Estimate dimensions | |
dimension = int((target_size_bytes / 3) ** 0.5) | |
# Generate random pixel data | |
pixel_data = np.random.randint(0, 256, (dimension, dimension, 3), dtype=np.uint8) | |
# Create an image from the pixel data | |
img = Image.fromarray(pixel_data) | |
# Save image to a temporary file | |
temp_filename = "temp_image." + image_format.lower() | |
img.save(temp_filename, format=image_format) | |
# Check the file size and adjust quality if needed | |
while os.path.getsize(temp_filename) < target_size_bytes: | |
# Increase dimensions or change compression quality | |
dimension += 1 | |
pixel_data = np.random.randint( | |
0, 256, (dimension, dimension, 3), dtype=np.uint8 | |
) | |
img = Image.fromarray(pixel_data) | |
img.save(temp_filename, format=image_format) | |
return img | |
class DontResizeIfLessThanMaxTest(unittest.TestCase): | |
def test_dont_resize_if_less_than_max(self): | |
max_image_size = 5 | |
initial_size_mb = 0.1 # Initial image size | |
img = generate_random_image(initial_size_mb) | |
image_bytes = BytesIO() | |
img.save(image_bytes, format="PNG") # Save the image as JPEG | |
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) | |
image_bytes = resize_image_and_return_image_in_bytes( | |
img, max_image_size_mb=max_image_size | |
) | |
new_image_size = check_byte_size_in_mb(image_bytes.getvalue()) | |
self.assertEqual(previous_image_size, new_image_size) | |
class ResizeLargeImageForModerationEndpoint(unittest.TestCase): | |
def test_resize_large_image_and_send_to_moderation_filter(self): | |
initial_size_mb = 6 # Initial image size which we know is greater than what the endpoint can take | |
img = generate_random_image(initial_size_mb) | |
nsfw_flag, csam_flag = image_moderation_filter(img) | |
self.assertFalse(nsfw_flag) | |
self.assertFalse(nsfw_flag) | |
class DontResizeIfMaxImageSizeIsNone(unittest.TestCase): | |
def test_dont_resize_if_max_image_size_is_none(self): | |
initial_size_mb = 0.2 # Initial image size | |
img = generate_random_image(initial_size_mb) | |
image_bytes = BytesIO() | |
img.save(image_bytes, format="PNG") # Save the image as JPEG | |
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) | |
image_bytes = resize_image_and_return_image_in_bytes( | |
img, max_image_size_mb=None | |
) | |
new_image_size = check_byte_size_in_mb(image_bytes.getvalue()) | |
self.assertEqual(previous_image_size, new_image_size) | |
class OpenAIConversationDontResizeImage(unittest.TestCase): | |
def test(self): | |
conv = get_conv_template("chatgpt") | |
initial_size_mb = 0.2 # Initial image size | |
img = generate_random_image(initial_size_mb) | |
image_bytes = BytesIO() | |
img.save(image_bytes, format="PNG") # Save the image as JPEG | |
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) | |
resized_img = conv.convert_image_to_base64(img) | |
resized_img_bytes = base64.b64decode(resized_img) | |
new_image_size = check_byte_size_in_mb(resized_img_bytes) | |
self.assertEqual(previous_image_size, new_image_size) | |
class ClaudeConversationResizesCorrectly(unittest.TestCase): | |
def test(self): | |
conv = get_conv_template("claude-3-haiku-20240307") | |
initial_size_mb = 5 # Initial image size | |
img = generate_random_image(initial_size_mb) | |
image_bytes = BytesIO() | |
img.save(image_bytes, format="PNG") # Save the image as JPEG | |
previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) | |
resized_img = conv.convert_image_to_base64(img) | |
new_base64_image_size = check_byte_size_in_mb(resized_img) | |
new_image_bytes_size = check_byte_size_in_mb(base64.b64decode(resized_img)) | |
self.assertLess(new_image_bytes_size, previous_image_size) | |
self.assertLessEqual(new_image_bytes_size, conv.max_image_size_mb) | |
self.assertLessEqual(new_base64_image_size, 5) | |