custom_nodes / comfyui_controlnet_aux /tests /test_controlnet_aux.py
gartajackhats1985's picture
Upload 1633 files
681fa96 verified
raw
history blame
4.97 kB
import os
import shutil
from io import BytesIO
import numpy as np
import pytest
import requests
from PIL import Image
from custom_controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector,
LeresDetector, LineartAnimeDetector,
LineartDetector, MediapipeFaceDetector,
MidasDetector, MLSDdetector, NormalBaeDetector,
OpenposeDetector, PidiNetDetector, SamDetector,
ZoeDetector, TileDetector)
OUTPUT_DIR = "tests/outputs"
def output(name, img):
img.save(os.path.join(OUTPUT_DIR, "{:s}.png".format(name)))
def common(name, processor, img):
output(name, processor(img))
output(name + "_pil_np", Image.fromarray(processor(img, output_type="np")))
output(name + "_np_np", Image.fromarray(processor(np.array(img, dtype=np.uint8), output_type="np")))
output(name + "_np_pil", processor(np.array(img, dtype=np.uint8), output_type="pil"))
output(name + "_scaled", processor(img, detect_resolution=640, image_resolution=768))
def return_pil(name, processor, img):
output(name + "_pil_false", Image.fromarray(processor(img, return_pil=False)))
output(name + "_pil_true", processor(img, return_pil=True))
@pytest.fixture(scope="module")
def img():
if os.path.exists(OUTPUT_DIR):
shutil.rmtree(OUTPUT_DIR)
os.mkdir(OUTPUT_DIR)
url = "https://huggingface.co/lllyasviel/sd-controlnet-openpose/resolve/main/images/pose.png"
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB").resize((512, 512))
return img
def test_canny(img):
canny = CannyDetector()
common("canny", canny, img)
output("canny_img", canny(img=img))
def test_hed(img):
hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
common("hed", hed, img)
return_pil("hed", hed, img)
output("hed_safe", hed(img, safe=True))
output("hed_scribble", hed(img, scribble=True))
def test_leres(img):
leres = LeresDetector.from_pretrained("lllyasviel/Annotators")
common("leres", leres, img)
output("leres_boost", leres(img, boost=True))
def test_lineart(img):
lineart = LineartDetector.from_pretrained("lllyasviel/Annotators")
common("lineart", lineart, img)
return_pil("lineart", lineart, img)
output("lineart_coarse", lineart(img, coarse=True))
def test_lineart_anime(img):
lineart_anime = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
common("lineart_anime", lineart_anime, img)
return_pil("lineart_anime", lineart_anime, img)
def test_mediapipe_face(img):
mediapipe = MediapipeFaceDetector()
common("mediapipe", mediapipe, img)
output("mediapipe_image", mediapipe(image=img))
def test_midas(img):
midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
common("midas", midas, img)
output("midas_normal", midas(img, depth_and_normal=True)[1])
def test_mlsd(img):
mlsd = MLSDdetector.from_pretrained("lllyasviel/Annotators")
common("mlsd", mlsd, img)
return_pil("mlsd", mlsd, img)
def test_normalbae(img):
normal_bae = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
common("normal_bae", normal_bae, img)
return_pil("normal_bae", normal_bae, img)
def test_openpose(img):
openpose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
common("openpose", openpose, img)
return_pil("openpose", openpose, img)
output("openpose_hand_and_face_false", openpose(img, hand_and_face=False))
output("openpose_hand_and_face_true", openpose(img, hand_and_face=True))
output("openpose_face", openpose(img, include_body=True, include_hand=False, include_face=True))
output("openpose_faceonly", openpose(img, include_body=False, include_hand=False, include_face=True))
output("openpose_full", openpose(img, include_body=True, include_hand=True, include_face=True))
output("openpose_hand", openpose(img, include_body=True, include_hand=True, include_face=False))
def test_pidi(img):
pidi = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
common("pidi", pidi, img)
return_pil("pidi", pidi, img)
output("pidi_safe", pidi(img, safe=True))
output("pidi_scribble", pidi(img, scribble=True))
def test_sam(img):
sam = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
common("sam", sam, img)
output("sam_image", sam(image=img))
def test_shuffle(img):
shuffle = ContentShuffleDetector()
common("shuffle", shuffle, img)
return_pil("shuffle", shuffle, img)
def test_zoe(img):
zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
common("zoe", zoe, img)
def test_tile(img):
tile = TileDetector()
common("tile", tile, img)
output("tile_img", tile(img))