File size: 6,617 Bytes
cb5daed
7736f5f
4d6f2bc
 
 
48c31e7
dffd0bb
23f4f95
1e250ff
ca5a1e4
f70898c
ca5a1e4
aafe7f2
7a7cda5
 
51fab87
7a7cda5
 
 
 
4d6f2bc
 
7a7cda5
6681256
4470520
 
 
 
6681256
 
4470520
 
 
6681256
 
4470520
6681256
 
51fab87
1e250ff
4d6f2bc
 
 
60849d7
98afd85
7a7cda5
4d6f2bc
af07f4b
4470520
98afd85
48c31e7
 
ca2f5d2
af07f4b
60849d7
6829539
 
4d6f2bc
1a688bc
98afd85
5c4e8c1
1e250ff
 
4d6f2bc
069fc81
 
972fe7d
069fc81
9e8b99d
 
069fc81
1128e78
5c4e8c1
1128e78
1a688bc
 
 
48c31e7
4470520
 
 
 
98afd85
 
4d5d84d
4d6f2bc
ca2f5d2
 
fd9e8de
61ad3d2
6829539
 
 
51fab87
6829539
4470520
1e250ff
 
 
 
 
 
 
 
 
6829539
 
 
4470520
6829539
 
 
 
98afd85
039ff6d
 
6829539
1e250ff
6829539
4d6f2bc
4470520
f70898c
4470520
 
039ff6d
4470520
ca2f5d2
 
79ce657
 
 
ca2f5d2
 
 
 
6829539
51fab87
6829539
 
 
1a7f234
6829539
 
 
 
 
4d6f2bc
6829539
 
9e8b99d
79ce657
6829539
dffd0bb
aafe7f2
4470520
972fe7d
6829539
dffd0bb
f70898c
6829539
 
 
 
 
4470520
6829539
 
f70898c
6829539
 
 
 
 
 
7a7cda5
6829539
 
7a7cda5
6829539
98afd85
7a7cda5
98afd85
 
7a7cda5
98afd85
6829539
7a7cda5
6829539
 
 
 
4470520
6829539
ca2f5d2
f70898c
ca2f5d2
4470520
 
6829539
51fab87
9e8b99d
 
 
 
 
 
 
 
 
51fab87
 
039ff6d
069fc81
 
aafe7f2
51fab87
 
6829539
aafe7f2
51fab87
6829539
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import os
import time
from datetime import datetime

import torch
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
from compel.prompt_parser import PromptParser
from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
from spaces import GPU

from .config import Config
from .loader import Loader
from .logger import Logger
from .utils import (
    annotate_image,
    clear_cuda_cache,
    resize_image,
    safe_progress,
    timer,
)


# Dynamic signature for the GPU duration function
def gpu_duration(**kwargs):
    loading = 20
    duration = 10
    width = kwargs.get("width", 512)
    height = kwargs.get("height", 512)
    scale = kwargs.get("scale", 1)
    num_images = kwargs.get("num_images", 1)
    size = width * height
    if size > 500_000:
        duration += 5
    if scale == 4:
        duration += 5
    return loading + (duration * num_images)


# Request GPU when deployed to Hugging Face
@GPU(duration=gpu_duration)
def generate(
    positive_prompt,
    negative_prompt="",
    image_prompt=None,
    control_image_prompt=None,
    ip_image_prompt=None,
    seed=None,
    model="Lykon/dreamshaper-8",
    scheduler="DDIM",
    annotator="canny",
    width=512,
    height=512,
    guidance_scale=6.0,
    inference_steps=40,
    denoising_strength=0.8,
    deepcache=1,
    scale=1,
    num_images=1,
    karras=False,
    ip_face=False,
    Error=Exception,
    Info=None,
    progress=None,
):
    start = time.perf_counter()
    log = Logger("generate")
    log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")

    if Config.ZERO_GPU:
        safe_progress(progress, 100, 100, "ZeroGPU init")

    if not torch.cuda.is_available():
        raise Error("CUDA not available")

    # https://pytorch.org/docs/stable/generated/torch.manual_seed.html
    if seed is None or seed < 0:
        seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)

    CURRENT_STEP = 0
    CURRENT_IMAGE = 1

    KIND = "img2img" if image_prompt is not None else "txt2img"
    KIND = f"controlnet_{KIND}" if control_image_prompt is not None else KIND

    EMBEDDINGS_TYPE = ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED

    FAST_NEGATIVE = "<fast_negative>" in negative_prompt

    if ip_image_prompt:
        IP_ADAPTER = "full-face" if ip_face else "plus"
    else:
        IP_ADAPTER = ""

    # Custom progress bar for multiple images
    def callback_on_step_end(pipeline, step, timestep, latents):
        nonlocal CURRENT_STEP, CURRENT_IMAGE
        if progress is not None:
            # calculate total steps for img2img based on denoising strength
            strength = denoising_strength if KIND == "img2img" else 1
            total_steps = min(int(inference_steps * strength), inference_steps)
            CURRENT_STEP = step + 1
            progress(
                (CURRENT_STEP, total_steps),
                desc=f"Generating image {CURRENT_IMAGE}/{num_images}",
            )
        return latents

    loader = Loader()
    loader.load(
        KIND,
        IP_ADAPTER,
        model,
        scheduler,
        annotator,
        deepcache,
        scale,
        karras,
        progress,
    )

    if loader.pipe is None:
        raise Error(f"Error loading {model}")

    pipe = loader.pipe
    upscaler = loader.upscaler

    # Load fast negative embedding
    if FAST_NEGATIVE:
        embeddings_dir = os.path.abspath(
            os.path.join(os.path.dirname(__file__), "..", "embeddings")
        )
        pipe.load_textual_inversion(
            pretrained_model_name_or_path=f"{embeddings_dir}/fast_negative.pt",
            token="<fast_negative>",
        )

    # Embed prompts with weights
    compel = Compel(
        device=pipe.device,
        tokenizer=pipe.tokenizer,
        truncate_long_prompts=False,
        text_encoder=pipe.text_encoder,
        returned_embeddings_type=EMBEDDINGS_TYPE,
        dtype_for_device_getter=lambda _: pipe.dtype,
        textual_inversion_manager=DiffusersTextualInversionManager(pipe),
    )

    images = []
    current_seed = seed
    safe_progress(progress, 0, num_images, f"Generating image 0/{num_images}")

    for i in range(num_images):
        try:
            generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
            positive_embeds, negative_embeds = compel.pad_conditioning_tensors_to_same_length(
                [compel(positive_prompt), compel(negative_prompt)]
            )
        except PromptParser.ParsingException:
            raise Error("Invalid prompt")

        kwargs = {
            "width": width,
            "height": height,
            "generator": generator,
            "prompt_embeds": positive_embeds,
            "guidance_scale": guidance_scale,
            "num_inference_steps": inference_steps,
            "negative_prompt_embeds": negative_embeds,
            "output_type": "np" if scale > 1 else "pil",
        }

        if progress is not None:
            kwargs["callback_on_step_end"] = callback_on_step_end

        # Resizing so the initial latents are the same size as the generated image
        if KIND == "img2img":
            kwargs["strength"] = denoising_strength
            kwargs["image"] = resize_image(image_prompt, (width, height))

        if KIND == "controlnet_txt2img":
            kwargs["image"] = annotate_image(control_image_prompt, annotator)

        if KIND == "controlnet_img2img":
            kwargs["control_image"] = annotate_image(control_image_prompt, annotator)

        if IP_ADAPTER:
            kwargs["ip_adapter_image"] = resize_image(ip_image_prompt)

        try:
            image = pipe(**kwargs).images[0]
            images.append((image, str(current_seed)))
            current_seed += 1
        finally:
            if FAST_NEGATIVE:
                pipe.unload_textual_inversion()

            CURRENT_STEP = 0
            CURRENT_IMAGE += 1

    # Upscale
    if scale > 1:
        msg = f"Upscaling {scale}x"
        with timer(msg, logger=log.info):
            safe_progress(progress, 0, num_images, desc=msg)
            for i, image in enumerate(images):
                image = upscaler.predict(image[0])
                images[i] = image
                safe_progress(progress, i + 1, num_images, desc=msg)

    # Flush memory after generating
    clear_cuda_cache()

    end = time.perf_counter()
    msg = f"Generating {len(images)} image{'s' if len(images) > 1 else ''} took {end - start:.2f}s"
    log.info(msg)

    # Alert if notifier provided
    if Info:
        Info(msg)

    return images