import os import time from datetime import datetime import torch from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType from compel.prompt_parser import PromptParser from gradio import Error, Info, Progress from spaces import GPU, config from .loader import get_loader from .logger import Logger from .utils import annotate_image, cuda_collect, resize_image, timer @GPU def generate( positive_prompt="", negative_prompt="", image_input=None, controlnet_input=None, ip_adapter_input=None, seed=None, model="XpucT/Reliberate", scheduler="UniPC", controlnet_annotator="canny", width=512, height=512, guidance_scale=6.0, inference_steps=40, denoising_strength=0.8, deepcache_interval=1, scale=1, num_images=1, use_karras=False, use_ip_adapter_face=False, _=Progress(track_tqdm=True), ): if not torch.cuda.is_available(): raise Error("CUDA not available") if positive_prompt.strip() == "": raise Error("You must enter a prompt") start = time.perf_counter() log = Logger("generate") log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...") KIND = "img2img" if image_input is not None else "txt2img" KIND = f"controlnet_{KIND}" if controlnet_input is not None else KIND EMBEDDINGS_TYPE = ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED FAST_NEGATIVE = "" in negative_prompt if ip_adapter_input: IP_KIND = "full-face" if use_ip_adapter_face else "plus" else: IP_KIND = "" # ZeroGPU is serverless so you want ephemeral instances # You want a singleton on localhost so the pipeline stays in memory loader = get_loader(singleton=not config.Config.zero_gpu) loader.load( KIND, IP_KIND, model, scheduler, controlnet_annotator, deepcache_interval, scale, use_karras, ) pipeline = loader.pipeline upscaler = loader.upscaler # Probably a typo in the config if pipeline is None: raise Error(f"Error loading {model}") # Load fast negative embedding if FAST_NEGATIVE: embeddings_dir = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "embeddings") ) pipeline.load_textual_inversion( pretrained_model_name_or_path=f"{embeddings_dir}/fast_negative.pt", token="", ) # Embed prompts with weights compel = Compel( device=pipeline.device, tokenizer=pipeline.tokenizer, truncate_long_prompts=False, text_encoder=pipeline.text_encoder, returned_embeddings_type=EMBEDDINGS_TYPE, dtype_for_device_getter=lambda _: pipeline.dtype, textual_inversion_manager=DiffusersTextualInversionManager(pipeline), ) # https://pytorch.org/docs/stable/generated/torch.manual_seed.html if seed is None or seed < 0: seed = int(datetime.now().timestamp() * 1_000_000) % (2**64) # Increment the seed after each iteration images = [] current_seed = seed for i in range(num_images): try: generator = torch.Generator(device=pipeline.device).manual_seed(current_seed) positive_embeds, negative_embeds = compel.pad_conditioning_tensors_to_same_length( [compel(positive_prompt), compel(negative_prompt)] ) except PromptParser.ParsingException: raise Error("Invalid prompt") kwargs = { "width": width, "height": height, "generator": generator, "prompt_embeds": positive_embeds, "guidance_scale": guidance_scale, "num_inference_steps": inference_steps, "negative_prompt_embeds": negative_embeds, "output_type": "np" if scale > 1 else "pil", } if KIND == "img2img" or KIND == "controlnet_img2img": kwargs["strength"] = denoising_strength kwargs["image"] = resize_image(image_input, (width, height)) if KIND == "controlnet_txt2img": kwargs["image"] = annotate_image(controlnet_input, controlnet_annotator) if KIND == "controlnet_img2img": kwargs["control_image"] = annotate_image(controlnet_input, controlnet_annotator) if IP_KIND: # No size means preserve aspect ratio kwargs["ip_adapter_image"] = resize_image(ip_adapter_input) try: image = pipeline(**kwargs).images[0] images.append((image, str(current_seed))) # tuple with seed for gallery caption current_seed += 1 finally: if FAST_NEGATIVE: pipeline.unload_textual_inversion() # Upscale if scale > 1: with timer(f"Upscaling {num_images} images {scale}x", logger=log.info): for i, image in enumerate(images): image = upscaler.predict(image[0]) seed = images[i][1] images[i] = (image, seed) # tuple again end = time.perf_counter() msg = f"Generating {len(images)} image{'s' if len(images) > 1 else ''} took {end - start:.2f}s" log.info(msg) if Info: Info(msg) # Flush cache before returning cuda_collect() return images