diffusion / generate.py
adamelliotfields's picture
Remove T-GATE
c62ffd9 verified
raw
history blame
11.6 kB
import json
import re
import time
from contextlib import contextmanager
from datetime import datetime
from itertools import product
from os import environ
from types import MethodType
from typing import Callable
import spaces
import tomesd
import torch
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
from compel.prompt_parser import PromptParser
from DeepCache import DeepCacheSDHelper
from diffusers import (
DEISMultistepScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
HeunDiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionPipeline,
)
from diffusers.models import AutoencoderKL, AutoencoderTiny
from torch._dynamo import OptimizedModule
# some models use the deprecated CLIPFeatureExtractor class (should use CLIPImageProcessor)
__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
__import__("transformers").logging.set_verbosity_error()
ZERO_GPU = (
environ.get("SPACES_ZERO_GPU", "").lower() == "true"
or environ.get("SPACES_ZERO_GPU", "") == "1"
)
EMBEDDINGS = {
"./embeddings/bad_prompt_version2.pt": "<bad_prompt>",
"./embeddings/BadDream.pt": "<bad_dream>",
"./embeddings/FastNegativeV2.pt": "<fast_negative>",
"./embeddings/negative_hand.pt": "<negative_hand>",
"./embeddings/UnrealisticDream.pt": "<unrealistic_dream>",
}
with open("./styles/twri.json") as f:
styles = json.load(f)
# inspired by ComfyUI
# https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/model_management.py
class Loader:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(Loader, cls).__new__(cls)
cls._instance.cpu = torch.device("cpu")
cls._instance.gpu = torch.device("cuda")
cls._instance.pipe = None
return cls._instance
def _load_deepcache(self, interval=1):
has_deepcache = hasattr(self.pipe, "deepcache")
if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
return self.pipe.deepcache
if has_deepcache:
self.pipe.deepcache.disable()
else:
self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
self.pipe.deepcache.set_params(cache_interval=interval)
self.pipe.deepcache.enable()
return self.pipe.deepcache
def _load_vae(self, model_name=None, taesd=False, dtype=None):
vae_type = type(self.pipe.vae)
is_kl = issubclass(vae_type, (AutoencoderKL, OptimizedModule))
is_tiny = issubclass(vae_type, AutoencoderTiny)
# by default all models use KL
if is_kl and taesd:
# can't compile tiny VAE
print("Switching to Tiny VAE...")
self.pipe.vae = AutoencoderTiny.from_pretrained(
pretrained_model_name_or_path="madebyollin/taesd",
use_safetensors=True,
torch_dtype=dtype,
).to(self.gpu)
return self.pipe.vae
if is_tiny and not taesd:
print("Switching to KL VAE...")
self.pipe.vae = torch.compile(
fullgraph=True,
mode="reduce-overhead",
model=AutoencoderKL.from_pretrained(
pretrained_model_name_or_path=model_name,
use_safetensors=True,
torch_dtype=dtype,
subfolder="vae",
).to(self.gpu),
)
return self.pipe.vae
def load(self, model, scheduler, karras, taesd, deepcache_interval, dtype=None):
model_lower = model.lower()
schedulers = {
"DEIS 2M": DEISMultistepScheduler,
"DPM++ 2M": DPMSolverMultistepScheduler,
"DPM2 a": KDPM2AncestralDiscreteScheduler,
"Euler a": EulerAncestralDiscreteScheduler,
"Heun": HeunDiscreteScheduler,
"LMS": LMSDiscreteScheduler,
"PNDM": PNDMScheduler,
}
scheduler_kwargs = {
"beta_schedule": "scaled_linear",
"timestep_spacing": "leading",
"use_karras_sigmas": karras,
"beta_start": 0.00085,
"beta_end": 0.012,
"steps_offset": 1,
}
if scheduler == "PNDM" or scheduler == "Euler a":
del scheduler_kwargs["use_karras_sigmas"]
pipe_kwargs = {
"scheduler": schedulers[scheduler](**scheduler_kwargs),
"pretrained_model_name_or_path": model_lower,
"requires_safety_checker": False,
"use_safetensors": True,
"safety_checker": None,
"torch_dtype": dtype,
}
# already loaded
if self.pipe is not None:
model_name = self.pipe.config._name_or_path
same_model = model_name.lower() == model_lower
same_scheduler = isinstance(self.pipe.scheduler, schedulers[scheduler])
same_karras = (
not hasattr(self.pipe.scheduler.config, "use_karras_sigmas")
or self.pipe.scheduler.config.use_karras_sigmas == karras
)
if same_model:
if not same_scheduler:
print(f"Switching to {scheduler}...")
if not same_karras:
print(f"{'Enabling' if karras else 'Disabling'} Karras sigmas...")
if not same_scheduler or not same_karras:
self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs)
self._load_vae(model_lower, taesd, dtype)
self._load_deepcache(interval=deepcache_interval)
return self.pipe
else:
print(f"Unloading {model_name.lower()}...")
self.pipe = None
torch.cuda.empty_cache()
# no fp16 available
if not ZERO_GPU and model_lower not in [
"sg161222/realistic_vision_v5.1_novae",
"prompthero/openjourney-v4",
"linaqruf/anything-v3-1",
]:
pipe_kwargs["variant"] = "fp16"
print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...")
self.pipe = StableDiffusionPipeline.from_pretrained(**pipe_kwargs).to(self.gpu)
self.pipe.load_textual_inversion(
pretrained_model_name_or_path=list(EMBEDDINGS.keys()),
tokens=list(EMBEDDINGS.values()),
)
self._load_vae(model_lower, taesd, dtype)
self._load_deepcache(interval=deepcache_interval)
return self.pipe
# applies tome to the pipeline
@contextmanager
def token_merging(pipe, tome_ratio=0):
try:
if tome_ratio > 0:
tomesd.apply_patch(pipe, max_downsample=1, sx=2, sy=2, ratio=tome_ratio)
yield
finally:
tomesd.remove_patch(pipe) # idempotent
# parse prompts with arrays
def parse_prompt(prompt: str) -> list[str]:
arrays = re.findall(r"\[\[(.*?)\]\]", prompt)
if not arrays:
return [prompt]
tokens = [item.split(",") for item in arrays]
combinations = list(product(*tokens))
prompts = []
for combo in combinations:
current_prompt = prompt
for i, token in enumerate(combo):
current_prompt = current_prompt.replace(f"[[{arrays[i]}]]", token.strip(), 1)
prompts.append(current_prompt)
return prompts
def apply_style(prompt, style_name, negative=False):
global styles
if not style_name or style_name == "None":
return prompt
for style in styles:
if style["name"] == style_name:
if negative:
return prompt + " . " + style["negative_prompt"]
else:
return style["prompt"].format(prompt=prompt)
return prompt
# 1024x1024 for 50 steps can take ~10s each
@spaces.GPU(duration=44)
def generate(
positive_prompt,
negative_prompt="",
style=None,
seed=None,
model="runwayml/stable-diffusion-v1-5",
scheduler="PNDM",
width=512,
height=512,
guidance_scale=7.5,
inference_steps=50,
num_images=1,
karras=False,
taesd=False,
clip_skip=False,
truncate_prompts=False,
increment_seed=True,
deepcache_interval=1,
tome_ratio=0,
log: Callable[[str], None] = None,
Error=Exception,
):
if not torch.cuda.is_available():
raise Error("CUDA not available")
# https://pytorch.org/docs/stable/generated/torch.manual_seed.html
if seed is None or seed < 0:
seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
TORCH_DTYPE = (
torch.bfloat16
if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
else torch.float16
)
EMBEDDINGS_TYPE = (
ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
if clip_skip
else ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
)
with torch.inference_mode():
start = time.perf_counter()
loader = Loader()
pipe = loader.load(model, scheduler, karras, taesd, deepcache_interval, TORCH_DTYPE)
# prompt embeds
compel = Compel(
textual_inversion_manager=DiffusersTextualInversionManager(pipe),
dtype_for_device_getter=lambda _: TORCH_DTYPE,
returned_embeddings_type=EMBEDDINGS_TYPE,
truncate_long_prompts=truncate_prompts,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
device=pipe.device,
)
images = []
current_seed = seed
try:
styled_negative_prompt = apply_style(negative_prompt, style, negative=True)
neg_embeds = compel(styled_negative_prompt)
except PromptParser.ParsingException:
raise Error("ParsingException: Invalid negative prompt")
for i in range(num_images):
# seeded generator for each iteration
generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
try:
all_positive_prompts = parse_prompt(positive_prompt)
prompt_index = i % len(all_positive_prompts)
pos_prompt = all_positive_prompts[prompt_index]
styled_pos_prompt = apply_style(pos_prompt, style)
pos_embeds = compel(styled_pos_prompt)
pos_embeds, neg_embeds = compel.pad_conditioning_tensors_to_same_length(
[pos_embeds, neg_embeds]
)
except PromptParser.ParsingException:
raise Error("ParsingException: Invalid prompt")
with token_merging(pipe, tome_ratio=tome_ratio):
result = pipe(
num_inference_steps=inference_steps,
negative_prompt_embeds=neg_embeds,
guidance_scale=guidance_scale,
prompt_embeds=pos_embeds,
generator=generator,
height=height,
width=width,
)
images.append((result.images[0], str(current_seed)))
if increment_seed:
current_seed += 1
if ZERO_GPU:
# spaces always start fresh
loader.pipe = None
end = time.perf_counter()
diff = end - start
if log:
log(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
return images