Spaces:
Paused
Paused
import base64 | |
from io import BytesIO | |
import os | |
import time | |
import numpy as np | |
from PIL import Image, ImageChops | |
import pytest | |
import requests | |
def local_run(model_endpoint: str, model_input: dict): | |
# Maximum wait time in seconds | |
max_wait_time = 1000 | |
# Interval between status checks in seconds | |
retry_interval = 100 | |
total_wait_time = 0 | |
while total_wait_time < max_wait_time: | |
response = requests.post(model_endpoint, json={"input": model_input}) | |
data = response.json() | |
if "output" in data: | |
try: | |
datauri = data["output"][0] | |
base64_encoded_data = datauri.split(",")[1] | |
decoded_data = base64.b64decode(base64_encoded_data) | |
return Image.open(BytesIO(decoded_data)) | |
except Exception as e: | |
print("Error while processing output:") | |
print("input:", model_input) | |
print(data) | |
raise e | |
elif "detail" in data and data["detail"] == "Already running a prediction": | |
print(f"Prediction in progress, waited {total_wait_time}s, waiting more...") | |
time.sleep(retry_interval) | |
total_wait_time += retry_interval | |
else: | |
print("Unexpected response data:", data) | |
break | |
else: | |
raise Exception("Max wait time exceeded, unable to get valid response") | |
def image_equal_fuzzy(img_expected, img_actual, test_name="default", tol=20): | |
""" | |
Assert that average pixel values differ by less than tol across image | |
Tol determined empirically - holding everything else equal but varying seed | |
generates images that vary by at least 50 | |
""" | |
img1 = np.array(img_expected, dtype=np.int32) | |
img2 = np.array(img_actual, dtype=np.int32) | |
mean_delta = np.mean(np.abs(img1 - img2)) | |
imgs_equal = mean_delta < tol | |
if not imgs_equal: | |
# save failures for quick inspection | |
save_dir = f"/tmp/{test_name}" | |
if not os.path.exists(save_dir): | |
os.makedirs(save_dir) | |
img_expected.save(os.path.join(save_dir, "expected.png")) | |
img_actual.save(os.path.join(save_dir, "actual.png")) | |
difference = ImageChops.difference(img_expected, img_actual) | |
difference.save(os.path.join(save_dir, "delta.png")) | |
return imgs_equal | |
def expected_image(): | |
return Image.open("tests/assets/out.png") | |
def test_seeded_prediction(expected_image): | |
data = { | |
"image": "https://replicate.delivery/pbxt/KIIutO7jIleskKaWebhvurgBUlHR6M6KN7KHaMMWSt4OnVrF/musk_resize.jpeg", | |
"prompt": "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality", | |
"scheduler": "EulerDiscreteScheduler", | |
"enable_lcm": False, | |
"pose_image": "https://replicate.delivery/pbxt/KJmFdQRQVDXGDVdVXftLvFrrvgOPXXRXbzIVEyExPYYOFPyF/80048a6e6586759dbcb529e74a9042ca.jpeg", | |
"sdxl_weights": "protovision-xl-high-fidel", | |
"pose_strength": 0.4, | |
"canny_strength": 0.3, | |
"depth_strength": 0.5, | |
"guidance_scale": 5, | |
"negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured", | |
"ip_adapter_scale": 0.8, | |
"lcm_guidance_scale": 1.5, | |
"num_inference_steps": 30, | |
"enable_pose_controlnet": True, | |
"enhance_nonface_region": True, | |
"enable_canny_controlnet": False, | |
"enable_depth_controlnet": False, | |
"lcm_num_inference_steps": 5, | |
"controlnet_conditioning_scale": 0.8, | |
"seed": 1337, | |
} | |
actual_image = local_run("http://localhost:5000/predictions", data) | |
expected_image = Image.open("tests/assets/out.png") | |
test_result = image_equal_fuzzy( | |
actual_image, expected_image, test_name="test_seeded_prediction" | |
) | |
if test_result: | |
print("Test passed successfully.") | |
else: | |
print("Test failed.") | |
assert test_result | |