madebyollin commited on
Commit
767b242
1 Parent(s): 0b14af5
Files changed (1) hide show
  1. app.py +414 -0
app.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import gradio as gr
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ from diffusers import (
7
+ StableDiffusion3Pipeline,
8
+ SD3Transformer2DModel,
9
+ FlowMatchEulerDiscreteScheduler,
10
+ AutoencoderTiny,
11
+ )
12
+ from typing import Any, Callable, Dict, List, Optional, Union
13
+
14
+ # import spaces
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ dtype = torch.float16
18
+
19
+ repo = "stabilityai/stable-diffusion-3-medium-diffusers"
20
+ pipe = StableDiffusion3Pipeline.from_pretrained(repo, torch_dtype=torch.float16).to(
21
+ device
22
+ )
23
+
24
+ taesd3 = (
25
+ AutoencoderTiny.from_pretrained("madebyollin/taesd3", torch_dtype=torch.float16)
26
+ .half()
27
+ .eval()
28
+ .requires_grad_(False)
29
+ .to(device)
30
+ )
31
+ taesd3.decoder.layers = torch.compile(
32
+ taesd3.decoder.layers,
33
+ fullgraph=True,
34
+ dynamic=False,
35
+ mode="max-autotune-no-cudagraphs",
36
+ )
37
+
38
+ MAX_SEED = np.iinfo(np.int32).max
39
+ MAX_IMAGE_SIZE = 1344
40
+
41
+
42
+ def get_pred_original_sample(sched, model_output, timestep, sample):
43
+ return (
44
+ sample
45
+ - sched.sigmas[(sched.timesteps == timestep).nonzero().item()] * model_output
46
+ )
47
+
48
+
49
+ def retrieve_timesteps(
50
+ scheduler,
51
+ num_inference_steps: Optional[int] = None,
52
+ device: Optional[Union[str, torch.device]] = None,
53
+ timesteps: Optional[List[int]] = None,
54
+ sigmas: Optional[List[float]] = None,
55
+ **kwargs,
56
+ ):
57
+ """
58
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
59
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
60
+
61
+ Args:
62
+ scheduler (`SchedulerMixin`):
63
+ The scheduler to get timesteps from.
64
+ num_inference_steps (`int`):
65
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
66
+ must be `None`.
67
+ device (`str` or `torch.device`, *optional*):
68
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
69
+ timesteps (`List[int]`, *optional*):
70
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
71
+ `num_inference_steps` and `sigmas` must be `None`.
72
+ sigmas (`List[float]`, *optional*):
73
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
74
+ `num_inference_steps` and `timesteps` must be `None`.
75
+
76
+ Returns:
77
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
78
+ second element is the number of inference steps.
79
+ """
80
+ if timesteps is not None and sigmas is not None:
81
+ raise ValueError(
82
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
83
+ )
84
+ if timesteps is not None:
85
+ accepts_timesteps = "timesteps" in set(
86
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
87
+ )
88
+ if not accepts_timesteps:
89
+ raise ValueError(
90
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
91
+ f" timestep schedules. Please check whether you are using the correct scheduler."
92
+ )
93
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
94
+ timesteps = scheduler.timesteps
95
+ num_inference_steps = len(timesteps)
96
+ elif sigmas is not None:
97
+ accept_sigmas = "sigmas" in set(
98
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
99
+ )
100
+ if not accept_sigmas:
101
+ raise ValueError(
102
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
103
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
104
+ )
105
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
106
+ timesteps = scheduler.timesteps
107
+ num_inference_steps = len(timesteps)
108
+ else:
109
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
110
+ timesteps = scheduler.timesteps
111
+ return timesteps, num_inference_steps
112
+
113
+
114
+ @torch.no_grad()
115
+ def sd3_pipe_call_that_returns_an_iterable_of_images(
116
+ self,
117
+ prompt: Union[str, List[str]] = None,
118
+ prompt_2: Optional[Union[str, List[str]]] = None,
119
+ prompt_3: Optional[Union[str, List[str]]] = None,
120
+ height: Optional[int] = None,
121
+ width: Optional[int] = None,
122
+ num_inference_steps: int = 28,
123
+ timesteps: List[int] = None,
124
+ guidance_scale: float = 7.0,
125
+ negative_prompt: Optional[Union[str, List[str]]] = None,
126
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
127
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
128
+ num_images_per_prompt: Optional[int] = 1,
129
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
130
+ latents: Optional[torch.FloatTensor] = None,
131
+ prompt_embeds: Optional[torch.FloatTensor] = None,
132
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
133
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
134
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
135
+ output_type: Optional[str] = "pil",
136
+ return_dict: bool = True,
137
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
138
+ clip_skip: Optional[int] = None,
139
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
140
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
141
+ ):
142
+ height = height or self.default_sample_size * self.vae_scale_factor
143
+ width = width or self.default_sample_size * self.vae_scale_factor
144
+
145
+ # 1. Check inputs. Raise error if not correct
146
+ self.check_inputs(
147
+ prompt,
148
+ prompt_2,
149
+ prompt_3,
150
+ height,
151
+ width,
152
+ negative_prompt=negative_prompt,
153
+ negative_prompt_2=negative_prompt_2,
154
+ negative_prompt_3=negative_prompt_3,
155
+ prompt_embeds=prompt_embeds,
156
+ negative_prompt_embeds=negative_prompt_embeds,
157
+ pooled_prompt_embeds=pooled_prompt_embeds,
158
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
159
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
160
+ )
161
+
162
+ self._guidance_scale = guidance_scale
163
+ self._clip_skip = clip_skip
164
+ self._joint_attention_kwargs = joint_attention_kwargs
165
+ self._interrupt = False
166
+
167
+ # 2. Define call parameters
168
+ if prompt is not None and isinstance(prompt, str):
169
+ batch_size = 1
170
+ elif prompt is not None and isinstance(prompt, list):
171
+ batch_size = len(prompt)
172
+ else:
173
+ batch_size = prompt_embeds.shape[0]
174
+
175
+ device = self._execution_device
176
+
177
+ (
178
+ prompt_embeds,
179
+ negative_prompt_embeds,
180
+ pooled_prompt_embeds,
181
+ negative_pooled_prompt_embeds,
182
+ ) = self.encode_prompt(
183
+ prompt=prompt,
184
+ prompt_2=prompt_2,
185
+ prompt_3=prompt_3,
186
+ negative_prompt=negative_prompt,
187
+ negative_prompt_2=negative_prompt_2,
188
+ negative_prompt_3=negative_prompt_3,
189
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
190
+ prompt_embeds=prompt_embeds,
191
+ negative_prompt_embeds=negative_prompt_embeds,
192
+ pooled_prompt_embeds=pooled_prompt_embeds,
193
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
194
+ device=device,
195
+ clip_skip=self.clip_skip,
196
+ num_images_per_prompt=num_images_per_prompt,
197
+ )
198
+
199
+ if self.do_classifier_free_guidance:
200
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
201
+ pooled_prompt_embeds = torch.cat(
202
+ [negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0
203
+ )
204
+
205
+ # 4. Prepare timesteps
206
+ timesteps, num_inference_steps = retrieve_timesteps(
207
+ self.scheduler, num_inference_steps, device, timesteps
208
+ )
209
+ num_warmup_steps = max(
210
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
211
+ )
212
+ self._num_timesteps = len(timesteps)
213
+
214
+ # 5. Prepare latent variables
215
+ num_channels_latents = self.transformer.config.in_channels
216
+ latents = self.prepare_latents(
217
+ batch_size * num_images_per_prompt,
218
+ num_channels_latents,
219
+ height,
220
+ width,
221
+ prompt_embeds.dtype,
222
+ device,
223
+ generator,
224
+ latents,
225
+ )
226
+
227
+ # 6. Denoising loop
228
+ # with self.progress_bar(total=num_inference_steps) as progress_bar:
229
+ if True:
230
+ for i, t in enumerate(timesteps):
231
+ if self.interrupt:
232
+ continue
233
+
234
+ # expand the latents if we are doing classifier free guidance
235
+ latent_model_input = (
236
+ torch.cat([latents] * 2)
237
+ if self.do_classifier_free_guidance
238
+ else latents
239
+ )
240
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
241
+ timestep = t.expand(latent_model_input.shape[0])
242
+
243
+ noise_pred = self.transformer(
244
+ hidden_states=latent_model_input,
245
+ timestep=timestep,
246
+ encoder_hidden_states=prompt_embeds,
247
+ pooled_projections=pooled_prompt_embeds,
248
+ joint_attention_kwargs=self.joint_attention_kwargs,
249
+ return_dict=False,
250
+ )[0]
251
+
252
+ # perform guidance
253
+ if self.do_classifier_free_guidance:
254
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
255
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
256
+ noise_pred_text - noise_pred_uncond
257
+ )
258
+
259
+ # compute the previous noisy sample x_t -> x_t-1
260
+ latents_dtype = latents.dtype
261
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
262
+
263
+ x0_pred = get_pred_original_sample(self.scheduler, noise_pred, t, latents)
264
+ yield self.image_processor.postprocess(taesd3.decode(x0_pred)[0])[0]
265
+
266
+ # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
267
+ # progress_bar.update()
268
+ #
269
+ yield self.image_processor.postprocess(
270
+ self.vae.decode(
271
+ (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor,
272
+ return_dict=False,
273
+ )[0]
274
+ )[0]
275
+
276
+
277
+ # @spaces.GPU
278
+ def infer(
279
+ prompt,
280
+ negative_prompt,
281
+ seed,
282
+ randomize_seed,
283
+ width,
284
+ height,
285
+ guidance_scale,
286
+ num_inference_steps,
287
+ progress=gr.Progress(track_tqdm=True),
288
+ ):
289
+ if randomize_seed:
290
+ seed = random.randint(0, MAX_SEED)
291
+
292
+ generator = torch.Generator().manual_seed(seed)
293
+
294
+ yield from sd3_pipe_call_that_returns_an_iterable_of_images(
295
+ pipe,
296
+ prompt=prompt,
297
+ negative_prompt=negative_prompt,
298
+ guidance_scale=guidance_scale,
299
+ num_inference_steps=num_inference_steps,
300
+ width=width,
301
+ height=height,
302
+ generator=generator,
303
+ )
304
+
305
+
306
+ examples = [
307
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
308
+ "An astronaut riding a green horse",
309
+ "A delicious ceviche cheesecake slice",
310
+ ]
311
+
312
+ css = """
313
+ #col-container {
314
+ margin: 0 auto;
315
+ max-width: 580px;
316
+ }
317
+ """
318
+
319
+ with gr.Blocks(css=css) as demo:
320
+
321
+ with gr.Column(elem_id="col-container"):
322
+ gr.Markdown(
323
+ f"""
324
+ # Demo [Stable Diffusion 3 Medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
325
+ Learn more about the [Stable Diffusion 3 series](https://stability.ai/news/stable-diffusion-3). Try on [Stability AI API](https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post), [Stable Assistant](https://stability.ai/stable-assistant), or on Discord via [Stable Artisan](https://stability.ai/stable-artisan). Run locally with [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [diffusers](https://github.com/huggingface/diffusers)
326
+ """
327
+ )
328
+
329
+ with gr.Row():
330
+
331
+ prompt = gr.Text(
332
+ label="Prompt",
333
+ show_label=False,
334
+ max_lines=1,
335
+ placeholder="Enter your prompt",
336
+ container=False,
337
+ )
338
+
339
+ run_button = gr.Button("Run", scale=0)
340
+
341
+ result = gr.Image(label="Result", show_label=False)
342
+
343
+ with gr.Accordion("Advanced Settings", open=False):
344
+
345
+ negative_prompt = gr.Text(
346
+ label="Negative prompt",
347
+ max_lines=1,
348
+ placeholder="Enter a negative prompt",
349
+ )
350
+
351
+ seed = gr.Slider(
352
+ label="Seed",
353
+ minimum=0,
354
+ maximum=MAX_SEED,
355
+ step=1,
356
+ value=0,
357
+ )
358
+
359
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
360
+
361
+ with gr.Row():
362
+
363
+ width = gr.Slider(
364
+ label="Width",
365
+ minimum=256,
366
+ maximum=MAX_IMAGE_SIZE,
367
+ step=64,
368
+ value=1024,
369
+ )
370
+
371
+ height = gr.Slider(
372
+ label="Height",
373
+ minimum=256,
374
+ maximum=MAX_IMAGE_SIZE,
375
+ step=64,
376
+ value=1024,
377
+ )
378
+
379
+ with gr.Row():
380
+
381
+ guidance_scale = gr.Slider(
382
+ label="Guidance scale",
383
+ minimum=0.0,
384
+ maximum=10.0,
385
+ step=0.1,
386
+ value=5.0,
387
+ )
388
+
389
+ num_inference_steps = gr.Slider(
390
+ label="Number of inference steps",
391
+ minimum=1,
392
+ maximum=50,
393
+ step=1,
394
+ value=28,
395
+ )
396
+
397
+ gr.Examples(examples=examples, inputs=[prompt])
398
+ gr.on(
399
+ triggers=[run_button.click, prompt.submit, negative_prompt.submit],
400
+ fn=infer,
401
+ inputs=[
402
+ prompt,
403
+ negative_prompt,
404
+ seed,
405
+ randomize_seed,
406
+ width,
407
+ height,
408
+ guidance_scale,
409
+ num_inference_steps,
410
+ ],
411
+ outputs=result,
412
+ )
413
+
414
+ demo.launch(share=True)