T2I-Adapter-SDXL / model.py
valhalla's picture
Update model.py
97bd822
raw
history blame
12.1 kB
import gc
import os
from abc import ABC, abstractmethod
import PIL.Image
import torch
from controlnet_aux import (
CannyDetector,
LineartDetector,
MidasDetector,
PidiNetDetector,
ZoeDetector,
)
from diffusers import (
AutoencoderKL,
EulerAncestralDiscreteScheduler,
StableDiffusionXLAdapterPipeline,
T2IAdapter,
)
SD_XL_BASE_RATIOS = {
"0.5": (704, 1408),
"0.52": (704, 1344),
"0.57": (768, 1344),
"0.6": (768, 1280),
"0.68": (832, 1216),
"0.72": (832, 1152),
"0.78": (896, 1152),
"0.82": (896, 1088),
"0.88": (960, 1088),
"0.94": (960, 1024),
"1.0": (1024, 1024),
"1.07": (1024, 960),
"1.13": (1088, 960),
"1.21": (1088, 896),
"1.29": (1152, 896),
"1.38": (1152, 832),
"1.46": (1216, 832),
"1.67": (1280, 768),
"1.75": (1344, 768),
"1.91": (1344, 704),
"2.0": (1408, 704),
"2.09": (1472, 704),
"2.4": (1536, 640),
"2.5": (1600, 640),
"2.89": (1664, 576),
"3.0": (1728, 576),
}
def find_closest_aspect_ratio(target_width, target_height):
target_ratio = target_width / target_height
closest_ratio = None
min_difference = float('inf')
for ratio_str, (width, height) in SD_XL_BASE_RATIOS.items():
ratio = width / height
difference = abs(target_ratio - ratio)
if difference < min_difference:
min_difference = difference
closest_ratio = ratio_str
return closest_ratio
def resize_to_closest_aspect_ratio(image):
target_width, target_height = image.size
closest_ratio = find_closest_aspect_ratio(target_width, target_height)
# Get the dimensions from the closest aspect ratio in the dictionary
new_width, new_height = SD_XL_BASE_RATIOS[closest_ratio]
# Resize the image to the new dimensions while preserving the aspect ratio
resized_image = image.resize((new_width, new_height), PIL.Image.ANTIALIAS)
return resized_image
ADAPTER_NAMES = [
"TencentARC/t2i-adapter-canny-sdxl-1.0",
"TencentARC/t2i-adapter-sketch-sdxl-1.0",
"TencentARC/t2i-adapter-lineart-sdxl-1.0",
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0",
"TencentARC/t2i-adapter-depth-zoe-sdxl-1.0",
# "TencentARC/t2i-adapter-recolor-sdxl-1.0",
]
class Preprocessor(ABC):
@abstractmethod
def to(self, device: torch.device | str) -> "Preprocessor":
pass
@abstractmethod
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
pass
class CannyPreprocessor(Preprocessor):
def __init__(self):
self.model = CannyDetector()
def to(self, device: torch.device | str) -> Preprocessor:
return self
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
return self.model(image, detect_resolution=384, image_resolution=1024)
class LineartPreprocessor(Preprocessor):
def __init__(self):
self.model = LineartDetector.from_pretrained("lllyasviel/Annotators")
def to(self, device: torch.device | str) -> Preprocessor:
return self.model.to(device)
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
return self.model(image, detect_resolution=512, image_resolution=1024)
class MidasPreprocessor(Preprocessor):
def __init__(self):
self.model = MidasDetector.from_pretrained(
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
)
def to(self, device: torch.device | str) -> Preprocessor:
return self.model.to(device)
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
return self.model(image, detect_resolution=512, image_resolution=1024)
class PidiNetPreprocessor(Preprocessor):
def __init__(self):
self.model = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
def to(self, device: torch.device | str) -> Preprocessor:
return self.model.to(device)
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
return self.model(image, detect_resolution=512, image_resolution=1024, apply_filter=True)
class RecolorPreprocessor(Preprocessor):
def to(self, device: torch.device | str) -> Preprocessor:
return self
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
return image.convert("L").convert("RGB")
class ZoePreprocessor(Preprocessor):
def __init__(self):
self.model = ZoeDetector.from_pretrained(
"valhalla/t2iadapter-aux-models", filename="zoed_nk.pth", model_type="zoedepth_nk"
)
def to(self, device: torch.device | str) -> Preprocessor:
return self.model.to(device)
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
return self.model(image, gamma_corrected=True, image_resolution=1024)
PRELOAD_PREPROCESSORS_IN_GPU_MEMORY = os.getenv("PRELOAD_PREPROCESSORS_IN_GPU_MEMORY", "1") == "1"
PRELOAD_PREPROCESSORS_IN_CPU_MEMORY = os.getenv("PRELOAD_PREPROCESSORS_IN_CPU_MEMORY", "0") == "1"
if PRELOAD_PREPROCESSORS_IN_GPU_MEMORY:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
preprocessors_gpu: dict[str, Preprocessor] = {
"TencentARC/t2i-adapter-canny-sdxl-1.0": CannyPreprocessor().to(device),
"TencentARC/t2i-adapter-sketch-sdxl-1.0": PidiNetPreprocessor().to(device),
"TencentARC/t2i-adapter-lineart-sdxl-1.0": LineartPreprocessor().to(device),
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0": MidasPreprocessor().to(device),
"TencentARC/t2i-adapter-depth-zoe-sdxl-1.0": ZoePreprocessor().to(device),
"TencentARC/t2i-adapter-recolor-sdxl-1.0": RecolorPreprocessor().to(device),
}
def get_preprocessor(adapter_name: str) -> Preprocessor:
return preprocessors_gpu[adapter_name]
elif PRELOAD_PREPROCESSORS_IN_CPU_MEMORY:
preprocessors_cpu: dict[str, Preprocessor] = {
"TencentARC/t2i-adapter-canny-sdxl-1.0": CannyPreprocessor(),
"TencentARC/t2i-adapter-sketch-sdxl-1.0": PidiNetPreprocessor(),
"TencentARC/t2i-adapter-lineart-sdxl-1.0": LineartPreprocessor(),
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0": MidasPreprocessor(),
"TencentARC/t2i-adapter-depth-zoe-sdxl-1.0": ZoePreprocessor(),
"TencentARC/t2i-adapter-recolor-sdxl-1.0": RecolorPreprocessor(),
}
def get_preprocessor(adapter_name: str) -> Preprocessor:
return preprocessors_cpu[adapter_name]
else:
def get_preprocessor(adapter_name: str) -> Preprocessor:
if adapter_name == "TencentARC/t2i-adapter-canny-sdxl-1.0":
return CannyPreprocessor()
elif adapter_name == "TencentARC/t2i-adapter-sketch-sdxl-1.0":
return PidiNetPreprocessor()
elif adapter_name == "TencentARC/t2i-adapter-lineart-sdxl-1.0":
return LineartPreprocessor()
elif adapter_name == "TencentARC/t2i-adapter-depth-midas-sdxl-1.0":
return MidasPreprocessor()
elif adapter_name == "TencentARC/t2i-adapter-depth-zoe-sdxl-1.0":
return ZoePreprocessor()
elif adapter_name == "TencentARC/t2i-adapter-recolor-sdxl-1.0":
return RecolorPreprocessor()
else:
raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
def download_all_preprocessors():
for adapter_name in ADAPTER_NAMES:
get_preprocessor(adapter_name)
gc.collect()
download_all_preprocessors()
def download_all_adapters():
for adapter_name in ADAPTER_NAMES:
T2IAdapter.from_pretrained(
adapter_name,
torch_dtype=torch.float16,
varient="fp16",
)
gc.collect()
download_all_adapters()
class Model:
MAX_NUM_INFERENCE_STEPS = 50
def __init__(self, adapter_name: str):
if adapter_name not in ADAPTER_NAMES:
raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
self.preprocessor_name = adapter_name
self.adapter_name = adapter_name
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
self.preprocessor = get_preprocessor(adapter_name).to(self.device)
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
adapter = T2IAdapter.from_pretrained(
adapter_name,
torch_dtype=torch.float16,
varient="fp16",
).to(self.device)
self.pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
model_id,
vae=AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16),
adapter=adapter,
scheduler=EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler"),
torch_dtype=torch.float16,
variant="fp16",
).to(self.device)
self.pipe.enable_xformers_memory_efficient_attention()
self.pipe.load_lora_weights("stabilityai/stable-diffusion-xl-base-1.0", weight_name="sd_xl_offset_example-lora_1.0.safetensors")
self.pipe.fuse_lora(lora_scale=0.4)
else:
self.preprocessor = None # type: ignore
self.pipe = None
def change_preprocessor(self, adapter_name: str) -> None:
if adapter_name not in ADAPTER_NAMES:
raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
if adapter_name == self.preprocessor_name:
return
if PRELOAD_PREPROCESSORS_IN_GPU_MEMORY:
pass
elif PRELOAD_PREPROCESSORS_IN_CPU_MEMORY:
self.preprocessor.to("cpu")
else:
del self.preprocessor
self.preprocessor = get_preprocessor(adapter_name).to(self.device)
self.preprocessor_name = adapter_name
gc.collect()
torch.cuda.empty_cache()
def change_adapter(self, adapter_name: str) -> None:
if adapter_name not in ADAPTER_NAMES:
raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
if adapter_name == self.adapter_name:
return
self.pipe.adapter = T2IAdapter.from_pretrained(
adapter_name,
torch_dtype=torch.float16,
varient="fp16",
).to(self.device)
self.adapter_name = adapter_name
gc.collect()
torch.cuda.empty_cache()
def resize_image(self, image: PIL.Image.Image) -> PIL.Image.Image:
w, h = image.size
scale = 1024 / max(w, h)
new_w = int(w * scale)
new_h = int(h * scale)
return image.resize((new_w, new_h), PIL.Image.LANCZOS)
def run(
self,
image: PIL.Image.Image,
prompt: str,
negative_prompt: str,
adapter_name: str,
num_inference_steps: int = 30,
guidance_scale: float = 5.0,
adapter_conditioning_scale: float = 1.0,
cond_tau: float = 1.0,
seed: int = 0,
apply_preprocess: bool = True,
) -> list[PIL.Image.Image]:
if not torch.cuda.is_available():
raise RuntimeError("This demo does not work on CPU.")
if num_inference_steps > self.MAX_NUM_INFERENCE_STEPS:
raise ValueError(f"Number of steps must be less than {self.MAX_NUM_INFERENCE_STEPS}")
# Resize image to avoid OOM
image = self.resize_image(image)
self.change_preprocessor(adapter_name)
self.change_adapter(adapter_name)
if apply_preprocess:
image = self.preprocessor(image)
image = resize_to_closest_aspect_ratio(image)
generator = torch.Generator(device=self.device).manual_seed(seed)
out = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
num_inference_steps=num_inference_steps,
adapter_conditioning_scale=adapter_conditioning_scale,
adapter_conditioning_factor=cond_tau,
generator=generator,
guidance_scale=guidance_scale,
).images[0]
return [image, out]