hugoycj commited on
Commit
6781e5a
1 Parent(s): cd3eef9

refactor: Clean code and refactor app to use torch.hub

Browse files
app.py CHANGED
@@ -1,46 +1,19 @@
1
- # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # --------------------------------------------------------------------------
15
- # If you find this code useful, we kindly ask you to cite our paper in your work.
16
- # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
- # More information about the method can be found at https://marigoldmonodepth.github.io
18
- # --------------------------------------------------------------------------
19
  from __future__ import annotations
20
 
21
  import functools
22
  import os
23
  import tempfile
24
-
25
- import diffusers
26
  import gradio as gr
27
- import imageio as imageio
28
- import numpy as np
29
- import spaces
30
- import torch as torch
31
- torch.backends.cuda.matmul.allow_tf32 = True
32
  from PIL import Image
33
  from gradio_imageslider import ImageSlider
34
- from tqdm import tqdm
35
-
36
  from pathlib import Path
37
- import gradio
38
  from gradio.utils import get_cache_folder
39
- from stablenormal.pipeline_yoso_normal import YOSONormalsPipeline
40
- from stablenormal.pipeline_stablenormal import StableNormalPipeline
41
- from stablenormal.scheduler.heuristics_ddimsampler import HEURI_DDIMScheduler
42
 
43
- class Examples(gradio.helpers.Examples):
 
 
 
44
  def __init__(self, *args, directory_name=None, **kwargs):
45
  super().__init__(*args, **kwargs, _initiated_directly=False)
46
  if directory_name is not None:
@@ -48,250 +21,94 @@ class Examples(gradio.helpers.Examples):
48
  self.cached_file = Path(self.cached_folder) / "log.csv"
49
  self.create()
50
 
51
-
52
- default_seed = 2024
53
- default_batch_size = 1
54
-
55
- default_image_processing_resolution = 768
56
-
57
- default_video_num_inference_steps = 10
58
- default_video_processing_resolution = 768
59
- default_video_out_max_frames = 60
60
-
61
- def process_image_check(path_input):
62
- if path_input is None:
63
- raise gr.Error(
64
- "Missing image in the first pane: upload a file or use one from the gallery below."
65
- )
66
-
67
- def resize_image(input_image, resolution):
68
- # Ensure input_image is a PIL Image object
69
- if not isinstance(input_image, Image.Image):
70
- raise ValueError("input_image should be a PIL Image object")
71
-
72
- # Convert image to numpy array
73
- input_image_np = np.asarray(input_image)
74
-
75
- # Get image dimensions
76
- H, W, C = input_image_np.shape
77
- H = float(H)
78
- W = float(W)
79
-
80
- # Calculate the scaling factor
81
- k = float(resolution) / min(H, W)
82
-
83
- # Determine new dimensions
84
- H *= k
85
- W *= k
86
- H = int(np.round(H / 64.0)) * 64
87
- W = int(np.round(W / 64.0)) * 64
88
-
89
- # Resize the image using PIL's resize method
90
- img = input_image.resize((W, H), Image.Resampling.LANCZOS)
91
-
92
- return img
93
 
94
  def process_image(
95
- pipe,
96
- path_input,
97
- ):
98
- name_base, name_ext = os.path.splitext(os.path.basename(path_input))
99
- print(f"Processing image {name_base}{name_ext}")
 
 
 
 
 
 
100
 
101
- path_output_dir = tempfile.mkdtemp()
102
- path_out_png = os.path.join(path_output_dir, f"{name_base}_normal_colored.png")
103
  input_image = Image.open(path_input)
104
- input_image = resize_image(input_image, default_image_processing_resolution)
 
 
105
 
106
- pipe_out = pipe(
107
- input_image,
108
- match_input_resolution=False,
109
- processing_resolution=max(input_image.size)
110
- )
111
-
112
- normal_pred = pipe_out.prediction[0, :, :]
113
- normal_colored = pipe.image_processor.visualize_normals(pipe_out.prediction)
114
- normal_colored[-1].save(path_out_png)
115
- yield [input_image, path_out_png]
116
 
117
- def center_crop(img):
118
- # Open the image file
119
- img_width, img_height = img.size
120
- crop_width =min(img_width, img_height)
121
- # Calculate the cropping box
122
- left = (img_width - crop_width) / 2
123
- top = (img_height - crop_width) / 2
124
- right = (img_width + crop_width) / 2
125
- bottom = (img_height + crop_width) / 2
126
 
127
- # Crop the image
128
- img_cropped = img.crop((left, top, right, bottom))
129
- return img_cropped
130
-
131
- def process_video(
132
- pipe,
133
- path_input,
134
- out_max_frames=default_video_out_max_frames,
135
- target_fps=10,
136
- progress=gr.Progress(),
137
- ):
138
- if path_input is None:
139
- raise gr.Error(
140
- "Missing video in the first pane: upload a file or use one from the gallery below."
141
- )
142
-
143
- name_base, name_ext = os.path.splitext(os.path.basename(path_input))
144
- print(f"Processing video {name_base}{name_ext}")
145
-
146
- path_output_dir = tempfile.mkdtemp()
147
- path_out_vis = os.path.join(path_output_dir, f"{name_base}_normal_colored.mp4")
148
-
149
- init_latents = None
150
- reader, writer = None, None
151
- try:
152
- reader = imageio.get_reader(path_input)
153
-
154
- meta_data = reader.get_meta_data()
155
- fps = meta_data["fps"]
156
- size = meta_data["size"]
157
- duration_sec = meta_data["duration"]
158
-
159
- writer = imageio.get_writer(path_out_vis, fps=target_fps)
160
-
161
- out_frame_id = 0
162
- pbar = tqdm(desc="Processing Video", total=duration_sec)
163
-
164
- for frame_id, frame in enumerate(reader):
165
- if frame_id % (fps // target_fps) != 0:
166
- continue
167
- else:
168
- out_frame_id += 1
169
- pbar.update(1)
170
- if out_frame_id > out_max_frames:
171
- break
172
-
173
- frame_pil = Image.fromarray(frame)
174
- frame_pil = center_crop(frame_pil)
175
- pipe_out = pipe(
176
- frame_pil,
177
- match_input_resolution=False,
178
- latents=init_latents
179
- )
180
-
181
- if init_latents is None:
182
- init_latents = pipe_out.gaus_noise
183
- processed_frame = pipe.image_processor.visualize_normals( # noqa
184
- pipe_out.prediction
185
- )[0]
186
- processed_frame = np.array(processed_frame)
187
-
188
- _processed_frame = imageio.core.util.Array(processed_frame)
189
- writer.append_data(_processed_frame)
190
-
191
- yield (
192
- [frame_pil, processed_frame],
193
- None,
194
- )
195
- finally:
196
-
197
- if writer is not None:
198
- writer.close()
199
-
200
- if reader is not None:
201
- reader.close()
202
-
203
- yield (
204
- [frame_pil, processed_frame],
205
- [path_out_vis,]
206
- )
207
-
208
-
209
- def run_demo_server(pipe):
210
- process_pipe_image = spaces.GPU(functools.partial(process_image, pipe))
211
- process_pipe_video = spaces.GPU(
212
- functools.partial(process_video, pipe), duration=120
213
- )
214
-
215
- gradio_theme = gr.themes.Default()
216
-
217
- with gr.Blocks(
218
- theme=gradio_theme,
219
  title="Stable Normal Estimation",
220
  css="""
221
- #download {
222
- height: 118px;
223
- }
224
- .slider .inner {
225
- width: 5px;
226
- background: #FFF;
227
- }
228
- .viewport {
229
- aspect-ratio: 4/3;
230
- }
231
- .tabs button.selected {
232
- font-size: 20px !important;
233
- color: crimson !important;
234
- }
235
- h1 {
236
- text-align: center;
237
- display: block;
238
- }
239
- h2 {
240
- text-align: center;
241
- display: block;
242
- }
243
- h3 {
244
- text-align: center;
245
- display: block;
246
- }
247
- .md_feedback li {
248
- margin-bottom: 0px !important;
249
- }
250
- """,
251
- head="""
252
- <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
253
- <script>
254
- window.dataLayer = window.dataLayer || [];
255
- function gtag() {dataLayer.push(arguments);}
256
- gtag('js', new Date());
257
- gtag('config', 'G-1FWSVCGZTG');
258
- </script>
259
- """,
260
- ) as demo:
261
- gr.Markdown(
262
- """
263
- # StableNormal: Reducing Diffusion Variance for Stable and Sharp Normal
264
- <p align="center">
265
-
266
- <a title="Website" href="https://stable-x.github.io/StableNormal/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
267
- <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
268
- </a>
269
- <a title="arXiv" href="https://arxiv.org/abs/2406.16864" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
270
- <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
271
- </a>
272
- <a title="Github" href="https://github.com/Stable-X/StableNormal" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
273
- <img src="https://img.shields.io/github/stars/Stable-X/StableDelight?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
274
- </a>
275
- <a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
276
- <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
277
- </a>
278
  """
279
- )
280
- with gr.Tabs(elem_classes=["tabs"]):
281
- with gr.Tab("Image"):
 
 
 
 
 
282
  with gr.Row():
283
  with gr.Column():
284
- image_input = gr.Image(
285
- label="Input Image",
286
- type="filepath",
 
 
 
 
 
287
  )
288
  with gr.Row():
289
- image_submit_btn = gr.Button(
290
- value="Compute Normal", variant="primary"
291
- )
292
- image_reset_btn = gr.Button(value="Reset")
293
  with gr.Column():
294
- image_output_slider = ImageSlider(
295
  label="Normal outputs",
296
  type="filepath",
297
  show_download_button=True,
@@ -302,32 +119,38 @@ def run_demo_server(pipe):
302
  )
303
 
304
  Examples(
305
- fn=process_pipe_image,
306
  examples=sorted([
307
- os.path.join("files", "image", name)
308
- for name in os.listdir(os.path.join("files", "image"))
 
309
  ]),
310
- inputs=[image_input],
311
- outputs=[image_output_slider],
312
  cache_examples=True,
313
- directory_name="examples_image",
 
314
  )
315
 
316
- with gr.Tab("Video"):
 
317
  with gr.Row():
318
  with gr.Column():
319
- video_input = gr.Video(
320
- label="Input Video",
321
- sources=["upload", "webcam"],
 
 
 
 
 
322
  )
323
  with gr.Row():
324
- video_submit_btn = gr.Button(
325
- value="Compute Normal", variant="primary"
326
- )
327
- video_reset_btn = gr.Button(value="Reset")
328
  with gr.Column():
329
- processed_frames = ImageSlider(
330
- label="Realtime Visualization",
331
  type="filepath",
332
  show_download_button=True,
333
  show_share_button=True,
@@ -335,111 +158,127 @@ def run_demo_server(pipe):
335
  elem_classes="slider",
336
  position=0.25,
337
  )
338
- video_output_files = gr.Files(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  label="Normal outputs",
340
- elem_id="download",
 
 
341
  interactive=False,
 
 
342
  )
 
343
  Examples(
344
- fn=process_pipe_video,
345
  examples=sorted([
346
- os.path.join("files", "video", name)
347
- for name in os.listdir(os.path.join("files", "video"))
 
348
  ]),
349
- inputs=[video_input],
350
- outputs=[processed_frames, video_output_files],
351
- directory_name="examples_video",
352
- cache_examples=False,
 
353
  )
354
-
355
- with gr.Tab("Panorama"):
356
- with gr.Column():
357
- gr.Markdown("Coming soon")
358
 
359
- with gr.Tab("4K Image"):
360
- with gr.Column():
361
- gr.Markdown("Coming soon")
362
-
363
- ### Image tab
364
- image_submit_btn.click(
365
- fn=process_image_check,
366
- inputs=image_input,
367
  outputs=None,
368
- preprocess=False,
369
  queue=False,
370
  ).success(
371
- fn=process_pipe_image,
372
- inputs=[
373
- image_input,
374
- ],
375
- outputs=[image_output_slider],
376
- concurrency_limit=1,
377
  )
378
 
379
- image_reset_btn.click(
380
- fn=lambda: (
381
- None,
382
- None,
383
- None,
384
- ),
385
  inputs=[],
386
- outputs=[
387
- image_input,
388
- image_output_slider,
389
- ],
390
  queue=False,
391
  )
392
 
393
- ### Video tab
394
-
395
- video_submit_btn.click(
396
- fn=process_pipe_video,
397
- inputs=[video_input],
398
- outputs=[processed_frames, video_output_files],
399
- concurrency_limit=1,
 
 
 
400
  )
401
 
402
- video_reset_btn.click(
403
- fn=lambda: (None, None, None),
404
  inputs=[],
405
- outputs=[video_input, processed_frames, video_output_files],
406
- concurrency_limit=1,
407
  )
408
 
409
- ### Server launch
 
 
 
 
 
 
 
 
 
 
410
 
411
- demo.queue(
412
- api_open=False,
413
- ).launch(
414
- server_name="0.0.0.0",
415
- server_port=7860,
416
  )
417
 
 
418
 
419
  def main():
420
- os.system("pip freeze")
421
-
422
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
423
-
424
- x_start_pipeline = YOSONormalsPipeline.from_pretrained(
425
- 'Stable-X/yoso-normal-v0-3', trust_remote_code=True, variant="fp16", torch_dtype=torch.float16).to(device)
426
- pipe = StableNormalPipeline.from_pretrained('Stable-X/stable-normal-v0-1', trust_remote_code=True,
427
- variant="fp16", torch_dtype=torch.float16,
428
- scheduler=HEURI_DDIMScheduler(prediction_type='sample',
429
- beta_start=0.00085, beta_end=0.0120,
430
- beta_schedule = "scaled_linear"))
431
- pipe.x_start_pipeline = x_start_pipeline
432
- pipe.to(device)
433
- pipe.prior.to(device, torch.float16)
434
-
435
- try:
436
- import xformers
437
- pipe.enable_xformers_memory_efficient_attention()
438
- except:
439
- pass # run without xformers
440
-
441
- run_demo_server(pipe)
442
-
443
 
444
  if __name__ == "__main__":
445
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import functools
4
  import os
5
  import tempfile
6
+ import torch
 
7
  import gradio as gr
 
 
 
 
 
8
  from PIL import Image
9
  from gradio_imageslider import ImageSlider
 
 
10
  from pathlib import Path
 
11
  from gradio.utils import get_cache_folder
 
 
 
12
 
13
+ # Constants
14
+ DEFAULT_SHARPNESS = 2
15
+
16
+ class Examples(gr.helpers.Examples):
17
  def __init__(self, *args, directory_name=None, **kwargs):
18
  super().__init__(*args, **kwargs, _initiated_directly=False)
19
  if directory_name is not None:
 
21
  self.cached_file = Path(self.cached_folder) / "log.csv"
22
  self.create()
23
 
24
+ def load_predictor():
25
+ """Load model predictor using torch.hub"""
26
+ predictor = torch.hub.load("hugoycj/StableNormal", "StableNormal", trust_repo=True,
27
+ local_cache_dir='./weights')
28
+ return predictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def process_image(
31
+ predictor,
32
+ path_input: str,
33
+ sharpness: int = DEFAULT_SHARPNESS,
34
+ data_type: str = "object"
35
+ ) -> tuple:
36
+ """Process single image"""
37
+ if path_input is None:
38
+ raise gr.Error("Please upload an image or select one from the gallery.")
39
+
40
+ name_base = os.path.splitext(os.path.basename(path_input))[0]
41
+ out_path = os.path.join(tempfile.mkdtemp(), f"{name_base}_normal.png")
42
 
43
+ # Load and process image
 
44
  input_image = Image.open(path_input)
45
+ normal_image = predictor(input_image, num_inference_steps=sharpness,
46
+ match_input_resolution=False, data_type=data_type)
47
+ normal_image.save(out_path)
48
 
49
+ yield [input_image, out_path]
 
 
 
 
 
 
 
 
 
50
 
51
+ def create_demo():
52
+ # Load model
53
+ predictor = load_predictor()
 
 
 
 
 
 
54
 
55
+ # Create processing functions for each data type
56
+ process_object = functools.partial(process_image, predictor, data_type="object")
57
+ process_scene = functools.partial(process_image, predictor, data_type="indoor")
58
+ process_human = functools.partial(process_image, predictor, data_type="object")
59
+
60
+ # Define markdown content
61
+ HEADER_MD = """
62
+ # StableNormal: Reducing Diffusion Variance for Stable and Sharp Normal
63
+ <p align="center">
64
+ <a title="Website" href="https://stable-x.github.io/StableNormal/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
65
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
66
+ </a>
67
+ <a title="arXiv" href="https://arxiv.org/abs/2406.16864" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
68
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
69
+ </a>
70
+ <a title="Github" href="https://github.com/Stable-X/StableNormal" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
71
+ <img src="https://img.shields.io/github/stars/Stable-X/StableDelight?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
72
+ </a>
73
+ <a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
74
+ <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
75
+ </a>
76
+ """
77
+
78
+ # Create interface
79
+ demo = gr.Blocks(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  title="Stable Normal Estimation",
81
  css="""
82
+ .slider .inner { width: 5px; background: #FFF; }
83
+ .viewport { aspect-ratio: 4/3; }
84
+ .tabs button.selected { font-size: 20px !important; color: crimson !important; }
85
+ h1, h2, h3 { text-align: center; display: block; }
86
+ .md_feedback li { margin-bottom: 0px !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  """
88
+ )
89
+
90
+ with demo:
91
+ gr.Markdown(HEADER_MD)
92
+
93
+ with gr.Tabs() as tabs:
94
+ # Object Tab
95
+ with gr.Tab("Object"):
96
  with gr.Row():
97
  with gr.Column():
98
+ object_input = gr.Image(label="Input Object Image", type="filepath")
99
+ object_sharpness = gr.Slider(
100
+ minimum=1,
101
+ maximum=10,
102
+ value=DEFAULT_SHARPNESS,
103
+ step=1,
104
+ label="Sharpness (inference steps)",
105
+ info="Higher values produce sharper results but take longer"
106
  )
107
  with gr.Row():
108
+ object_submit_btn = gr.Button("Compute Normal", variant="primary")
109
+ object_reset_btn = gr.Button("Reset")
 
 
110
  with gr.Column():
111
+ object_output_slider = ImageSlider(
112
  label="Normal outputs",
113
  type="filepath",
114
  show_download_button=True,
 
119
  )
120
 
121
  Examples(
122
+ fn=process_object,
123
  examples=sorted([
124
+ os.path.join("files", "object", name)
125
+ for name in os.listdir(os.path.join("files", "object"))
126
+ if os.path.exists(os.path.join("files", "object"))
127
  ]),
128
+ inputs=[object_input],
129
+ outputs=[object_output_slider],
130
  cache_examples=True,
131
+ directory_name="examples_object",
132
+ examples_per_page=50,
133
  )
134
 
135
+ # Scene Tab
136
+ with gr.Tab("Scene"):
137
  with gr.Row():
138
  with gr.Column():
139
+ scene_input = gr.Image(label="Input Scene Image", type="filepath")
140
+ scene_sharpness = gr.Slider(
141
+ minimum=1,
142
+ maximum=10,
143
+ value=DEFAULT_SHARPNESS,
144
+ step=1,
145
+ label="Sharpness (inference steps)",
146
+ info="Higher values produce sharper results but take longer"
147
  )
148
  with gr.Row():
149
+ scene_submit_btn = gr.Button("Compute Normal", variant="primary")
150
+ scene_reset_btn = gr.Button("Reset")
 
 
151
  with gr.Column():
152
+ scene_output_slider = ImageSlider(
153
+ label="Normal outputs",
154
  type="filepath",
155
  show_download_button=True,
156
  show_share_button=True,
 
158
  elem_classes="slider",
159
  position=0.25,
160
  )
161
+
162
+ Examples(
163
+ fn=process_scene,
164
+ examples=sorted([
165
+ os.path.join("files", "scene", name)
166
+ for name in os.listdir(os.path.join("files", "scene"))
167
+ if os.path.exists(os.path.join("files", "scene"))
168
+ ]),
169
+ inputs=[scene_input],
170
+ outputs=[scene_output_slider],
171
+ cache_examples=True,
172
+ directory_name="examples_scene",
173
+ examples_per_page=50,
174
+ )
175
+
176
+ # Human Tab
177
+ with gr.Tab("Human"):
178
+ with gr.Row():
179
+ with gr.Column():
180
+ human_input = gr.Image(label="Input Human Image", type="filepath")
181
+ human_sharpness = gr.Slider(
182
+ minimum=1,
183
+ maximum=10,
184
+ value=DEFAULT_SHARPNESS,
185
+ step=1,
186
+ label="Sharpness (inference steps)",
187
+ info="Higher values produce sharper results but take longer"
188
+ )
189
+ with gr.Row():
190
+ human_submit_btn = gr.Button("Compute Normal", variant="primary")
191
+ human_reset_btn = gr.Button("Reset")
192
+ with gr.Column():
193
+ human_output_slider = ImageSlider(
194
  label="Normal outputs",
195
+ type="filepath",
196
+ show_download_button=True,
197
+ show_share_button=True,
198
  interactive=False,
199
+ elem_classes="slider",
200
+ position=0.25,
201
  )
202
+
203
  Examples(
204
+ fn=process_human,
205
  examples=sorted([
206
+ os.path.join("files", "human", name)
207
+ for name in os.listdir(os.path.join("files", "human"))
208
+ if os.path.exists(os.path.join("files", "human"))
209
  ]),
210
+ inputs=[human_input],
211
+ outputs=[human_output_slider],
212
+ cache_examples=True,
213
+ directory_name="examples_human",
214
+ examples_per_page=50,
215
  )
 
 
 
 
216
 
217
+ # Event Handlers for Object Tab
218
+ object_submit_btn.click(
219
+ fn=lambda x, _: None if x else gr.Error("Please upload an image"),
220
+ inputs=[object_input, object_sharpness],
 
 
 
 
221
  outputs=None,
 
222
  queue=False,
223
  ).success(
224
+ fn=process_object,
225
+ inputs=[object_input, object_sharpness],
226
+ outputs=[object_output_slider],
 
 
 
227
  )
228
 
229
+ object_reset_btn.click(
230
+ fn=lambda: (None, DEFAULT_SHARPNESS, None),
 
 
 
 
231
  inputs=[],
232
+ outputs=[object_input, object_sharpness, object_output_slider],
 
 
 
233
  queue=False,
234
  )
235
 
236
+ # Event Handlers for Scene Tab
237
+ scene_submit_btn.click(
238
+ fn=lambda x, _: None if x else gr.Error("Please upload an image"),
239
+ inputs=[scene_input, scene_sharpness],
240
+ outputs=None,
241
+ queue=False,
242
+ ).success(
243
+ fn=process_scene,
244
+ inputs=[scene_input, scene_sharpness],
245
+ outputs=[scene_output_slider],
246
  )
247
 
248
+ scene_reset_btn.click(
249
+ fn=lambda: (None, DEFAULT_SHARPNESS, None),
250
  inputs=[],
251
+ outputs=[scene_input, scene_sharpness, scene_output_slider],
252
+ queue=False,
253
  )
254
 
255
+ # Event Handlers for Human Tab
256
+ human_submit_btn.click(
257
+ fn=lambda x, _: None if x else gr.Error("Please upload an image"),
258
+ inputs=[human_input, human_sharpness],
259
+ outputs=None,
260
+ queue=False,
261
+ ).success(
262
+ fn=process_human,
263
+ inputs=[human_input, human_sharpness],
264
+ outputs=[human_output_slider],
265
+ )
266
 
267
+ human_reset_btn.click(
268
+ fn=lambda: (None, DEFAULT_SHARPNESS, None),
269
+ inputs=[],
270
+ outputs=[human_input, human_sharpness, human_output_slider],
271
+ queue=False,
272
  )
273
 
274
+ return demo
275
 
276
  def main():
277
+ demo = create_demo()
278
+ demo.queue(api_open=False).launch(
279
+ server_name="0.0.0.0",
280
+ server_port=7860,
281
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  if __name__ == "__main__":
284
+ main()
setup.py DELETED
@@ -1,9 +0,0 @@
1
- from pathlib import Path
2
- from setuptools import setup, find_packages
3
-
4
- setup_path = Path(__file__).parent
5
-
6
- setup(
7
- name = "stablenormal",
8
- packages=find_packages()
9
- )
 
 
 
 
 
 
 
 
 
 
stablenormal/__init__.py DELETED
File without changes
stablenormal/pipeline_stablenormal.py DELETED
@@ -1,1279 +0,0 @@
1
- # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
2
- # Copyright 2024 The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # --------------------------------------------------------------------------
16
- # More information and citation instructions are available on the
17
- # --------------------------------------------------------------------------
18
- from dataclasses import dataclass
19
- from typing import Any, Dict, List, Optional, Tuple, Union
20
-
21
- import numpy as np
22
- import torch
23
- from PIL import Image
24
- from tqdm.auto import tqdm
25
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
26
-
27
-
28
- from diffusers.image_processor import PipelineImageInput
29
- from diffusers.models import (
30
- AutoencoderKL,
31
- UNet2DConditionModel,
32
- ControlNetModel,
33
- )
34
- from diffusers.schedulers import (
35
- DDIMScheduler
36
- )
37
-
38
- from diffusers.utils import (
39
- BaseOutput,
40
- logging,
41
- replace_example_docstring,
42
- )
43
-
44
- from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
45
-
46
- from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
47
-
48
-
49
-
50
- from diffusers.utils.torch_utils import randn_tensor
51
- from diffusers.pipelines.controlnet import StableDiffusionControlNetPipeline
52
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
53
- from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor
54
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
55
- import torch.nn.functional as F
56
-
57
- import pdb
58
-
59
-
60
-
61
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62
-
63
-
64
- EXAMPLE_DOC_STRING = """
65
- Examples:
66
- ```py
67
- >>> import diffusers
68
- >>> import torch
69
-
70
- >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(
71
- ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16
72
- ... ).to("cuda")
73
-
74
- >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
75
- >>> normals = pipe(image)
76
-
77
- >>> vis = pipe.image_processor.visualize_normals(normals.prediction)
78
- >>> vis[0].save("einstein_normals.png")
79
- ```
80
- """
81
-
82
-
83
- @dataclass
84
- class StableNormalOutput(BaseOutput):
85
- """
86
- Output class for Marigold monocular normals prediction pipeline.
87
-
88
- Args:
89
- prediction (`np.ndarray`, `torch.Tensor`):
90
- Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height
91
- \times width$, regardless of whether the images were passed as a 4D array or a list.
92
- uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
93
- Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
94
- \times 1 \times height \times width$.
95
- latent (`None`, `torch.Tensor`):
96
- Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
97
- The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
98
- """
99
-
100
- prediction: Union[np.ndarray, torch.Tensor]
101
- latent: Union[None, torch.Tensor]
102
- gaus_noise: Union[None, torch.Tensor]
103
-
104
- from einops import rearrange
105
- class DINOv2_Encoder(torch.nn.Module):
106
- IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
107
- IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
108
-
109
- def __init__(
110
- self,
111
- model_name = 'dinov2_vitl14',
112
- freeze = True,
113
- antialias=True,
114
- device="cuda",
115
- size = 448,
116
- ):
117
- super(DINOv2_Encoder, self).__init__()
118
-
119
- self.model = torch.hub.load('facebookresearch/dinov2', model_name)
120
- self.model.eval().to(device)
121
- self.device = device
122
- self.antialias = antialias
123
- self.dtype = torch.float32
124
-
125
- self.mean = torch.Tensor(self.IMAGENET_DEFAULT_MEAN)
126
- self.std = torch.Tensor(self.IMAGENET_DEFAULT_STD)
127
- self.size = size
128
- if freeze:
129
- self.freeze()
130
-
131
-
132
- def freeze(self):
133
- for param in self.model.parameters():
134
- param.requires_grad = False
135
-
136
- @torch.no_grad()
137
- def encoder(self, x):
138
- '''
139
- x: [b h w c], range from (-1, 1), rbg
140
- '''
141
-
142
- x = self.preprocess(x).to(self.device, self.dtype)
143
-
144
- b, c, h, w = x.shape
145
- patch_h, patch_w = h // 14, w // 14
146
-
147
- embeddings = self.model.forward_features(x)['x_norm_patchtokens']
148
- embeddings = rearrange(embeddings, 'b (h w) c -> b h w c', h = patch_h, w = patch_w)
149
-
150
- return rearrange(embeddings, 'b h w c -> b c h w')
151
-
152
- def preprocess(self, x):
153
- ''' x
154
- '''
155
- # normalize to [0,1],
156
- x = torch.nn.functional.interpolate(
157
- x,
158
- size=(self.size, self.size),
159
- mode='bicubic',
160
- align_corners=True,
161
- antialias=self.antialias,
162
- )
163
-
164
- x = (x + 1.0) / 2.0
165
- # renormalize according to dino
166
- mean = self.mean.view(1, 3, 1, 1).to(x.device)
167
- std = self.std.view(1, 3, 1, 1).to(x.device)
168
- x = (x - mean) / std
169
-
170
- return x
171
-
172
- def to(self, device, dtype=None):
173
- if dtype is not None:
174
- self.dtype = dtype
175
- self.model.to(device, dtype)
176
- self.mean.to(device, dtype)
177
- self.std.to(device, dtype)
178
- else:
179
- self.model.to(device)
180
- self.mean.to(device)
181
- self.std.to(device)
182
- return self
183
-
184
- def __call__(self, x, **kwargs):
185
- return self.encoder(x, **kwargs)
186
-
187
- class StableNormalPipeline(StableDiffusionControlNetPipeline):
188
- """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io.
189
- Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
190
-
191
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
192
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
193
-
194
- The pipeline also inherits the following loading methods:
195
- - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
196
- - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
197
- - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
198
- - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
199
- - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
200
-
201
- Args:
202
- vae ([`AutoencoderKL`]):
203
- Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
204
- text_encoder ([`~transformers.CLIPTextModel`]):
205
- Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
206
- tokenizer ([`~transformers.CLIPTokenizer`]):
207
- A `CLIPTokenizer` to tokenize text.
208
- unet ([`UNet2DConditionModel`]):
209
- A `UNet2DConditionModel` to denoise the encoded image latents.
210
- controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
211
- Provides additional conditioning to the `unet` during the denoising process. If you set multiple
212
- ControlNets as a list, the outputs from each ControlNet are added together to create one combined
213
- additional conditioning.
214
- scheduler ([`SchedulerMixin`]):
215
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
216
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
217
- safety_checker ([`StableDiffusionSafetyChecker`]):
218
- Classification module that estimates whether generated images could be considered offensive or harmful.
219
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
220
- about a model's potential harms.
221
- feature_extractor ([`~transformers.CLIPImageProcessor`]):
222
- A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
223
- """
224
-
225
- model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
226
- _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
227
- _exclude_from_cpu_offload = ["safety_checker"]
228
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
229
-
230
-
231
-
232
- def __init__(
233
- self,
234
- vae: AutoencoderKL,
235
- text_encoder: CLIPTextModel,
236
- tokenizer: CLIPTokenizer,
237
- unet: UNet2DConditionModel,
238
- controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]],
239
- dino_controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]],
240
- scheduler: Union[DDIMScheduler],
241
- safety_checker: StableDiffusionSafetyChecker,
242
- feature_extractor: CLIPImageProcessor,
243
- image_encoder: CLIPVisionModelWithProjection = None,
244
- requires_safety_checker: bool = True,
245
- default_denoising_steps: Optional[int] = 10,
246
- default_processing_resolution: Optional[int] = 768,
247
- prompt="The normal map",
248
- empty_text_embedding=None,
249
- ):
250
- super().__init__(
251
- vae,
252
- text_encoder,
253
- tokenizer,
254
- unet,
255
- controlnet,
256
- scheduler,
257
- safety_checker,
258
- feature_extractor,
259
- image_encoder,
260
- requires_safety_checker,
261
- )
262
-
263
- self.register_modules(
264
- dino_controlnet=dino_controlnet,
265
- )
266
-
267
- self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
268
- self.dino_image_processor = lambda x: x / 127.5 -1.
269
-
270
- self.default_denoising_steps = default_denoising_steps
271
- self.default_processing_resolution = default_processing_resolution
272
- self.prompt = prompt
273
- self.prompt_embeds = None
274
- self.empty_text_embedding = empty_text_embedding
275
- self.prior = DINOv2_Encoder(size=672)
276
-
277
- def check_inputs(
278
- self,
279
- image: PipelineImageInput,
280
- num_inference_steps: int,
281
- ensemble_size: int,
282
- processing_resolution: int,
283
- resample_method_input: str,
284
- resample_method_output: str,
285
- batch_size: int,
286
- ensembling_kwargs: Optional[Dict[str, Any]],
287
- latents: Optional[torch.Tensor],
288
- generator: Optional[Union[torch.Generator, List[torch.Generator]]],
289
- output_type: str,
290
- output_uncertainty: bool,
291
- ) -> int:
292
- if num_inference_steps is None:
293
- raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
294
- if num_inference_steps < 1:
295
- raise ValueError("`num_inference_steps` must be positive.")
296
- if ensemble_size < 1:
297
- raise ValueError("`ensemble_size` must be positive.")
298
- if ensemble_size == 2:
299
- logger.warning(
300
- "`ensemble_size` == 2 results are similar to no ensembling (1); "
301
- "consider increasing the value to at least 3."
302
- )
303
- if ensemble_size == 1 and output_uncertainty:
304
- raise ValueError(
305
- "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` "
306
- "greater than 1."
307
- )
308
- if processing_resolution is None:
309
- raise ValueError(
310
- "`processing_resolution` is not specified and could not be resolved from the model config."
311
- )
312
- if processing_resolution < 0:
313
- raise ValueError(
314
- "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for "
315
- "downsampled processing."
316
- )
317
- if processing_resolution % self.vae_scale_factor != 0:
318
- raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.")
319
- if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
320
- raise ValueError(
321
- "`resample_method_input` takes string values compatible with PIL library: "
322
- "nearest, nearest-exact, bilinear, bicubic, area."
323
- )
324
- if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
325
- raise ValueError(
326
- "`resample_method_output` takes string values compatible with PIL library: "
327
- "nearest, nearest-exact, bilinear, bicubic, area."
328
- )
329
- if batch_size < 1:
330
- raise ValueError("`batch_size` must be positive.")
331
- if output_type not in ["pt", "np"]:
332
- raise ValueError("`output_type` must be one of `pt` or `np`.")
333
- if latents is not None and generator is not None:
334
- raise ValueError("`latents` and `generator` cannot be used together.")
335
- if ensembling_kwargs is not None:
336
- if not isinstance(ensembling_kwargs, dict):
337
- raise ValueError("`ensembling_kwargs` must be a dictionary.")
338
- if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"):
339
- raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.")
340
-
341
- # image checks
342
- num_images = 0
343
- W, H = None, None
344
- if not isinstance(image, list):
345
- image = [image]
346
- for i, img in enumerate(image):
347
- if isinstance(img, np.ndarray) or torch.is_tensor(img):
348
- if img.ndim not in (2, 3, 4):
349
- raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.")
350
- H_i, W_i = img.shape[-2:]
351
- N_i = 1
352
- if img.ndim == 4:
353
- N_i = img.shape[0]
354
- elif isinstance(img, Image.Image):
355
- W_i, H_i = img.size
356
- N_i = 1
357
- else:
358
- raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.")
359
- if W is None:
360
- W, H = W_i, H_i
361
- elif (W, H) != (W_i, H_i):
362
- raise ValueError(
363
- f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}"
364
- )
365
- num_images += N_i
366
-
367
- # latents checks
368
- if latents is not None:
369
- if not torch.is_tensor(latents):
370
- raise ValueError("`latents` must be a torch.Tensor.")
371
- if latents.dim() != 4:
372
- raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.")
373
-
374
- if processing_resolution > 0:
375
- max_orig = max(H, W)
376
- new_H = H * processing_resolution // max_orig
377
- new_W = W * processing_resolution // max_orig
378
- if new_H == 0 or new_W == 0:
379
- raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]")
380
- W, H = new_W, new_H
381
- w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor
382
- h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor
383
- shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w)
384
-
385
- if latents.shape != shape_expected:
386
- raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.")
387
-
388
- # generator checks
389
- if generator is not None:
390
- if isinstance(generator, list):
391
- if len(generator) != num_images * ensemble_size:
392
- raise ValueError(
393
- "The number of generators must match the total number of ensemble members for all input images."
394
- )
395
- if not all(g.device.type == generator[0].device.type for g in generator):
396
- raise ValueError("`generator` device placement is not consistent in the list.")
397
- elif not isinstance(generator, torch.Generator):
398
- raise ValueError(f"Unsupported generator type: {type(generator)}.")
399
-
400
- return num_images
401
-
402
- def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
403
- if not hasattr(self, "_progress_bar_config"):
404
- self._progress_bar_config = {}
405
- elif not isinstance(self._progress_bar_config, dict):
406
- raise ValueError(
407
- f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
408
- )
409
-
410
- progress_bar_config = dict(**self._progress_bar_config)
411
- progress_bar_config["desc"] = progress_bar_config.get("desc", desc)
412
- progress_bar_config["leave"] = progress_bar_config.get("leave", leave)
413
- if iterable is not None:
414
- return tqdm(iterable, **progress_bar_config)
415
- elif total is not None:
416
- return tqdm(total=total, **progress_bar_config)
417
- else:
418
- raise ValueError("Either `total` or `iterable` has to be defined.")
419
-
420
- @torch.no_grad()
421
- @replace_example_docstring(EXAMPLE_DOC_STRING)
422
- def __call__(
423
- self,
424
- image: PipelineImageInput,
425
- prompt: Union[str, List[str]] = None,
426
- negative_prompt: Optional[Union[str, List[str]]] = None,
427
- num_inference_steps: Optional[int] = None,
428
- ensemble_size: int = 1,
429
- processing_resolution: Optional[int] = None,
430
- match_input_resolution: bool = True,
431
- resample_method_input: str = "bilinear",
432
- resample_method_output: str = "bilinear",
433
- batch_size: int = 1,
434
- ensembling_kwargs: Optional[Dict[str, Any]] = None,
435
- latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
436
- prompt_embeds: Optional[torch.Tensor] = None,
437
- negative_prompt_embeds: Optional[torch.Tensor] = None,
438
- num_images_per_prompt: Optional[int] = 1,
439
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
440
- controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
441
- output_type: str = "np",
442
- output_uncertainty: bool = False,
443
- output_latent: bool = False,
444
- return_dict: bool = True,
445
- ):
446
- """
447
- Function invoked when calling the pipeline.
448
-
449
- Args:
450
- image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`),
451
- `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For
452
- arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible
453
- by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
454
- three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
455
- same width and height.
456
- num_inference_steps (`int`, *optional*, defaults to `None`):
457
- Number of denoising diffusion steps during inference. The default value `None` results in automatic
458
- selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
459
- for Marigold-LCM models.
460
- ensemble_size (`int`, defaults to `1`):
461
- Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for
462
- faster inference.
463
- processing_resolution (`int`, *optional*, defaults to `None`):
464
- Effective processing resolution. When set to `0`, matches the larger input image dimension. This
465
- produces crisper predictions, but may also lead to the overall loss of global context. The default
466
- value `None` resolves to the optimal value from the model config.
467
- match_input_resolution (`bool`, *optional*, defaults to `True`):
468
- When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
469
- side of the output will equal to `processing_resolution`.
470
- resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
471
- Resampling method used to resize input images to `processing_resolution`. The accepted values are:
472
- `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
473
- resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
474
- Resampling method used to resize output predictions to match the input resolution. The accepted values
475
- are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
476
- batch_size (`int`, *optional*, defaults to `1`):
477
- Batch size; only matters when setting `ensemble_size` or passing a tensor of images.
478
- ensembling_kwargs (`dict`, *optional*, defaults to `None`)
479
- Extra dictionary with arguments for precise ensembling control. The following options are available:
480
- - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in
481
- every pixel location, can be either `"closest"` or `"mean"`.
482
- latents (`torch.Tensor`, *optional*, defaults to `None`):
483
- Latent noise tensors to replace the random initialization. These can be taken from the previous
484
- function call's output.
485
- generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`):
486
- Random number generator object to ensure reproducibility.
487
- output_type (`str`, *optional*, defaults to `"np"`):
488
- Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted
489
- values are: `"np"` (numpy array) or `"pt"` (torch tensor).
490
- output_uncertainty (`bool`, *optional*, defaults to `False`):
491
- When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that
492
- the `ensemble_size` argument is set to a value above 2.
493
- output_latent (`bool`, *optional*, defaults to `False`):
494
- When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
495
- within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
496
- `latents` argument.
497
- return_dict (`bool`, *optional*, defaults to `True`):
498
- Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple.
499
-
500
- Examples:
501
-
502
- Returns:
503
- [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`:
504
- If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a
505
- `tuple` is returned where the first element is the prediction, the second element is the uncertainty
506
- (or `None`), and the third is the latent (or `None`).
507
- """
508
-
509
- # 0. Resolving variables.
510
- device = self._execution_device
511
- dtype = self.dtype
512
-
513
- # Model-specific optimal default values leading to fast and reasonable results.
514
- if num_inference_steps is None:
515
- num_inference_steps = self.default_denoising_steps
516
- if processing_resolution is None:
517
- processing_resolution = self.default_processing_resolution
518
-
519
-
520
- image, padding, original_resolution = self.image_processor.preprocess(
521
- image, processing_resolution, resample_method_input, device, dtype
522
- ) # [N,3,PPH,PPW]
523
-
524
- image_latent, gaus_noise = self.prepare_latents(
525
- image, latents, generator, ensemble_size, batch_size
526
- ) # [N,4,h,w], [N,4,h,w]
527
-
528
- # 0. X_start latent obtain
529
- predictor = self.x_start_pipeline(image, latents=gaus_noise,
530
- processing_resolution=processing_resolution, skip_preprocess=True)
531
- x_start_latent = predictor.latent
532
-
533
- # 1. Check inputs.
534
- num_images = self.check_inputs(
535
- image,
536
- num_inference_steps,
537
- ensemble_size,
538
- processing_resolution,
539
- resample_method_input,
540
- resample_method_output,
541
- batch_size,
542
- ensembling_kwargs,
543
- latents,
544
- generator,
545
- output_type,
546
- output_uncertainty,
547
- )
548
-
549
-
550
- # 2. Prepare empty text conditioning.
551
- # Model invocation: self.tokenizer, self.text_encoder.
552
- if self.empty_text_embedding is None:
553
- prompt = ""
554
- text_inputs = self.tokenizer(
555
- prompt,
556
- padding="do_not_pad",
557
- max_length=self.tokenizer.model_max_length,
558
- truncation=True,
559
- return_tensors="pt",
560
- )
561
- text_input_ids = text_inputs.input_ids.to(device)
562
- self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024]
563
-
564
-
565
-
566
- # 3. prepare prompt
567
- if self.prompt_embeds is None:
568
- prompt_embeds, negative_prompt_embeds = self.encode_prompt(
569
- self.prompt,
570
- device,
571
- num_images_per_prompt,
572
- False,
573
- negative_prompt,
574
- prompt_embeds=prompt_embeds,
575
- negative_prompt_embeds=None,
576
- lora_scale=None,
577
- clip_skip=None,
578
- )
579
- self.prompt_embeds = prompt_embeds
580
- self.negative_prompt_embeds = negative_prompt_embeds
581
-
582
-
583
-
584
- # 5. dino guider features obtaining
585
- ## TODO different case-1
586
- dino_features = self.prior(image)
587
- dino_features = self.dino_controlnet.dino_controlnet_cond_embedding(dino_features)
588
- dino_features = self.match_noisy(dino_features, x_start_latent)
589
-
590
- del (
591
- image,
592
- )
593
-
594
- # 7. denoise sampling, using heuritic sampling proposed by Ye.
595
-
596
- t_start = self.x_start_pipeline.t_start
597
- self.scheduler.set_timesteps(num_inference_steps, t_start=t_start,device=device)
598
-
599
- cond_scale =controlnet_conditioning_scale
600
- pred_latent = x_start_latent
601
-
602
- cur_step = 0
603
-
604
- # dino controlnet
605
- dino_down_block_res_samples, dino_mid_block_res_sample = self.dino_controlnet(
606
- dino_features.detach(),
607
- 0, # not depend on time steps
608
- encoder_hidden_states=self.prompt_embeds,
609
- conditioning_scale=cond_scale,
610
- guess_mode=False,
611
- return_dict=False,
612
- )
613
- assert dino_mid_block_res_sample == None
614
-
615
- pred_latents = []
616
-
617
- last_pred_latent = pred_latent
618
- for (t, prev_t) in self.progress_bar(zip(self.scheduler.timesteps,self.scheduler.prev_timesteps), leave=False, desc="Diffusion steps..."):
619
-
620
- _dino_down_block_res_samples = [dino_down_block_res_sample for dino_down_block_res_sample in dino_down_block_res_samples] # copy, avoid repeat quiery
621
-
622
- # controlnet
623
- down_block_res_samples, mid_block_res_sample = self.controlnet(
624
- image_latent.detach(),
625
- t,
626
- encoder_hidden_states=self.prompt_embeds,
627
- conditioning_scale=cond_scale,
628
- guess_mode=False,
629
- return_dict=False,
630
- )
631
-
632
- # SG-DRN
633
- noise = self.dino_unet_forward(
634
- self.unet,
635
- pred_latent,
636
- t,
637
- encoder_hidden_states=self.prompt_embeds,
638
- down_block_additional_residuals=down_block_res_samples,
639
- mid_block_additional_residual=mid_block_res_sample,
640
- dino_down_block_additional_residuals= _dino_down_block_res_samples,
641
- return_dict=False,
642
- )[0] # [B,4,h,w]
643
-
644
- pred_latents.append(noise)
645
- # ddim steps
646
- out = self.scheduler.step(
647
- noise, t, prev_t, pred_latent, gaus_noise = gaus_noise, generator=generator, cur_step=cur_step+1 # NOTE that cur_step dirs to next_step
648
- )# [B,4,h,w]
649
- pred_latent = out.prev_sample
650
-
651
- cur_step += 1
652
-
653
- del (
654
- image_latent,
655
- dino_features,
656
- )
657
- pred_latent = pred_latents[-1] # using x0
658
-
659
- # decoder
660
- prediction = self.decode_prediction(pred_latent)
661
- prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW]
662
- prediction = self.image_processor.resize_antialias(prediction, original_resolution, resample_method_output, is_aa=False) # [N,3,H,W]
663
-
664
- if match_input_resolution:
665
- prediction = self.image_processor.resize_antialias(
666
- prediction, original_resolution, resample_method_output, is_aa=False
667
- ) # [N,3,H,W]
668
-
669
- if match_input_resolution:
670
- prediction = self.image_processor.resize_antialias(
671
- prediction, original_resolution, resample_method_output, is_aa=False
672
- ) # [N,3,H,W]
673
- prediction = self.normalize_normals(prediction) # [N,3,H,W]
674
-
675
- if output_type == "np":
676
- prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3]
677
- prediction = prediction.clip(min=-1, max=1)
678
-
679
- # 11. Offload all models
680
- self.maybe_free_model_hooks()
681
-
682
- return StableNormalOutput(
683
- prediction=prediction,
684
- latent=pred_latent,
685
- gaus_noise=gaus_noise
686
- )
687
-
688
- # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents
689
- def prepare_latents(
690
- self,
691
- image: torch.Tensor,
692
- latents: Optional[torch.Tensor],
693
- generator: Optional[torch.Generator],
694
- ensemble_size: int,
695
- batch_size: int,
696
- ) -> Tuple[torch.Tensor, torch.Tensor]:
697
- def retrieve_latents(encoder_output):
698
- if hasattr(encoder_output, "latent_dist"):
699
- return encoder_output.latent_dist.mode()
700
- elif hasattr(encoder_output, "latents"):
701
- return encoder_output.latents
702
- else:
703
- raise AttributeError("Could not access latents of provided encoder_output")
704
-
705
-
706
-
707
- image_latent = torch.cat(
708
- [
709
- retrieve_latents(self.vae.encode(image[i : i + batch_size]))
710
- for i in range(0, image.shape[0], batch_size)
711
- ],
712
- dim=0,
713
- ) # [N,4,h,w]
714
- image_latent = image_latent * self.vae.config.scaling_factor
715
- image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w]
716
-
717
- pred_latent = latents
718
- if pred_latent is None:
719
-
720
-
721
- pred_latent = randn_tensor(
722
- image_latent.shape,
723
- generator=generator,
724
- device=image_latent.device,
725
- dtype=image_latent.dtype,
726
- ) # [N*E,4,h,w]
727
-
728
- return image_latent, pred_latent
729
-
730
- def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor:
731
- if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels:
732
- raise ValueError(
733
- f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}."
734
- )
735
-
736
- prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W]
737
-
738
- return prediction # [B,3,H,W]
739
-
740
- @staticmethod
741
- def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
742
- if normals.dim() != 4 or normals.shape[1] != 3:
743
- raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
744
-
745
- norm = torch.norm(normals, dim=1, keepdim=True)
746
- normals /= norm.clamp(min=eps)
747
-
748
- return normals
749
-
750
- @staticmethod
751
- def match_noisy(dino, noisy):
752
- _, __, dino_h, dino_w = dino.shape
753
- _, __, h, w = noisy.shape
754
-
755
- if h == dino_h and w == dino_w:
756
- return dino
757
- else:
758
- return F.interpolate(dino, (h, w), mode='bilinear')
759
-
760
-
761
-
762
-
763
-
764
-
765
-
766
-
767
-
768
-
769
- @staticmethod
770
- def dino_unet_forward(
771
- self, # NOTE that repurpose to UNet
772
- sample: torch.Tensor,
773
- timestep: Union[torch.Tensor, float, int],
774
- encoder_hidden_states: torch.Tensor,
775
- class_labels: Optional[torch.Tensor] = None,
776
- timestep_cond: Optional[torch.Tensor] = None,
777
- attention_mask: Optional[torch.Tensor] = None,
778
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
779
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
780
- down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
781
- mid_block_additional_residual: Optional[torch.Tensor] = None,
782
- dino_down_block_additional_residuals: Optional[torch.Tensor] = None,
783
- down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
784
- encoder_attention_mask: Optional[torch.Tensor] = None,
785
- return_dict: bool = True,
786
- ) -> Union[UNet2DConditionOutput, Tuple]:
787
- r"""
788
- The [`UNet2DConditionModel`] forward method.
789
-
790
- Args:
791
- sample (`torch.Tensor`):
792
- The noisy input tensor with the following shape `(batch, channel, height, width)`.
793
- timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
794
- encoder_hidden_states (`torch.Tensor`):
795
- The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
796
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
797
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
798
- timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
799
- Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
800
- through the `self.time_embedding` layer to obtain the timestep embeddings.
801
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
802
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
803
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
804
- negative values to the attention scores corresponding to "discard" tokens.
805
- cross_attention_kwargs (`dict`, *optional*):
806
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
807
- `self.processor` in
808
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
809
- added_cond_kwargs: (`dict`, *optional*):
810
- A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
811
- are passed along to the UNet blocks.
812
- down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
813
- A tuple of tensors that if specified are added to the residuals of down unet blocks.
814
- mid_block_additional_residual: (`torch.Tensor`, *optional*):
815
- A tensor that if specified is added to the residual of the middle unet block.
816
- down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
817
- additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
818
- encoder_attention_mask (`torch.Tensor`):
819
- A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
820
- `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
821
- which adds large negative values to the attention scores corresponding to "discard" tokens.
822
- return_dict (`bool`, *optional*, defaults to `True`):
823
- Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
824
- tuple.
825
-
826
- Returns:
827
- [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
828
- If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
829
- otherwise a `tuple` is returned where the first element is the sample tensor.
830
- """
831
- # By default samples have to be AT least a multiple of the overall upsampling factor.
832
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
833
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
834
- # on the fly if necessary.
835
-
836
-
837
- default_overall_up_factor = 2**self.num_upsamplers
838
-
839
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
840
- forward_upsample_size = False
841
- upsample_size = None
842
-
843
- for dim in sample.shape[-2:]:
844
- if dim % default_overall_up_factor != 0:
845
- # Forward upsample size to force interpolation output size.
846
- forward_upsample_size = True
847
- break
848
-
849
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
850
- # expects mask of shape:
851
- # [batch, key_tokens]
852
- # adds singleton query_tokens dimension:
853
- # [batch, 1, key_tokens]
854
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
855
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
856
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
857
- if attention_mask is not None:
858
- # assume that mask is expressed as:
859
- # (1 = keep, 0 = discard)
860
- # convert mask into a bias that can be added to attention scores:
861
- # (keep = +0, discard = -10000.0)
862
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
863
- attention_mask = attention_mask.unsqueeze(1)
864
-
865
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
866
- if encoder_attention_mask is not None:
867
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
868
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
869
-
870
- # 0. center input if necessary
871
- if self.config.center_input_sample:
872
- sample = 2 * sample - 1.0
873
-
874
- # 1. time
875
- t_emb = self.get_time_embed(sample=sample, timestep=timestep)
876
- emb = self.time_embedding(t_emb, timestep_cond)
877
- aug_emb = None
878
-
879
- class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
880
- if class_emb is not None:
881
- if self.config.class_embeddings_concat:
882
- emb = torch.cat([emb, class_emb], dim=-1)
883
- else:
884
- emb = emb + class_emb
885
-
886
- aug_emb = self.get_aug_embed(
887
- emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
888
- )
889
- if self.config.addition_embed_type == "image_hint":
890
- aug_emb, hint = aug_emb
891
- sample = torch.cat([sample, hint], dim=1)
892
-
893
- emb = emb + aug_emb if aug_emb is not None else emb
894
-
895
- if self.time_embed_act is not None:
896
- emb = self.time_embed_act(emb)
897
-
898
- encoder_hidden_states = self.process_encoder_hidden_states(
899
- encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
900
- )
901
-
902
- # 2. pre-process
903
- sample = self.conv_in(sample)
904
-
905
- # 2.5 GLIGEN position net
906
- if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
907
- cross_attention_kwargs = cross_attention_kwargs.copy()
908
- gligen_args = cross_attention_kwargs.pop("gligen")
909
- cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
910
-
911
- # 3. down
912
- # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
913
- # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
914
- if cross_attention_kwargs is not None:
915
- cross_attention_kwargs = cross_attention_kwargs.copy()
916
- lora_scale = cross_attention_kwargs.pop("scale", 1.0)
917
- else:
918
- lora_scale = 1.0
919
-
920
- if USE_PEFT_BACKEND:
921
- # weight the lora layers by setting `lora_scale` for each PEFT layer
922
- scale_lora_layers(self, lora_scale)
923
-
924
- is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
925
- # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
926
- is_adapter = down_intrablock_additional_residuals is not None
927
- # maintain backward compatibility for legacy usage, where
928
- # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
929
- # but can only use one or the other
930
- if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
931
- deprecate(
932
- "T2I should not use down_block_additional_residuals",
933
- "1.3.0",
934
- "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
935
- and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
936
- for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
937
- standard_warn=False,
938
- )
939
- down_intrablock_additional_residuals = down_block_additional_residuals
940
- is_adapter = True
941
-
942
-
943
-
944
- def residual_downforward(
945
- self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None,
946
- additional_residuals: Optional[torch.Tensor] = None,
947
- *args, **kwargs,
948
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
949
- if len(args) > 0 or kwargs.get("scale", None) is not None:
950
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
951
- deprecate("scale", "1.0.0", deprecation_message)
952
-
953
- output_states = ()
954
-
955
- for resnet in self.resnets:
956
- if self.training and self.gradient_checkpointing:
957
-
958
- def create_custom_forward(module):
959
- def custom_forward(*inputs):
960
- return module(*inputs)
961
-
962
- return custom_forward
963
-
964
- if is_torch_version(">=", "1.11.0"):
965
- hidden_states = torch.utils.checkpoint.checkpoint(
966
- create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
967
- )
968
- else:
969
- hidden_states = torch.utils.checkpoint.checkpoint(
970
- create_custom_forward(resnet), hidden_states, temb
971
- )
972
- else:
973
- hidden_states = resnet(hidden_states, temb)
974
- hidden_states += additional_residuals.pop(0)
975
-
976
-
977
- output_states = output_states + (hidden_states,)
978
-
979
- if self.downsamplers is not None:
980
- for downsampler in self.downsamplers:
981
- hidden_states = downsampler(hidden_states)
982
- hidden_states += additional_residuals.pop(0)
983
-
984
- output_states = output_states + (hidden_states,)
985
-
986
- return hidden_states, output_states
987
-
988
-
989
- def residual_blockforward(
990
- self, ## NOTE that repurpose to unet_blocks
991
- hidden_states: torch.Tensor,
992
- temb: Optional[torch.Tensor] = None,
993
- encoder_hidden_states: Optional[torch.Tensor] = None,
994
- attention_mask: Optional[torch.Tensor] = None,
995
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
996
- encoder_attention_mask: Optional[torch.Tensor] = None,
997
- additional_residuals: Optional[torch.Tensor] = None,
998
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
999
- if cross_attention_kwargs is not None:
1000
- if cross_attention_kwargs.get("scale", None) is not None:
1001
- logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1002
-
1003
-
1004
-
1005
- output_states = ()
1006
-
1007
- blocks = list(zip(self.resnets, self.attentions))
1008
-
1009
- for i, (resnet, attn) in enumerate(blocks):
1010
- if self.training and self.gradient_checkpointing:
1011
-
1012
- def create_custom_forward(module, return_dict=None):
1013
- def custom_forward(*inputs):
1014
- if return_dict is not None:
1015
- return module(*inputs, return_dict=return_dict)
1016
- else:
1017
- return module(*inputs)
1018
-
1019
- return custom_forward
1020
-
1021
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1022
- hidden_states = torch.utils.checkpoint.checkpoint(
1023
- create_custom_forward(resnet),
1024
- hidden_states,
1025
- temb,
1026
- **ckpt_kwargs,
1027
- )
1028
- hidden_states = attn(
1029
- hidden_states,
1030
- encoder_hidden_states=encoder_hidden_states,
1031
- cross_attention_kwargs=cross_attention_kwargs,
1032
- attention_mask=attention_mask,
1033
- encoder_attention_mask=encoder_attention_mask,
1034
- return_dict=False,
1035
- )[0]
1036
- else:
1037
- hidden_states = resnet(hidden_states, temb)
1038
- hidden_states = attn(
1039
- hidden_states,
1040
- encoder_hidden_states=encoder_hidden_states,
1041
- cross_attention_kwargs=cross_attention_kwargs,
1042
- attention_mask=attention_mask,
1043
- encoder_attention_mask=encoder_attention_mask,
1044
- return_dict=False,
1045
- )[0]
1046
-
1047
- hidden_states += additional_residuals.pop(0)
1048
-
1049
- output_states = output_states + (hidden_states,)
1050
-
1051
- if self.downsamplers is not None:
1052
- for downsampler in self.downsamplers:
1053
- hidden_states = downsampler(hidden_states)
1054
- hidden_states += additional_residuals.pop(0)
1055
-
1056
- output_states = output_states + (hidden_states,)
1057
-
1058
- return hidden_states, output_states
1059
-
1060
-
1061
- down_intrablock_additional_residuals = dino_down_block_additional_residuals
1062
-
1063
- sample += down_intrablock_additional_residuals.pop(0)
1064
- down_block_res_samples = (sample,)
1065
-
1066
- for downsample_block in self.down_blocks:
1067
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1068
-
1069
- sample, res_samples = residual_blockforward(
1070
- downsample_block,
1071
- hidden_states=sample,
1072
- temb=emb,
1073
- encoder_hidden_states=encoder_hidden_states,
1074
- attention_mask=attention_mask,
1075
- cross_attention_kwargs=cross_attention_kwargs,
1076
- encoder_attention_mask=encoder_attention_mask,
1077
- additional_residuals = down_intrablock_additional_residuals,
1078
- )
1079
-
1080
- else:
1081
- sample, res_samples = residual_downforward(
1082
- downsample_block,
1083
- hidden_states=sample,
1084
- temb=emb,
1085
- additional_residuals = down_intrablock_additional_residuals,
1086
- )
1087
-
1088
-
1089
- down_block_res_samples += res_samples
1090
-
1091
-
1092
- if is_controlnet:
1093
- new_down_block_res_samples = ()
1094
-
1095
- for down_block_res_sample, down_block_additional_residual in zip(
1096
- down_block_res_samples, down_block_additional_residuals
1097
- ):
1098
- down_block_res_sample = down_block_res_sample + down_block_additional_residual
1099
- new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1100
-
1101
- down_block_res_samples = new_down_block_res_samples
1102
-
1103
- # 4. mid
1104
- if self.mid_block is not None:
1105
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1106
- sample = self.mid_block(
1107
- sample,
1108
- emb,
1109
- encoder_hidden_states=encoder_hidden_states,
1110
- attention_mask=attention_mask,
1111
- cross_attention_kwargs=cross_attention_kwargs,
1112
- encoder_attention_mask=encoder_attention_mask,
1113
- )
1114
- else:
1115
- sample = self.mid_block(sample, emb)
1116
-
1117
- # To support T2I-Adapter-XL
1118
- if (
1119
- is_adapter
1120
- and len(down_intrablock_additional_residuals) > 0
1121
- and sample.shape == down_intrablock_additional_residuals[0].shape
1122
- ):
1123
- sample += down_intrablock_additional_residuals.pop(0)
1124
-
1125
- if is_controlnet:
1126
- sample = sample + mid_block_additional_residual
1127
-
1128
- # 5. up
1129
- for i, upsample_block in enumerate(self.up_blocks):
1130
- is_final_block = i == len(self.up_blocks) - 1
1131
-
1132
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1133
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1134
-
1135
- # if we have not reached the final block and need to forward the
1136
- # upsample size, we do it here
1137
- if not is_final_block and forward_upsample_size:
1138
- upsample_size = down_block_res_samples[-1].shape[2:]
1139
-
1140
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1141
- sample = upsample_block(
1142
- hidden_states=sample,
1143
- temb=emb,
1144
- res_hidden_states_tuple=res_samples,
1145
- encoder_hidden_states=encoder_hidden_states,
1146
- cross_attention_kwargs=cross_attention_kwargs,
1147
- upsample_size=upsample_size,
1148
- attention_mask=attention_mask,
1149
- encoder_attention_mask=encoder_attention_mask,
1150
- )
1151
- else:
1152
- sample = upsample_block(
1153
- hidden_states=sample,
1154
- temb=emb,
1155
- res_hidden_states_tuple=res_samples,
1156
- upsample_size=upsample_size,
1157
- )
1158
-
1159
- # 6. post-process
1160
- if self.conv_norm_out:
1161
- sample = self.conv_norm_out(sample)
1162
- sample = self.conv_act(sample)
1163
- sample = self.conv_out(sample)
1164
-
1165
- if USE_PEFT_BACKEND:
1166
- # remove `lora_scale` from each PEFT layer
1167
- unscale_lora_layers(self, lora_scale)
1168
-
1169
- if not return_dict:
1170
- return (sample,)
1171
-
1172
- return UNet2DConditionOutput(sample=sample)
1173
-
1174
-
1175
-
1176
- @staticmethod
1177
- def ensemble_normals(
1178
- normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest"
1179
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
1180
- """
1181
- Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is
1182
- the number of ensemble members for a given prediction of size `(H x W)`.
1183
-
1184
- Args:
1185
- normals (`torch.Tensor`):
1186
- Input ensemble normals maps.
1187
- output_uncertainty (`bool`, *optional*, defaults to `False`):
1188
- Whether to output uncertainty map.
1189
- reduction (`str`, *optional*, defaults to `"closest"`):
1190
- Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and
1191
- `"mean"`.
1192
-
1193
- Returns:
1194
- A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of
1195
- uncertainties of shape `(1, 1, H, W)`.
1196
- """
1197
- if normals.dim() != 4 or normals.shape[1] != 3:
1198
- raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
1199
- if reduction not in ("closest", "mean"):
1200
- raise ValueError(f"Unrecognized reduction method: {reduction}.")
1201
-
1202
- mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W]
1203
- mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W]
1204
-
1205
- sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W]
1206
- sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16
1207
-
1208
- uncertainty = None
1209
- if output_uncertainty:
1210
- uncertainty = sim_cos.arccos() # [E,1,H,W]
1211
- uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W]
1212
-
1213
- if reduction == "mean":
1214
- return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W]
1215
-
1216
- closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W]
1217
- closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W]
1218
- closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W]
1219
-
1220
- return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W]
1221
-
1222
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
1223
- def retrieve_timesteps(
1224
- scheduler,
1225
- num_inference_steps: Optional[int] = None,
1226
- device: Optional[Union[str, torch.device]] = None,
1227
- timesteps: Optional[List[int]] = None,
1228
- sigmas: Optional[List[float]] = None,
1229
- **kwargs,
1230
- ):
1231
- """
1232
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
1233
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
1234
-
1235
- Args:
1236
- scheduler (`SchedulerMixin`):
1237
- The scheduler to get timesteps from.
1238
- num_inference_steps (`int`):
1239
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
1240
- must be `None`.
1241
- device (`str` or `torch.device`, *optional*):
1242
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
1243
- timesteps (`List[int]`, *optional*):
1244
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
1245
- `num_inference_steps` and `sigmas` must be `None`.
1246
- sigmas (`List[float]`, *optional*):
1247
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
1248
- `num_inference_steps` and `timesteps` must be `None`.
1249
-
1250
- Returns:
1251
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
1252
- second element is the number of inference steps.
1253
- """
1254
- if timesteps is not None and sigmas is not None:
1255
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
1256
- if timesteps is not None:
1257
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
1258
- if not accepts_timesteps:
1259
- raise ValueError(
1260
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
1261
- f" timestep schedules. Please check whether you are using the correct scheduler."
1262
- )
1263
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
1264
- timesteps = scheduler.timesteps
1265
- num_inference_steps = len(timesteps)
1266
- elif sigmas is not None:
1267
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
1268
- if not accept_sigmas:
1269
- raise ValueError(
1270
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
1271
- f" sigmas schedules. Please check whether you are using the correct scheduler."
1272
- )
1273
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
1274
- timesteps = scheduler.timesteps
1275
- num_inference_steps = len(timesteps)
1276
- else:
1277
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
1278
- timesteps = scheduler.timesteps
1279
- return timesteps, num_inference_steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stablenormal/pipeline_yoso_normal.py DELETED
@@ -1,727 +0,0 @@
1
- # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
2
- # Copyright 2024 The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # --------------------------------------------------------------------------
16
- # More information and citation instructions are available on the
17
- # --------------------------------------------------------------------------
18
- from dataclasses import dataclass
19
- from typing import Any, Dict, List, Optional, Tuple, Union
20
-
21
- import numpy as np
22
- import torch
23
- from PIL import Image
24
- from tqdm.auto import tqdm
25
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
26
-
27
-
28
- from diffusers.image_processor import PipelineImageInput
29
- from diffusers.models import (
30
- AutoencoderKL,
31
- UNet2DConditionModel,
32
- ControlNetModel,
33
- )
34
- from diffusers.schedulers import (
35
- DDIMScheduler
36
- )
37
-
38
- from diffusers.utils import (
39
- BaseOutput,
40
- logging,
41
- replace_example_docstring,
42
- )
43
-
44
-
45
- from diffusers.utils.torch_utils import randn_tensor
46
- from diffusers.pipelines.controlnet import StableDiffusionControlNetPipeline
47
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
48
- from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor
49
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
50
-
51
- import pdb
52
-
53
-
54
-
55
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
-
57
-
58
- EXAMPLE_DOC_STRING = """
59
- Examples:
60
- ```py
61
- >>> import diffusers
62
- >>> import torch
63
-
64
- >>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(
65
- ... "prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16
66
- ... ).to("cuda")
67
-
68
- >>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
69
- >>> normals = pipe(image)
70
-
71
- >>> vis = pipe.image_processor.visualize_normals(normals.prediction)
72
- >>> vis[0].save("einstein_normals.png")
73
- ```
74
- """
75
-
76
-
77
- @dataclass
78
- class YosoNormalsOutput(BaseOutput):
79
- """
80
- Output class for Marigold monocular normals prediction pipeline.
81
-
82
- Args:
83
- prediction (`np.ndarray`, `torch.Tensor`):
84
- Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height
85
- \times width$, regardless of whether the images were passed as a 4D array or a list.
86
- uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
87
- Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
88
- \times 1 \times height \times width$.
89
- latent (`None`, `torch.Tensor`):
90
- Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
91
- The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
92
- """
93
-
94
- prediction: Union[np.ndarray, torch.Tensor]
95
- latent: Union[None, torch.Tensor]
96
- gaus_noise: Union[None, torch.Tensor]
97
-
98
-
99
- class YOSONormalsPipeline(StableDiffusionControlNetPipeline):
100
- """ Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io.
101
- Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
102
-
103
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
104
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
105
-
106
- The pipeline also inherits the following loading methods:
107
- - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
108
- - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
109
- - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
110
- - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
111
- - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
112
-
113
- Args:
114
- vae ([`AutoencoderKL`]):
115
- Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
116
- text_encoder ([`~transformers.CLIPTextModel`]):
117
- Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
118
- tokenizer ([`~transformers.CLIPTokenizer`]):
119
- A `CLIPTokenizer` to tokenize text.
120
- unet ([`UNet2DConditionModel`]):
121
- A `UNet2DConditionModel` to denoise the encoded image latents.
122
- controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
123
- Provides additional conditioning to the `unet` during the denoising process. If you set multiple
124
- ControlNets as a list, the outputs from each ControlNet are added together to create one combined
125
- additional conditioning.
126
- scheduler ([`SchedulerMixin`]):
127
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
128
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
129
- safety_checker ([`StableDiffusionSafetyChecker`]):
130
- Classification module that estimates whether generated images could be considered offensive or harmful.
131
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
132
- about a model's potential harms.
133
- feature_extractor ([`~transformers.CLIPImageProcessor`]):
134
- A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
135
- """
136
-
137
- model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
138
- _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
139
- _exclude_from_cpu_offload = ["safety_checker"]
140
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
141
-
142
-
143
-
144
- def __init__(
145
- self,
146
- vae: AutoencoderKL,
147
- text_encoder: CLIPTextModel,
148
- tokenizer: CLIPTokenizer,
149
- unet: UNet2DConditionModel,
150
- controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel]],
151
- scheduler: Union[DDIMScheduler],
152
- safety_checker: StableDiffusionSafetyChecker,
153
- feature_extractor: CLIPImageProcessor,
154
- image_encoder: CLIPVisionModelWithProjection = None,
155
- requires_safety_checker: bool = True,
156
- default_denoising_steps: Optional[int] = 1,
157
- default_processing_resolution: Optional[int] = 768,
158
- prompt="",
159
- empty_text_embedding=None,
160
- t_start: Optional[int] = 401,
161
- ):
162
- super().__init__(
163
- vae,
164
- text_encoder,
165
- tokenizer,
166
- unet,
167
- controlnet,
168
- scheduler,
169
- safety_checker,
170
- feature_extractor,
171
- image_encoder,
172
- requires_safety_checker,
173
- )
174
-
175
- # TODO yoso ImageProcessor
176
- self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
177
- self.control_image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
178
- self.default_denoising_steps = default_denoising_steps
179
- self.default_processing_resolution = default_processing_resolution
180
- self.prompt = prompt
181
- self.prompt_embeds = None
182
- self.empty_text_embedding = empty_text_embedding
183
- self.t_start= t_start # target_out latents
184
-
185
- def check_inputs(
186
- self,
187
- image: PipelineImageInput,
188
- num_inference_steps: int,
189
- ensemble_size: int,
190
- processing_resolution: int,
191
- resample_method_input: str,
192
- resample_method_output: str,
193
- batch_size: int,
194
- ensembling_kwargs: Optional[Dict[str, Any]],
195
- latents: Optional[torch.Tensor],
196
- generator: Optional[Union[torch.Generator, List[torch.Generator]]],
197
- output_type: str,
198
- output_uncertainty: bool,
199
- ) -> int:
200
- if num_inference_steps is None:
201
- raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.")
202
- if num_inference_steps < 1:
203
- raise ValueError("`num_inference_steps` must be positive.")
204
- if ensemble_size < 1:
205
- raise ValueError("`ensemble_size` must be positive.")
206
- if ensemble_size == 2:
207
- logger.warning(
208
- "`ensemble_size` == 2 results are similar to no ensembling (1); "
209
- "consider increasing the value to at least 3."
210
- )
211
- if ensemble_size == 1 and output_uncertainty:
212
- raise ValueError(
213
- "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` "
214
- "greater than 1."
215
- )
216
- if processing_resolution is None:
217
- raise ValueError(
218
- "`processing_resolution` is not specified and could not be resolved from the model config."
219
- )
220
- if processing_resolution < 0:
221
- raise ValueError(
222
- "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for "
223
- "downsampled processing."
224
- )
225
- if processing_resolution % self.vae_scale_factor != 0:
226
- raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.")
227
- if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
228
- raise ValueError(
229
- "`resample_method_input` takes string values compatible with PIL library: "
230
- "nearest, nearest-exact, bilinear, bicubic, area."
231
- )
232
- if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
233
- raise ValueError(
234
- "`resample_method_output` takes string values compatible with PIL library: "
235
- "nearest, nearest-exact, bilinear, bicubic, area."
236
- )
237
- if batch_size < 1:
238
- raise ValueError("`batch_size` must be positive.")
239
- if output_type not in ["pt", "np"]:
240
- raise ValueError("`output_type` must be one of `pt` or `np`.")
241
- if latents is not None and generator is not None:
242
- raise ValueError("`latents` and `generator` cannot be used together.")
243
- if ensembling_kwargs is not None:
244
- if not isinstance(ensembling_kwargs, dict):
245
- raise ValueError("`ensembling_kwargs` must be a dictionary.")
246
- if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"):
247
- raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.")
248
-
249
- # image checks
250
- num_images = 0
251
- W, H = None, None
252
- if not isinstance(image, list):
253
- image = [image]
254
- for i, img in enumerate(image):
255
- if isinstance(img, np.ndarray) or torch.is_tensor(img):
256
- if img.ndim not in (2, 3, 4):
257
- raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.")
258
- H_i, W_i = img.shape[-2:]
259
- N_i = 1
260
- if img.ndim == 4:
261
- N_i = img.shape[0]
262
- elif isinstance(img, Image.Image):
263
- W_i, H_i = img.size
264
- N_i = 1
265
- else:
266
- raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.")
267
- if W is None:
268
- W, H = W_i, H_i
269
- elif (W, H) != (W_i, H_i):
270
- raise ValueError(
271
- f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}"
272
- )
273
- num_images += N_i
274
-
275
- # latents checks
276
- if latents is not None:
277
- if not torch.is_tensor(latents):
278
- raise ValueError("`latents` must be a torch.Tensor.")
279
- if latents.dim() != 4:
280
- raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.")
281
-
282
- if processing_resolution > 0:
283
- max_orig = max(H, W)
284
- new_H = H * processing_resolution // max_orig
285
- new_W = W * processing_resolution // max_orig
286
- if new_H == 0 or new_W == 0:
287
- raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]")
288
- W, H = new_W, new_H
289
- w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor
290
- h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor
291
- shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w)
292
-
293
- if latents.shape != shape_expected:
294
- raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.")
295
-
296
- # generator checks
297
- if generator is not None:
298
- if isinstance(generator, list):
299
- if len(generator) != num_images * ensemble_size:
300
- raise ValueError(
301
- "The number of generators must match the total number of ensemble members for all input images."
302
- )
303
- if not all(g.device.type == generator[0].device.type for g in generator):
304
- raise ValueError("`generator` device placement is not consistent in the list.")
305
- elif not isinstance(generator, torch.Generator):
306
- raise ValueError(f"Unsupported generator type: {type(generator)}.")
307
-
308
- return num_images
309
-
310
- def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
311
- if not hasattr(self, "_progress_bar_config"):
312
- self._progress_bar_config = {}
313
- elif not isinstance(self._progress_bar_config, dict):
314
- raise ValueError(
315
- f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
316
- )
317
-
318
- progress_bar_config = dict(**self._progress_bar_config)
319
- progress_bar_config["desc"] = progress_bar_config.get("desc", desc)
320
- progress_bar_config["leave"] = progress_bar_config.get("leave", leave)
321
- if iterable is not None:
322
- return tqdm(iterable, **progress_bar_config)
323
- elif total is not None:
324
- return tqdm(total=total, **progress_bar_config)
325
- else:
326
- raise ValueError("Either `total` or `iterable` has to be defined.")
327
-
328
- @torch.no_grad()
329
- @replace_example_docstring(EXAMPLE_DOC_STRING)
330
- def __call__(
331
- self,
332
- image: PipelineImageInput,
333
- prompt: Union[str, List[str]] = None,
334
- negative_prompt: Optional[Union[str, List[str]]] = None,
335
- num_inference_steps: Optional[int] = None,
336
- ensemble_size: int = 1,
337
- processing_resolution: Optional[int] = None,
338
- match_input_resolution: bool = True,
339
- resample_method_input: str = "bilinear",
340
- resample_method_output: str = "bilinear",
341
- batch_size: int = 1,
342
- ensembling_kwargs: Optional[Dict[str, Any]] = None,
343
- latents: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
344
- prompt_embeds: Optional[torch.Tensor] = None,
345
- negative_prompt_embeds: Optional[torch.Tensor] = None,
346
- num_images_per_prompt: Optional[int] = 1,
347
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
348
- controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
349
- output_type: str = "np",
350
- output_uncertainty: bool = False,
351
- output_latent: bool = False,
352
- skip_preprocess: bool = False,
353
- return_dict: bool = True,
354
- **kwargs,
355
- ):
356
- """
357
- Function invoked when calling the pipeline.
358
-
359
- Args:
360
- image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`),
361
- `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For
362
- arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible
363
- by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
364
- three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
365
- same width and height.
366
- num_inference_steps (`int`, *optional*, defaults to `None`):
367
- Number of denoising diffusion steps during inference. The default value `None` results in automatic
368
- selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
369
- for Marigold-LCM models.
370
- ensemble_size (`int`, defaults to `1`):
371
- Number of ensemble predictions. Recommended values are 5 and higher for better precision, or 1 for
372
- faster inference.
373
- processing_resolution (`int`, *optional*, defaults to `None`):
374
- Effective processing resolution. When set to `0`, matches the larger input image dimension. This
375
- produces crisper predictions, but may also lead to the overall loss of global context. The default
376
- value `None` resolves to the optimal value from the model config.
377
- match_input_resolution (`bool`, *optional*, defaults to `True`):
378
- When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
379
- side of the output will equal to `processing_resolution`.
380
- resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
381
- Resampling method used to resize input images to `processing_resolution`. The accepted values are:
382
- `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
383
- resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
384
- Resampling method used to resize output predictions to match the input resolution. The accepted values
385
- are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
386
- batch_size (`int`, *optional*, defaults to `1`):
387
- Batch size; only matters when setting `ensemble_size` or passing a tensor of images.
388
- ensembling_kwargs (`dict`, *optional*, defaults to `None`)
389
- Extra dictionary with arguments for precise ensembling control. The following options are available:
390
- - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in
391
- every pixel location, can be either `"closest"` or `"mean"`.
392
- latents (`torch.Tensor`, *optional*, defaults to `None`):
393
- Latent noise tensors to replace the random initialization. These can be taken from the previous
394
- function call's output.
395
- generator (`torch.Generator`, or `List[torch.Generator]`, *optional*, defaults to `None`):
396
- Random number generator object to ensure reproducibility.
397
- output_type (`str`, *optional*, defaults to `"np"`):
398
- Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted
399
- values are: `"np"` (numpy array) or `"pt"` (torch tensor).
400
- output_uncertainty (`bool`, *optional*, defaults to `False`):
401
- When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that
402
- the `ensemble_size` argument is set to a value above 2.
403
- output_latent (`bool`, *optional*, defaults to `False`):
404
- When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
405
- within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
406
- `latents` argument.
407
- return_dict (`bool`, *optional*, defaults to `True`):
408
- Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple.
409
-
410
- Examples:
411
-
412
- Returns:
413
- [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`:
414
- If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a
415
- `tuple` is returned where the first element is the prediction, the second element is the uncertainty
416
- (or `None`), and the third is the latent (or `None`).
417
- """
418
-
419
- # 0. Resolving variables.
420
- device = self._execution_device
421
- dtype = self.dtype
422
-
423
- # Model-specific optimal default values leading to fast and reasonable results.
424
- if num_inference_steps is None:
425
- num_inference_steps = self.default_denoising_steps
426
- if processing_resolution is None:
427
- processing_resolution = self.default_processing_resolution
428
-
429
- # 1. Check inputs.
430
- num_images = self.check_inputs(
431
- image,
432
- num_inference_steps,
433
- ensemble_size,
434
- processing_resolution,
435
- resample_method_input,
436
- resample_method_output,
437
- batch_size,
438
- ensembling_kwargs,
439
- latents,
440
- generator,
441
- output_type,
442
- output_uncertainty,
443
- )
444
-
445
-
446
- # 2. Prepare empty text conditioning.
447
- # Model invocation: self.tokenizer, self.text_encoder.
448
- if self.empty_text_embedding is None:
449
- prompt = ""
450
- text_inputs = self.tokenizer(
451
- prompt,
452
- padding="do_not_pad",
453
- max_length=self.tokenizer.model_max_length,
454
- truncation=True,
455
- return_tensors="pt",
456
- )
457
- text_input_ids = text_inputs.input_ids.to(device)
458
- self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024]
459
-
460
-
461
-
462
- # 3. prepare prompt
463
- if self.prompt_embeds is None:
464
- prompt_embeds, negative_prompt_embeds = self.encode_prompt(
465
- self.prompt,
466
- device,
467
- num_images_per_prompt,
468
- False,
469
- negative_prompt,
470
- prompt_embeds=prompt_embeds,
471
- negative_prompt_embeds=None,
472
- lora_scale=None,
473
- clip_skip=None,
474
- )
475
- self.prompt_embeds = prompt_embeds
476
- self.negative_prompt_embeds = negative_prompt_embeds
477
-
478
-
479
-
480
- # 4. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`,
481
- # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where
482
- # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are
483
- # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None`
484
- # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of
485
- # operation and leads to the most reasonable results. Using the native image resolution or any other processing
486
- # resolution can lead to loss of either fine details or global context in the output predictions.
487
- if not skip_preprocess:
488
- image, padding, original_resolution = self.image_processor.preprocess(
489
- image, processing_resolution, resample_method_input, device, dtype
490
- ) # [N,3,PPH,PPW]
491
- else:
492
- padding = (0, 0)
493
- original_resolution = image.shape[2:]
494
- # 5. Encode input image into latent space. At this step, each of the `N` input images is represented with `E`
495
- # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently.
496
- # Latents of each such predictions across all input images and all ensemble members are represented in the
497
- # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
498
- # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure
499
- # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline
500
- # code. For example, in the Marigold-LCM video processing demo, the latents initialization of a frame is taken
501
- # as a convex combination of the latents output of the pipeline for the previous frame and a newly-sampled
502
- # noise. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space
503
- # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`.
504
- # Model invocation: self.vae.encoder.
505
- image_latent, pred_latent = self.prepare_latents(
506
- image, latents, generator, ensemble_size, batch_size
507
- ) # [N*E,4,h,w], [N*E,4,h,w]
508
-
509
- gaus_noise = pred_latent.detach().clone()
510
- del image
511
-
512
-
513
- # 6. obtain control_output
514
-
515
- cond_scale =controlnet_conditioning_scale
516
- down_block_res_samples, mid_block_res_sample = self.controlnet(
517
- image_latent.detach(),
518
- self.t_start,
519
- encoder_hidden_states=self.prompt_embeds,
520
- conditioning_scale=cond_scale,
521
- guess_mode=False,
522
- return_dict=False,
523
- )
524
-
525
- # 7. YOSO sampling
526
- latent_x_t = self.unet(
527
- pred_latent,
528
- self.t_start,
529
- encoder_hidden_states=self.prompt_embeds,
530
- down_block_additional_residuals=down_block_res_samples,
531
- mid_block_additional_residual=mid_block_res_sample,
532
- return_dict=False,
533
- )[0]
534
-
535
-
536
- del (
537
- pred_latent,
538
- image_latent,
539
- )
540
-
541
- # decoder
542
- prediction = self.decode_prediction(latent_x_t)
543
- prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW]
544
-
545
- prediction = self.image_processor.resize_antialias(
546
- prediction, original_resolution, resample_method_output, is_aa=False
547
- ) # [N,3,H,W]
548
- prediction = self.normalize_normals(prediction) # [N,3,H,W]
549
-
550
- if output_type == "np":
551
- prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3]
552
-
553
- # 11. Offload all models
554
- self.maybe_free_model_hooks()
555
-
556
- return YosoNormalsOutput(
557
- prediction=prediction,
558
- latent=latent_x_t,
559
- gaus_noise=gaus_noise,
560
- )
561
-
562
- # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents
563
- def prepare_latents(
564
- self,
565
- image: torch.Tensor,
566
- latents: Optional[torch.Tensor],
567
- generator: Optional[torch.Generator],
568
- ensemble_size: int,
569
- batch_size: int,
570
- ) -> Tuple[torch.Tensor, torch.Tensor]:
571
- def retrieve_latents(encoder_output):
572
- if hasattr(encoder_output, "latent_dist"):
573
- return encoder_output.latent_dist.mode()
574
- elif hasattr(encoder_output, "latents"):
575
- return encoder_output.latents
576
- else:
577
- raise AttributeError("Could not access latents of provided encoder_output")
578
-
579
-
580
-
581
- image_latent = torch.cat(
582
- [
583
- retrieve_latents(self.vae.encode(image[i : i + batch_size]))
584
- for i in range(0, image.shape[0], batch_size)
585
- ],
586
- dim=0,
587
- ) # [N,4,h,w]
588
- image_latent = image_latent * self.vae.config.scaling_factor
589
- image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w]
590
-
591
- pred_latent = torch.zeros_like(image_latent)
592
- if pred_latent is None:
593
- pred_latent = randn_tensor(
594
- image_latent.shape,
595
- generator=generator,
596
- device=image_latent.device,
597
- dtype=image_latent.dtype,
598
- ) # [N*E,4,h,w]
599
-
600
- return image_latent, pred_latent
601
-
602
- def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor:
603
- if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels:
604
- raise ValueError(
605
- f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}."
606
- )
607
-
608
- prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W]
609
-
610
- prediction = self.normalize_normals(prediction) # [B,3,H,W]
611
-
612
- return prediction # [B,3,H,W]
613
-
614
- @staticmethod
615
- def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
616
- if normals.dim() != 4 or normals.shape[1] != 3:
617
- raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
618
-
619
- norm = torch.norm(normals, dim=1, keepdim=True)
620
- normals /= norm.clamp(min=eps)
621
-
622
- return normals
623
-
624
- @staticmethod
625
- def ensemble_normals(
626
- normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest"
627
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
628
- """
629
- Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is
630
- the number of ensemble members for a given prediction of size `(H x W)`.
631
-
632
- Args:
633
- normals (`torch.Tensor`):
634
- Input ensemble normals maps.
635
- output_uncertainty (`bool`, *optional*, defaults to `False`):
636
- Whether to output uncertainty map.
637
- reduction (`str`, *optional*, defaults to `"closest"`):
638
- Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and
639
- `"mean"`.
640
-
641
- Returns:
642
- A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of
643
- uncertainties of shape `(1, 1, H, W)`.
644
- """
645
- if normals.dim() != 4 or normals.shape[1] != 3:
646
- raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
647
- if reduction not in ("closest", "mean"):
648
- raise ValueError(f"Unrecognized reduction method: {reduction}.")
649
-
650
- mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W]
651
- mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W]
652
-
653
- sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W]
654
- sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16
655
-
656
- uncertainty = None
657
- if output_uncertainty:
658
- uncertainty = sim_cos.arccos() # [E,1,H,W]
659
- uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W]
660
-
661
- if reduction == "mean":
662
- return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W]
663
-
664
- closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W]
665
- closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W]
666
- closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W]
667
-
668
- return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W]
669
-
670
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
671
- def retrieve_timesteps(
672
- scheduler,
673
- num_inference_steps: Optional[int] = None,
674
- device: Optional[Union[str, torch.device]] = None,
675
- timesteps: Optional[List[int]] = None,
676
- sigmas: Optional[List[float]] = None,
677
- **kwargs,
678
- ):
679
- """
680
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
681
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
682
-
683
- Args:
684
- scheduler (`SchedulerMixin`):
685
- The scheduler to get timesteps from.
686
- num_inference_steps (`int`):
687
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
688
- must be `None`.
689
- device (`str` or `torch.device`, *optional*):
690
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
691
- timesteps (`List[int]`, *optional*):
692
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
693
- `num_inference_steps` and `sigmas` must be `None`.
694
- sigmas (`List[float]`, *optional*):
695
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
696
- `num_inference_steps` and `timesteps` must be `None`.
697
-
698
- Returns:
699
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
700
- second element is the number of inference steps.
701
- """
702
- if timesteps is not None and sigmas is not None:
703
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
704
- if timesteps is not None:
705
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
706
- if not accepts_timesteps:
707
- raise ValueError(
708
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
709
- f" timestep schedules. Please check whether you are using the correct scheduler."
710
- )
711
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
712
- timesteps = scheduler.timesteps
713
- num_inference_steps = len(timesteps)
714
- elif sigmas is not None:
715
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
716
- if not accept_sigmas:
717
- raise ValueError(
718
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
719
- f" sigmas schedules. Please check whether you are using the correct scheduler."
720
- )
721
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
722
- timesteps = scheduler.timesteps
723
- num_inference_steps = len(timesteps)
724
- else:
725
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
726
- timesteps = scheduler.timesteps
727
- return timesteps, num_inference_steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stablenormal/scheduler/__init__.py DELETED
File without changes
stablenormal/scheduler/heuristics_ddimsampler.py DELETED
@@ -1,243 +0,0 @@
1
- import math
2
- from dataclasses import dataclass
3
- from typing import List, Optional, Tuple, Union
4
-
5
- import numpy as np
6
- import torch
7
- from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler
8
- from diffusers.schedulers.scheduling_utils import SchedulerMixin
9
- from diffusers.configuration_utils import register_to_config, ConfigMixin
10
- import pdb
11
-
12
-
13
- class HEURI_DDIMScheduler(DDIMScheduler, SchedulerMixin, ConfigMixin):
14
-
15
- def set_timesteps(self, num_inference_steps: int, t_start: int, device: Union[str, torch.device] = None):
16
- """
17
- Sets the discrete timesteps used for the diffusion chain (to be run before inference).
18
-
19
- Args:
20
- num_inference_steps (`int`):
21
- The number of diffusion steps used when generating samples with a pre-trained model.
22
- """
23
-
24
- if num_inference_steps > self.config.num_train_timesteps:
25
- raise ValueError(
26
- f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
27
- f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
28
- f" maximal {self.config.num_train_timesteps} timesteps."
29
- )
30
-
31
- self.num_inference_steps = num_inference_steps
32
-
33
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
34
- if self.config.timestep_spacing == "linspace":
35
- timesteps = (
36
- np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
37
- .round()[::-1]
38
- .copy()
39
- .astype(np.int64)
40
- )
41
- elif self.config.timestep_spacing == "leading":
42
- step_ratio = self.config.num_train_timesteps // self.num_inference_steps
43
- # creates integer timesteps by multiplying by ratio
44
- # casting to int to avoid issues when num_inference_step is power of 3
45
- timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
46
- timesteps += self.config.steps_offset
47
- elif self.config.timestep_spacing == "trailing":
48
- step_ratio = self.config.num_train_timesteps / self.num_inference_steps
49
- # creates integer timesteps by multiplying by ratio
50
- # casting to int to avoid issues when num_inference_step is power of 3
51
- timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
52
- timesteps -= 1
53
- else:
54
- raise ValueError(
55
- f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
56
- )
57
-
58
- timesteps = torch.from_numpy(timesteps).to(device)
59
-
60
-
61
- naive_sampling_step = num_inference_steps //2
62
-
63
- # TODO for debug
64
- # naive_sampling_step = 0
65
-
66
- self.naive_sampling_step = naive_sampling_step
67
-
68
- timesteps[:naive_sampling_step] = timesteps[naive_sampling_step] # refine on step 5 for 5 steps, then backward from step 6
69
-
70
- timesteps = [timestep + 1 for timestep in timesteps]
71
-
72
- self.timesteps = timesteps
73
- self.gap = self.config.num_train_timesteps // self.num_inference_steps
74
- self.prev_timesteps = [timestep for timestep in self.timesteps[1:]]
75
- self.prev_timesteps.append(torch.zeros_like(self.prev_timesteps[-1]))
76
-
77
- def step(
78
- self,
79
- model_output: torch.Tensor,
80
- timestep: int,
81
- prev_timestep: int,
82
- sample: torch.Tensor,
83
- eta: float = 0.0,
84
- use_clipped_model_output: bool = False,
85
- generator=None,
86
- cur_step=None,
87
- variance_noise: Optional[torch.Tensor] = None,
88
- gaus_noise: Optional[torch.Tensor] = None,
89
- return_dict: bool = True,
90
- ) -> Union[DDIMSchedulerOutput, Tuple]:
91
- """
92
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
93
- process from the learned model outputs (most often the predicted noise).
94
-
95
- Args:
96
- model_output (`torch.Tensor`):
97
- The direct output from learned diffusion model.
98
- timestep (`float`):
99
- The current discrete timestep in the diffusion chain.
100
- pre_timestep (`float`):
101
- next_timestep
102
- sample (`torch.Tensor`):
103
- A current instance of a sample created by the diffusion process.
104
- eta (`float`):
105
- The weight of noise for added noise in diffusion step.
106
- use_clipped_model_output (`bool`, defaults to `False`):
107
- If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
108
- because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
109
- clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
110
- `use_clipped_model_output` has no effect.
111
- generator (`torch.Generator`, *optional*):
112
- A random number generator.
113
- variance_noise (`torch.Tensor`):
114
- Alternative to generating noise with `generator` by directly providing the noise for the variance
115
- itself. Useful for methods such as [`CycleDiffusion`].
116
- return_dict (`bool`, *optional*, defaults to `True`):
117
- Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
118
-
119
- Returns:
120
- [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
121
- If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
122
- tuple is returned where the first element is the sample tensor.
123
-
124
- """
125
- if self.num_inference_steps is None:
126
- raise ValueError(
127
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
128
- )
129
-
130
- # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
131
- # Ideally, read DDIM paper in-detail understanding
132
-
133
- # Notation (<variable name> -> <name in paper>
134
- # - pred_noise_t -> e_theta(x_t, t)
135
- # - pred_original_sample -> f_theta(x_t, t) or x_0
136
- # - std_dev_t -> sigma_t
137
- # - eta -> η
138
- # - pred_sample_direction -> "direction pointing to x_t"
139
- # - pred_prev_sample -> "x_t-1"
140
-
141
- # 1. get previous step value (=t-1)
142
-
143
- # trick from heuri_sampling
144
- if cur_step == self.naive_sampling_step and timestep == prev_timestep:
145
- timestep += self.gap
146
-
147
-
148
- prev_timestep = prev_timestep # NOTE naive sampling
149
-
150
- # 2. compute alphas, betas
151
- alpha_prod_t = self.alphas_cumprod[timestep]
152
- alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
153
-
154
- beta_prod_t = 1 - alpha_prod_t
155
-
156
- # 3. compute predicted original sample from predicted noise also called
157
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
158
- if self.config.prediction_type == "epsilon":
159
- pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
160
- pred_epsilon = model_output
161
- elif self.config.prediction_type == "sample":
162
- pred_original_sample = model_output
163
- pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
164
- elif self.config.prediction_type == "v_prediction":
165
- pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
166
- pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
167
- else:
168
- raise ValueError(
169
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
170
- " `v_prediction`"
171
- )
172
-
173
- # 4. Clip or threshold "predicted x_0"
174
- if self.config.thresholding:
175
- pred_original_sample = self._threshold_sample(pred_original_sample)
176
-
177
- # 5. compute variance: "sigma_t(η)" -> see formula (16)
178
- # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
179
- variance = self._get_variance(timestep, prev_timestep)
180
- std_dev_t = eta * variance ** (0.5)
181
-
182
-
183
- if use_clipped_model_output:
184
- # the pred_epsilon is always re-derived from the clipped x_0 in Glide
185
- pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
186
-
187
- # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
188
- pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
189
-
190
- # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
191
- prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
192
-
193
- if eta > 0:
194
- if variance_noise is not None and generator is not None:
195
- raise ValueError(
196
- "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
197
- " `variance_noise` stays `None`."
198
- )
199
-
200
- if variance_noise is None:
201
- variance_noise = randn_tensor(
202
- model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
203
- )
204
- variance = std_dev_t * variance_noise
205
-
206
- prev_sample = prev_sample + variance
207
-
208
- if cur_step < self.naive_sampling_step:
209
- prev_sample = self.add_noise(pred_original_sample, torch.randn_like(pred_original_sample), timestep)
210
-
211
- if not return_dict:
212
- return (prev_sample,)
213
-
214
-
215
- return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
216
-
217
-
218
-
219
- def add_noise(
220
- self,
221
- original_samples: torch.Tensor,
222
- noise: torch.Tensor,
223
- timesteps: torch.IntTensor,
224
- ) -> torch.Tensor:
225
- # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
226
- # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
227
- # for the subsequent add_noise calls
228
- self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
229
- alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
230
- timesteps = timesteps.to(original_samples.device)
231
-
232
- sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
233
- sqrt_alpha_prod = sqrt_alpha_prod.flatten()
234
- while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
235
- sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
236
-
237
- sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
238
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
239
- while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
240
- sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
241
-
242
- noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
243
- return noisy_samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stablenormal/stablecontrolnet.py DELETED
@@ -1,1354 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- import inspect
17
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
-
19
- import numpy as np
20
- import PIL.Image
21
- import torch
22
- import torch.nn.functional as F
23
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
24
-
25
- from ...callbacks import MultiPipelineCallbacks, PipelineCallback
26
- from ...image_processor import PipelineImageInput, VaeImageProcessor
27
- from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
28
- from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
29
- from ...models.lora import adjust_lora_scale_text_encoder
30
- from ...schedulers import KarrasDiffusionSchedulers
31
- from ...utils import (
32
- USE_PEFT_BACKEND,
33
- deprecate,
34
- logging,
35
- replace_example_docstring,
36
- scale_lora_layers,
37
- unscale_lora_layers,
38
- )
39
- from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
40
- from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
41
- from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
42
- from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
43
- from .multicontrolnet import MultiControlNetModel
44
-
45
-
46
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
-
48
-
49
- EXAMPLE_DOC_STRING = """
50
- Examples:
51
- ```py
52
- >>> # !pip install opencv-python transformers accelerate
53
- >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
54
- >>> from diffusers.utils import load_image
55
- >>> import numpy as np
56
- >>> import torch
57
-
58
- >>> import cv2
59
- >>> from PIL import Image
60
-
61
- >>> # download an image
62
- >>> image = load_image(
63
- ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
64
- ... )
65
- >>> image = np.array(image)
66
-
67
- >>> # get canny image
68
- >>> image = cv2.Canny(image, 100, 200)
69
- >>> image = image[:, :, None]
70
- >>> image = np.concatenate([image, image, image], axis=2)
71
- >>> canny_image = Image.fromarray(image)
72
-
73
- >>> # load control net and stable diffusion v1-5
74
- >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
75
- >>> pipe = StableDiffusionControlNetPipeline.from_pretrained(
76
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
77
- ... )
78
-
79
- >>> # speed up diffusion process with faster scheduler and memory optimization
80
- >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
81
- >>> # remove following line if xformers is not installed
82
- >>> pipe.enable_xformers_memory_efficient_attention()
83
-
84
- >>> pipe.enable_model_cpu_offload()
85
-
86
- >>> # generate image
87
- >>> generator = torch.manual_seed(0)
88
- >>> image = pipe(
89
- ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
90
- ... ).images[0]
91
- ```
92
- """
93
-
94
-
95
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
96
- def retrieve_timesteps(
97
- scheduler,
98
- num_inference_steps: Optional[int] = None,
99
- device: Optional[Union[str, torch.device]] = None,
100
- timesteps: Optional[List[int]] = None,
101
- sigmas: Optional[List[float]] = None,
102
- **kwargs,
103
- ):
104
- """
105
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
106
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
107
-
108
- Args:
109
- scheduler (`SchedulerMixin`):
110
- The scheduler to get timesteps from.
111
- num_inference_steps (`int`):
112
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
113
- must be `None`.
114
- device (`str` or `torch.device`, *optional*):
115
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
116
- timesteps (`List[int]`, *optional*):
117
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
118
- `num_inference_steps` and `sigmas` must be `None`.
119
- sigmas (`List[float]`, *optional*):
120
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
121
- `num_inference_steps` and `timesteps` must be `None`.
122
-
123
- Returns:
124
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
125
- second element is the number of inference steps.
126
- """
127
- if timesteps is not None and sigmas is not None:
128
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
129
- if timesteps is not None:
130
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
131
- if not accepts_timesteps:
132
- raise ValueError(
133
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
134
- f" timestep schedules. Please check whether you are using the correct scheduler."
135
- )
136
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
137
- timesteps = scheduler.timesteps
138
- num_inference_steps = len(timesteps)
139
- elif sigmas is not None:
140
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
141
- if not accept_sigmas:
142
- raise ValueError(
143
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
144
- f" sigmas schedules. Please check whether you are using the correct scheduler."
145
- )
146
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
147
- timesteps = scheduler.timesteps
148
- num_inference_steps = len(timesteps)
149
- else:
150
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
151
- timesteps = scheduler.timesteps
152
- return timesteps, num_inference_steps
153
-
154
-
155
- class StableDiffusionControlNetPipeline(
156
- DiffusionPipeline,
157
- StableDiffusionMixin,
158
- TextualInversionLoaderMixin,
159
- LoraLoaderMixin,
160
- IPAdapterMixin,
161
- FromSingleFileMixin,
162
- ):
163
- r"""
164
- Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
165
-
166
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
167
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
168
-
169
- The pipeline also inherits the following loading methods:
170
- - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
171
- - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
172
- - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
173
- - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
174
- - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
175
-
176
- Args:
177
- vae ([`AutoencoderKL`]):
178
- Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
179
- text_encoder ([`~transformers.CLIPTextModel`]):
180
- Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
181
- tokenizer ([`~transformers.CLIPTokenizer`]):
182
- A `CLIPTokenizer` to tokenize text.
183
- unet ([`UNet2DConditionModel`]):
184
- A `UNet2DConditionModel` to denoise the encoded image latents.
185
- controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
186
- Provides additional conditioning to the `unet` during the denoising process. If you set multiple
187
- ControlNets as a list, the outputs from each ControlNet are added together to create one combined
188
- additional conditioning.
189
- scheduler ([`SchedulerMixin`]):
190
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
191
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
192
- safety_checker ([`StableDiffusionSafetyChecker`]):
193
- Classification module that estimates whether generated images could be considered offensive or harmful.
194
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
195
- about a model's potential harms.
196
- feature_extractor ([`~transformers.CLIPImageProcessor`]):
197
- A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
198
- """
199
-
200
- model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
201
- _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
202
- _exclude_from_cpu_offload = ["safety_checker"]
203
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
204
-
205
- def __init__(
206
- self,
207
- vae: AutoencoderKL,
208
- text_encoder: CLIPTextModel,
209
- tokenizer: CLIPTokenizer,
210
- unet: UNet2DConditionModel,
211
- controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
212
- scheduler: KarrasDiffusionSchedulers,
213
- safety_checker: StableDiffusionSafetyChecker,
214
- feature_extractor: CLIPImageProcessor,
215
- image_encoder: CLIPVisionModelWithProjection = None,
216
- requires_safety_checker: bool = True,
217
- ):
218
- super().__init__()
219
-
220
- if safety_checker is None and requires_safety_checker:
221
- logger.warning(
222
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
223
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
224
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
225
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
226
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
227
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
228
- )
229
-
230
- if safety_checker is not None and feature_extractor is None:
231
- raise ValueError(
232
- "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
233
- " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
234
- )
235
-
236
- if isinstance(controlnet, (list, tuple)):
237
- controlnet = MultiControlNetModel(controlnet)
238
-
239
- self.register_modules(
240
- vae=vae,
241
- text_encoder=text_encoder,
242
- tokenizer=tokenizer,
243
- unet=unet,
244
- controlnet=controlnet,
245
- scheduler=scheduler,
246
- safety_checker=safety_checker,
247
- feature_extractor=feature_extractor,
248
- image_encoder=image_encoder,
249
- )
250
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
251
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
252
- self.control_image_processor = VaeImageProcessor(
253
- vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
254
- )
255
- self.register_to_config(requires_safety_checker=requires_safety_checker)
256
-
257
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
258
- def _encode_prompt(
259
- self,
260
- prompt,
261
- device,
262
- num_images_per_prompt,
263
- do_classifier_free_guidance,
264
- negative_prompt=None,
265
- prompt_embeds: Optional[torch.Tensor] = None,
266
- negative_prompt_embeds: Optional[torch.Tensor] = None,
267
- lora_scale: Optional[float] = None,
268
- **kwargs,
269
- ):
270
- deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
271
- deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
272
-
273
- prompt_embeds_tuple = self.encode_prompt(
274
- prompt=prompt,
275
- device=device,
276
- num_images_per_prompt=num_images_per_prompt,
277
- do_classifier_free_guidance=do_classifier_free_guidance,
278
- negative_prompt=negative_prompt,
279
- prompt_embeds=prompt_embeds,
280
- negative_prompt_embeds=negative_prompt_embeds,
281
- lora_scale=lora_scale,
282
- **kwargs,
283
- )
284
-
285
- # concatenate for backwards comp
286
- prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
287
-
288
- return prompt_embeds
289
-
290
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
291
- def encode_prompt(
292
- self,
293
- prompt,
294
- device,
295
- num_images_per_prompt,
296
- do_classifier_free_guidance,
297
- negative_prompt=None,
298
- prompt_embeds: Optional[torch.Tensor] = None,
299
- negative_prompt_embeds: Optional[torch.Tensor] = None,
300
- lora_scale: Optional[float] = None,
301
- clip_skip: Optional[int] = None,
302
- ):
303
- r"""
304
- Encodes the prompt into text encoder hidden states.
305
-
306
- Args:
307
- prompt (`str` or `List[str]`, *optional*):
308
- prompt to be encoded
309
- device: (`torch.device`):
310
- torch device
311
- num_images_per_prompt (`int`):
312
- number of images that should be generated per prompt
313
- do_classifier_free_guidance (`bool`):
314
- whether to use classifier free guidance or not
315
- negative_prompt (`str` or `List[str]`, *optional*):
316
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
317
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
318
- less than `1`).
319
- prompt_embeds (`torch.Tensor`, *optional*):
320
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
321
- provided, text embeddings will be generated from `prompt` input argument.
322
- negative_prompt_embeds (`torch.Tensor`, *optional*):
323
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
324
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
325
- argument.
326
- lora_scale (`float`, *optional*):
327
- A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
328
- clip_skip (`int`, *optional*):
329
- Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
330
- the output of the pre-final layer will be used for computing the prompt embeddings.
331
- """
332
- # set lora scale so that monkey patched LoRA
333
- # function of text encoder can correctly access it
334
- if lora_scale is not None and isinstance(self, LoraLoaderMixin):
335
- self._lora_scale = lora_scale
336
-
337
- # dynamically adjust the LoRA scale
338
- if not USE_PEFT_BACKEND:
339
- adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
340
- else:
341
- scale_lora_layers(self.text_encoder, lora_scale)
342
-
343
- if prompt is not None and isinstance(prompt, str):
344
- batch_size = 1
345
- elif prompt is not None and isinstance(prompt, list):
346
- batch_size = len(prompt)
347
- else:
348
- batch_size = prompt_embeds.shape[0]
349
-
350
- if prompt_embeds is None:
351
- # textual inversion: process multi-vector tokens if necessary
352
- if isinstance(self, TextualInversionLoaderMixin):
353
- prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
354
-
355
- text_inputs = self.tokenizer(
356
- prompt,
357
- padding="max_length",
358
- max_length=self.tokenizer.model_max_length,
359
- truncation=True,
360
- return_tensors="pt",
361
- )
362
- text_input_ids = text_inputs.input_ids
363
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
364
-
365
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
366
- text_input_ids, untruncated_ids
367
- ):
368
- removed_text = self.tokenizer.batch_decode(
369
- untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
370
- )
371
- logger.warning(
372
- "The following part of your input was truncated because CLIP can only handle sequences up to"
373
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
374
- )
375
-
376
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
377
- attention_mask = text_inputs.attention_mask.to(device)
378
- else:
379
- attention_mask = None
380
-
381
- if clip_skip is None:
382
- prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
383
- prompt_embeds = prompt_embeds[0]
384
- else:
385
- prompt_embeds = self.text_encoder(
386
- text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
387
- )
388
- # Access the `hidden_states` first, that contains a tuple of
389
- # all the hidden states from the encoder layers. Then index into
390
- # the tuple to access the hidden states from the desired layer.
391
- prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
392
- # We also need to apply the final LayerNorm here to not mess with the
393
- # representations. The `last_hidden_states` that we typically use for
394
- # obtaining the final prompt representations passes through the LayerNorm
395
- # layer.
396
- prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
397
-
398
- if self.text_encoder is not None:
399
- prompt_embeds_dtype = self.text_encoder.dtype
400
- elif self.unet is not None:
401
- prompt_embeds_dtype = self.unet.dtype
402
- else:
403
- prompt_embeds_dtype = prompt_embeds.dtype
404
-
405
- prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
406
-
407
- bs_embed, seq_len, _ = prompt_embeds.shape
408
- # duplicate text embeddings for each generation per prompt, using mps friendly method
409
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
410
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
411
-
412
- # get unconditional embeddings for classifier free guidance
413
- if do_classifier_free_guidance and negative_prompt_embeds is None:
414
- uncond_tokens: List[str]
415
- if negative_prompt is None:
416
- uncond_tokens = [""] * batch_size
417
- elif prompt is not None and type(prompt) is not type(negative_prompt):
418
- raise TypeError(
419
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
420
- f" {type(prompt)}."
421
- )
422
- elif isinstance(negative_prompt, str):
423
- uncond_tokens = [negative_prompt]
424
- elif batch_size != len(negative_prompt):
425
- raise ValueError(
426
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
427
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
428
- " the batch size of `prompt`."
429
- )
430
- else:
431
- uncond_tokens = negative_prompt
432
-
433
- # textual inversion: process multi-vector tokens if necessary
434
- if isinstance(self, TextualInversionLoaderMixin):
435
- uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
436
-
437
- max_length = prompt_embeds.shape[1]
438
- uncond_input = self.tokenizer(
439
- uncond_tokens,
440
- padding="max_length",
441
- max_length=max_length,
442
- truncation=True,
443
- return_tensors="pt",
444
- )
445
-
446
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
447
- attention_mask = uncond_input.attention_mask.to(device)
448
- else:
449
- attention_mask = None
450
-
451
- negative_prompt_embeds = self.text_encoder(
452
- uncond_input.input_ids.to(device),
453
- attention_mask=attention_mask,
454
- )
455
- negative_prompt_embeds = negative_prompt_embeds[0]
456
-
457
- if do_classifier_free_guidance:
458
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
459
- seq_len = negative_prompt_embeds.shape[1]
460
-
461
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
462
-
463
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
464
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
465
-
466
- if self.text_encoder is not None:
467
- if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
468
- # Retrieve the original scale by scaling back the LoRA layers
469
- unscale_lora_layers(self.text_encoder, lora_scale)
470
-
471
- return prompt_embeds, negative_prompt_embeds
472
-
473
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
474
- def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
475
- dtype = next(self.image_encoder.parameters()).dtype
476
-
477
- if not isinstance(image, torch.Tensor):
478
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
479
-
480
- image = image.to(device=device, dtype=dtype)
481
- if output_hidden_states:
482
- image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
483
- image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
484
- uncond_image_enc_hidden_states = self.image_encoder(
485
- torch.zeros_like(image), output_hidden_states=True
486
- ).hidden_states[-2]
487
- uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
488
- num_images_per_prompt, dim=0
489
- )
490
- return image_enc_hidden_states, uncond_image_enc_hidden_states
491
- else:
492
- image_embeds = self.image_encoder(image).image_embeds
493
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
494
- uncond_image_embeds = torch.zeros_like(image_embeds)
495
-
496
- return image_embeds, uncond_image_embeds
497
-
498
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
499
- def prepare_ip_adapter_image_embeds(
500
- self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
501
- ):
502
- if ip_adapter_image_embeds is None:
503
- if not isinstance(ip_adapter_image, list):
504
- ip_adapter_image = [ip_adapter_image]
505
-
506
- if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
507
- raise ValueError(
508
- f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
509
- )
510
-
511
- image_embeds = []
512
- for single_ip_adapter_image, image_proj_layer in zip(
513
- ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
514
- ):
515
- output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
516
- single_image_embeds, single_negative_image_embeds = self.encode_image(
517
- single_ip_adapter_image, device, 1, output_hidden_state
518
- )
519
- single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
520
- single_negative_image_embeds = torch.stack(
521
- [single_negative_image_embeds] * num_images_per_prompt, dim=0
522
- )
523
-
524
- if do_classifier_free_guidance:
525
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
526
- single_image_embeds = single_image_embeds.to(device)
527
-
528
- image_embeds.append(single_image_embeds)
529
- else:
530
- repeat_dims = [1]
531
- image_embeds = []
532
- for single_image_embeds in ip_adapter_image_embeds:
533
- if do_classifier_free_guidance:
534
- single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
535
- single_image_embeds = single_image_embeds.repeat(
536
- num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
537
- )
538
- single_negative_image_embeds = single_negative_image_embeds.repeat(
539
- num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
540
- )
541
- single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
542
- else:
543
- single_image_embeds = single_image_embeds.repeat(
544
- num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
545
- )
546
- image_embeds.append(single_image_embeds)
547
-
548
- return image_embeds
549
-
550
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
551
- def run_safety_checker(self, image, device, dtype):
552
- if self.safety_checker is None:
553
- has_nsfw_concept = None
554
- else:
555
- if torch.is_tensor(image):
556
- feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
557
- else:
558
- feature_extractor_input = self.image_processor.numpy_to_pil(image)
559
- safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
560
- image, has_nsfw_concept = self.safety_checker(
561
- images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
562
- )
563
- return image, has_nsfw_concept
564
-
565
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
566
- def decode_latents(self, latents):
567
- deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
568
- deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
569
-
570
- latents = 1 / self.vae.config.scaling_factor * latents
571
- image = self.vae.decode(latents, return_dict=False)[0]
572
- image = (image / 2 + 0.5).clamp(0, 1)
573
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
574
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
575
- return image
576
-
577
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
578
- def prepare_extra_step_kwargs(self, generator, eta):
579
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
580
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
581
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
582
- # and should be between [0, 1]
583
-
584
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
585
- extra_step_kwargs = {}
586
- if accepts_eta:
587
- extra_step_kwargs["eta"] = eta
588
-
589
- # check if the scheduler accepts generator
590
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
591
- if accepts_generator:
592
- extra_step_kwargs["generator"] = generator
593
- return extra_step_kwargs
594
-
595
- def check_inputs(
596
- self,
597
- prompt,
598
- image,
599
- callback_steps,
600
- negative_prompt=None,
601
- prompt_embeds=None,
602
- negative_prompt_embeds=None,
603
- ip_adapter_image=None,
604
- ip_adapter_image_embeds=None,
605
- controlnet_conditioning_scale=1.0,
606
- control_guidance_start=0.0,
607
- control_guidance_end=1.0,
608
- callback_on_step_end_tensor_inputs=None,
609
- ):
610
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
611
- raise ValueError(
612
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
613
- f" {type(callback_steps)}."
614
- )
615
-
616
- if callback_on_step_end_tensor_inputs is not None and not all(
617
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
618
- ):
619
- raise ValueError(
620
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
621
- )
622
-
623
- if prompt is not None and prompt_embeds is not None:
624
- raise ValueError(
625
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
626
- " only forward one of the two."
627
- )
628
- elif prompt is None and prompt_embeds is None:
629
- raise ValueError(
630
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
631
- )
632
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
633
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
634
-
635
- if negative_prompt is not None and negative_prompt_embeds is not None:
636
- raise ValueError(
637
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
638
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
639
- )
640
-
641
- if prompt_embeds is not None and negative_prompt_embeds is not None:
642
- if prompt_embeds.shape != negative_prompt_embeds.shape:
643
- raise ValueError(
644
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
645
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
646
- f" {negative_prompt_embeds.shape}."
647
- )
648
-
649
- # Check `image`
650
- is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
651
- self.controlnet, torch._dynamo.eval_frame.OptimizedModule
652
- )
653
- if (
654
- isinstance(self.controlnet, ControlNetModel)
655
- or is_compiled
656
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
657
- ):
658
- self.check_image(image, prompt, prompt_embeds)
659
- elif (
660
- isinstance(self.controlnet, MultiControlNetModel)
661
- or is_compiled
662
- and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
663
- ):
664
- if not isinstance(image, list):
665
- raise TypeError("For multiple controlnets: `image` must be type `list`")
666
-
667
- # When `image` is a nested list:
668
- # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
669
- elif any(isinstance(i, list) for i in image):
670
- transposed_image = [list(t) for t in zip(*image)]
671
- if len(transposed_image) != len(self.controlnet.nets):
672
- raise ValueError(
673
- f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets."
674
- )
675
- for image_ in transposed_image:
676
- self.check_image(image_, prompt, prompt_embeds)
677
- elif len(image) != len(self.controlnet.nets):
678
- raise ValueError(
679
- f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
680
- )
681
- else:
682
- for image_ in image:
683
- self.check_image(image_, prompt, prompt_embeds)
684
- else:
685
- assert False
686
-
687
- # Check `controlnet_conditioning_scale`
688
- if (
689
- isinstance(self.controlnet, ControlNetModel)
690
- or is_compiled
691
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
692
- ):
693
- if not isinstance(controlnet_conditioning_scale, float):
694
- raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
695
- elif (
696
- isinstance(self.controlnet, MultiControlNetModel)
697
- or is_compiled
698
- and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
699
- ):
700
- if isinstance(controlnet_conditioning_scale, list):
701
- if any(isinstance(i, list) for i in controlnet_conditioning_scale):
702
- raise ValueError(
703
- "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. "
704
- "The conditioning scale must be fixed across the batch."
705
- )
706
- elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
707
- self.controlnet.nets
708
- ):
709
- raise ValueError(
710
- "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
711
- " the same length as the number of controlnets"
712
- )
713
- else:
714
- assert False
715
-
716
- if not isinstance(control_guidance_start, (tuple, list)):
717
- control_guidance_start = [control_guidance_start]
718
-
719
- if not isinstance(control_guidance_end, (tuple, list)):
720
- control_guidance_end = [control_guidance_end]
721
-
722
- if len(control_guidance_start) != len(control_guidance_end):
723
- raise ValueError(
724
- f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
725
- )
726
-
727
- if isinstance(self.controlnet, MultiControlNetModel):
728
- if len(control_guidance_start) != len(self.controlnet.nets):
729
- raise ValueError(
730
- f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
731
- )
732
-
733
- for start, end in zip(control_guidance_start, control_guidance_end):
734
- if start >= end:
735
- raise ValueError(
736
- f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
737
- )
738
- if start < 0.0:
739
- raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
740
- if end > 1.0:
741
- raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
742
-
743
- if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
744
- raise ValueError(
745
- "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
746
- )
747
-
748
- if ip_adapter_image_embeds is not None:
749
- if not isinstance(ip_adapter_image_embeds, list):
750
- raise ValueError(
751
- f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
752
- )
753
- elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
754
- raise ValueError(
755
- f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
756
- )
757
-
758
- def check_image(self, image, prompt, prompt_embeds):
759
- image_is_pil = isinstance(image, PIL.Image.Image)
760
- image_is_tensor = isinstance(image, torch.Tensor)
761
- image_is_np = isinstance(image, np.ndarray)
762
- image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
763
- image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
764
- image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
765
-
766
- if (
767
- not image_is_pil
768
- and not image_is_tensor
769
- and not image_is_np
770
- and not image_is_pil_list
771
- and not image_is_tensor_list
772
- and not image_is_np_list
773
- ):
774
- raise TypeError(
775
- f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
776
- )
777
-
778
- if image_is_pil:
779
- image_batch_size = 1
780
- else:
781
- image_batch_size = len(image)
782
-
783
- if prompt is not None and isinstance(prompt, str):
784
- prompt_batch_size = 1
785
- elif prompt is not None and isinstance(prompt, list):
786
- prompt_batch_size = len(prompt)
787
- elif prompt_embeds is not None:
788
- prompt_batch_size = prompt_embeds.shape[0]
789
-
790
- if image_batch_size != 1 and image_batch_size != prompt_batch_size:
791
- raise ValueError(
792
- f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
793
- )
794
-
795
- def prepare_image(
796
- self,
797
- image,
798
- width,
799
- height,
800
- batch_size,
801
- num_images_per_prompt,
802
- device,
803
- dtype,
804
- do_classifier_free_guidance=False,
805
- guess_mode=False,
806
- ):
807
- image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
808
- image_batch_size = image.shape[0]
809
-
810
- if image_batch_size == 1:
811
- repeat_by = batch_size
812
- else:
813
- # image batch size is the same as prompt batch size
814
- repeat_by = num_images_per_prompt
815
-
816
- image = image.repeat_interleave(repeat_by, dim=0)
817
-
818
- image = image.to(device=device, dtype=dtype)
819
-
820
- if do_classifier_free_guidance and not guess_mode:
821
- image = torch.cat([image] * 2)
822
-
823
- return image
824
-
825
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
826
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
827
- shape = (
828
- batch_size,
829
- num_channels_latents,
830
- int(height) // self.vae_scale_factor,
831
- int(width) // self.vae_scale_factor,
832
- )
833
- if isinstance(generator, list) and len(generator) != batch_size:
834
- raise ValueError(
835
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
836
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
837
- )
838
-
839
- if latents is None:
840
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
841
- else:
842
- latents = latents.to(device)
843
-
844
- # scale the initial noise by the standard deviation required by the scheduler
845
- latents = latents * self.scheduler.init_noise_sigma
846
- return latents
847
-
848
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
849
- def get_guidance_scale_embedding(
850
- self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
851
- ) -> torch.Tensor:
852
- """
853
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
854
-
855
- Args:
856
- w (`torch.Tensor`):
857
- Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
858
- embedding_dim (`int`, *optional*, defaults to 512):
859
- Dimension of the embeddings to generate.
860
- dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
861
- Data type of the generated embeddings.
862
-
863
- Returns:
864
- `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
865
- """
866
- assert len(w.shape) == 1
867
- w = w * 1000.0
868
-
869
- half_dim = embedding_dim // 2
870
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
871
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
872
- emb = w.to(dtype)[:, None] * emb[None, :]
873
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
874
- if embedding_dim % 2 == 1: # zero pad
875
- emb = torch.nn.functional.pad(emb, (0, 1))
876
- assert emb.shape == (w.shape[0], embedding_dim)
877
- return emb
878
-
879
- @property
880
- def guidance_scale(self):
881
- return self._guidance_scale
882
-
883
- @property
884
- def clip_skip(self):
885
- return self._clip_skip
886
-
887
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
888
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
889
- # corresponds to doing no classifier free guidance.
890
- @property
891
- def do_classifier_free_guidance(self):
892
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
893
-
894
- @property
895
- def cross_attention_kwargs(self):
896
- return self._cross_attention_kwargs
897
-
898
- @property
899
- def num_timesteps(self):
900
- return self._num_timesteps
901
-
902
- @torch.no_grad()
903
- @replace_example_docstring(EXAMPLE_DOC_STRING)
904
- def __call__(
905
- self,
906
- prompt: Union[str, List[str]] = None,
907
- image: PipelineImageInput = None,
908
- height: Optional[int] = None,
909
- width: Optional[int] = None,
910
- num_inference_steps: int = 50,
911
- timesteps: List[int] = None,
912
- sigmas: List[float] = None,
913
- guidance_scale: float = 7.5,
914
- negative_prompt: Optional[Union[str, List[str]]] = None,
915
- num_images_per_prompt: Optional[int] = 1,
916
- eta: float = 0.0,
917
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
918
- latents: Optional[torch.Tensor] = None,
919
- prompt_embeds: Optional[torch.Tensor] = None,
920
- negative_prompt_embeds: Optional[torch.Tensor] = None,
921
- ip_adapter_image: Optional[PipelineImageInput] = None,
922
- ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
923
- output_type: Optional[str] = "pil",
924
- return_dict: bool = True,
925
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
926
- controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
927
- guess_mode: bool = False,
928
- control_guidance_start: Union[float, List[float]] = 0.0,
929
- control_guidance_end: Union[float, List[float]] = 1.0,
930
- clip_skip: Optional[int] = None,
931
- callback_on_step_end: Optional[
932
- Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
933
- ] = None,
934
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
935
- **kwargs,
936
- ):
937
- r"""
938
- The call function to the pipeline for generation.
939
-
940
- Args:
941
- prompt (`str` or `List[str]`, *optional*):
942
- The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
943
- image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
944
- `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
945
- The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
946
- specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
947
- as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
948
- width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
949
- images must be passed as a list such that each element of the list can be correctly batched for input
950
- to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single
951
- ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple
952
- ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet.
953
- height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
954
- The height in pixels of the generated image.
955
- width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
956
- The width in pixels of the generated image.
957
- num_inference_steps (`int`, *optional*, defaults to 50):
958
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
959
- expense of slower inference.
960
- timesteps (`List[int]`, *optional*):
961
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
962
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
963
- passed will be used. Must be in descending order.
964
- sigmas (`List[float]`, *optional*):
965
- Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
966
- their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
967
- will be used.
968
- guidance_scale (`float`, *optional*, defaults to 7.5):
969
- A higher guidance scale value encourages the model to generate images closely linked to the text
970
- `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
971
- negative_prompt (`str` or `List[str]`, *optional*):
972
- The prompt or prompts to guide what to not include in image generation. If not defined, you need to
973
- pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
974
- num_images_per_prompt (`int`, *optional*, defaults to 1):
975
- The number of images to generate per prompt.
976
- eta (`float`, *optional*, defaults to 0.0):
977
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
978
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
979
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
980
- A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
981
- generation deterministic.
982
- latents (`torch.Tensor`, *optional*):
983
- Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
984
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
985
- tensor is generated by sampling using the supplied random `generator`.
986
- prompt_embeds (`torch.Tensor`, *optional*):
987
- Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
988
- provided, text embeddings are generated from the `prompt` input argument.
989
- negative_prompt_embeds (`torch.Tensor`, *optional*):
990
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
991
- not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
992
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
993
- ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
994
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
995
- IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
996
- contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
997
- provided, embeddings are computed from the `ip_adapter_image` input argument.
998
- output_type (`str`, *optional*, defaults to `"pil"`):
999
- The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1000
- return_dict (`bool`, *optional*, defaults to `True`):
1001
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1002
- plain tuple.
1003
- callback (`Callable`, *optional*):
1004
- A function that calls every `callback_steps` steps during inference. The function is called with the
1005
- following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
1006
- callback_steps (`int`, *optional*, defaults to 1):
1007
- The frequency at which the `callback` function is called. If not specified, the callback is called at
1008
- every step.
1009
- cross_attention_kwargs (`dict`, *optional*):
1010
- A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1011
- [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1012
- controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1013
- The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
1014
- to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
1015
- the corresponding scale as a list.
1016
- guess_mode (`bool`, *optional*, defaults to `False`):
1017
- The ControlNet encoder tries to recognize the content of the input image even if you remove all
1018
- prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
1019
- control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1020
- The percentage of total steps at which the ControlNet starts applying.
1021
- control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1022
- The percentage of total steps at which the ControlNet stops applying.
1023
- clip_skip (`int`, *optional*):
1024
- Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1025
- the output of the pre-final layer will be used for computing the prompt embeddings.
1026
- callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1027
- A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1028
- each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1029
- DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1030
- list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1031
- callback_on_step_end_tensor_inputs (`List`, *optional*):
1032
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1033
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1034
- `._callback_tensor_inputs` attribute of your pipeline class.
1035
-
1036
- Examples:
1037
-
1038
- Returns:
1039
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1040
- If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1041
- otherwise a `tuple` is returned where the first element is a list with the generated images and the
1042
- second element is a list of `bool`s indicating whether the corresponding generated image contains
1043
- "not-safe-for-work" (nsfw) content.
1044
- """
1045
-
1046
- callback = kwargs.pop("callback", None)
1047
- callback_steps = kwargs.pop("callback_steps", None)
1048
-
1049
- if callback is not None:
1050
- deprecate(
1051
- "callback",
1052
- "1.0.0",
1053
- "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1054
- )
1055
- if callback_steps is not None:
1056
- deprecate(
1057
- "callback_steps",
1058
- "1.0.0",
1059
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1060
- )
1061
-
1062
- if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1063
- callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1064
-
1065
- controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1066
-
1067
- # align format for control guidance
1068
- if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1069
- control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1070
- elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1071
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1072
- elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1073
- mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1074
- control_guidance_start, control_guidance_end = (
1075
- mult * [control_guidance_start],
1076
- mult * [control_guidance_end],
1077
- )
1078
-
1079
- # 1. Check inputs. Raise error if not correct
1080
- self.check_inputs(
1081
- prompt,
1082
- image,
1083
- callback_steps,
1084
- negative_prompt,
1085
- prompt_embeds,
1086
- negative_prompt_embeds,
1087
- ip_adapter_image,
1088
- ip_adapter_image_embeds,
1089
- controlnet_conditioning_scale,
1090
- control_guidance_start,
1091
- control_guidance_end,
1092
- callback_on_step_end_tensor_inputs,
1093
- )
1094
-
1095
- self._guidance_scale = guidance_scale
1096
- self._clip_skip = clip_skip
1097
- self._cross_attention_kwargs = cross_attention_kwargs
1098
-
1099
- # 2. Define call parameters
1100
- if prompt is not None and isinstance(prompt, str):
1101
- batch_size = 1
1102
- elif prompt is not None and isinstance(prompt, list):
1103
- batch_size = len(prompt)
1104
- else:
1105
- batch_size = prompt_embeds.shape[0]
1106
-
1107
- device = self._execution_device
1108
-
1109
- if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1110
- controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1111
-
1112
- global_pool_conditions = (
1113
- controlnet.config.global_pool_conditions
1114
- if isinstance(controlnet, ControlNetModel)
1115
- else controlnet.nets[0].config.global_pool_conditions
1116
- )
1117
- guess_mode = guess_mode or global_pool_conditions
1118
-
1119
- # 3. Encode input prompt
1120
- text_encoder_lora_scale = (
1121
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1122
- )
1123
- prompt_embeds, negative_prompt_embeds = self.encode_prompt(
1124
- prompt,
1125
- device,
1126
- num_images_per_prompt,
1127
- self.do_classifier_free_guidance,
1128
- negative_prompt,
1129
- prompt_embeds=prompt_embeds,
1130
- negative_prompt_embeds=negative_prompt_embeds,
1131
- lora_scale=text_encoder_lora_scale,
1132
- clip_skip=self.clip_skip,
1133
- )
1134
- # For classifier free guidance, we need to do two forward passes.
1135
- # Here we concatenate the unconditional and text embeddings into a single batch
1136
- # to avoid doing two forward passes
1137
- if self.do_classifier_free_guidance:
1138
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1139
-
1140
- if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1141
- image_embeds = self.prepare_ip_adapter_image_embeds(
1142
- ip_adapter_image,
1143
- ip_adapter_image_embeds,
1144
- device,
1145
- batch_size * num_images_per_prompt,
1146
- self.do_classifier_free_guidance,
1147
- )
1148
-
1149
- # 4. Prepare image
1150
- if isinstance(controlnet, ControlNetModel):
1151
- image = self.prepare_image(
1152
- image=image,
1153
- width=width,
1154
- height=height,
1155
- batch_size=batch_size * num_images_per_prompt,
1156
- num_images_per_prompt=num_images_per_prompt,
1157
- device=device,
1158
- dtype=controlnet.dtype,
1159
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1160
- guess_mode=guess_mode,
1161
- )
1162
- height, width = image.shape[-2:]
1163
- elif isinstance(controlnet, MultiControlNetModel):
1164
- images = []
1165
-
1166
- # Nested lists as ControlNet condition
1167
- if isinstance(image[0], list):
1168
- # Transpose the nested image list
1169
- image = [list(t) for t in zip(*image)]
1170
-
1171
- for image_ in image:
1172
- image_ = self.prepare_image(
1173
- image=image_,
1174
- width=width,
1175
- height=height,
1176
- batch_size=batch_size * num_images_per_prompt,
1177
- num_images_per_prompt=num_images_per_prompt,
1178
- device=device,
1179
- dtype=controlnet.dtype,
1180
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1181
- guess_mode=guess_mode,
1182
- )
1183
-
1184
- images.append(image_)
1185
-
1186
- image = images
1187
- height, width = image[0].shape[-2:]
1188
- else:
1189
- assert False
1190
-
1191
- # 5. Prepare timesteps
1192
- timesteps, num_inference_steps = retrieve_timesteps(
1193
- self.scheduler, num_inference_steps, device, timesteps, sigmas
1194
- )
1195
- self._num_timesteps = len(timesteps)
1196
-
1197
- # 6. Prepare latent variables
1198
- num_channels_latents = self.unet.config.in_channels
1199
- latents = self.prepare_latents(
1200
- batch_size * num_images_per_prompt,
1201
- num_channels_latents,
1202
- height,
1203
- width,
1204
- prompt_embeds.dtype,
1205
- device,
1206
- generator,
1207
- latents,
1208
- )
1209
-
1210
- # 6.5 Optionally get Guidance Scale Embedding
1211
- timestep_cond = None
1212
- if self.unet.config.time_cond_proj_dim is not None:
1213
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1214
- timestep_cond = self.get_guidance_scale_embedding(
1215
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1216
- ).to(device=device, dtype=latents.dtype)
1217
-
1218
- # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1219
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1220
-
1221
- # 7.1 Add image embeds for IP-Adapter
1222
- added_cond_kwargs = (
1223
- {"image_embeds": image_embeds}
1224
- if ip_adapter_image is not None or ip_adapter_image_embeds is not None
1225
- else None
1226
- )
1227
-
1228
- # 7.2 Create tensor stating which controlnets to keep
1229
- controlnet_keep = []
1230
- for i in range(len(timesteps)):
1231
- keeps = [
1232
- 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1233
- for s, e in zip(control_guidance_start, control_guidance_end)
1234
- ]
1235
- controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1236
-
1237
- # 8. Denoising loop
1238
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1239
- is_unet_compiled = is_compiled_module(self.unet)
1240
- is_controlnet_compiled = is_compiled_module(self.controlnet)
1241
- is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1242
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1243
- for i, t in enumerate(timesteps):
1244
- # Relevant thread:
1245
- # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1246
- if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
1247
- torch._inductor.cudagraph_mark_step_begin()
1248
- # expand the latents if we are doing classifier free guidance
1249
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1250
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1251
-
1252
- # controlnet(s) inference
1253
- if guess_mode and self.do_classifier_free_guidance:
1254
- # Infer ControlNet only for the conditional batch.
1255
- control_model_input = latents
1256
- control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1257
- controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1258
- else:
1259
- control_model_input = latent_model_input
1260
- controlnet_prompt_embeds = prompt_embeds
1261
-
1262
- if isinstance(controlnet_keep[i], list):
1263
- cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1264
- else:
1265
- controlnet_cond_scale = controlnet_conditioning_scale
1266
- if isinstance(controlnet_cond_scale, list):
1267
- controlnet_cond_scale = controlnet_cond_scale[0]
1268
- cond_scale = controlnet_cond_scale * controlnet_keep[i]
1269
-
1270
- down_block_res_samples, mid_block_res_sample = self.controlnet(
1271
- control_model_input,
1272
- t,
1273
- encoder_hidden_states=controlnet_prompt_embeds,
1274
- controlnet_cond=image,
1275
- conditioning_scale=cond_scale,
1276
- guess_mode=guess_mode,
1277
- return_dict=False,
1278
- )
1279
-
1280
- if guess_mode and self.do_classifier_free_guidance:
1281
- # Infered ControlNet only for the conditional batch.
1282
- # To apply the output of ControlNet to both the unconditional and conditional batches,
1283
- # add 0 to the unconditional batch to keep it unchanged.
1284
- down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1285
- mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1286
-
1287
- # predict the noise residual
1288
- noise_pred = self.unet(
1289
- latent_model_input,
1290
- t,
1291
- encoder_hidden_states=prompt_embeds,
1292
- timestep_cond=timestep_cond,
1293
- cross_attention_kwargs=self.cross_attention_kwargs,
1294
- down_block_additional_residuals=down_block_res_samples,
1295
- mid_block_additional_residual=mid_block_res_sample,
1296
- added_cond_kwargs=added_cond_kwargs,
1297
- return_dict=False,
1298
- )[0]
1299
-
1300
- # perform guidance
1301
- if self.do_classifier_free_guidance:
1302
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1303
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1304
-
1305
- # compute the previous noisy sample x_t -> x_t-1
1306
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1307
-
1308
- if callback_on_step_end is not None:
1309
- callback_kwargs = {}
1310
- for k in callback_on_step_end_tensor_inputs:
1311
- callback_kwargs[k] = locals()[k]
1312
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1313
-
1314
- latents = callback_outputs.pop("latents", latents)
1315
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1316
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1317
-
1318
- # call the callback, if provided
1319
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1320
- progress_bar.update()
1321
- if callback is not None and i % callback_steps == 0:
1322
- step_idx = i // getattr(self.scheduler, "order", 1)
1323
- callback(step_idx, t, latents)
1324
-
1325
- # If we do sequential model offloading, let's offload unet and controlnet
1326
- # manually for max memory savings
1327
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1328
- self.unet.to("cpu")
1329
- self.controlnet.to("cpu")
1330
- torch.cuda.empty_cache()
1331
-
1332
- if not output_type == "latent":
1333
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1334
- 0
1335
- ]
1336
- image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1337
- else:
1338
- image = latents
1339
- has_nsfw_concept = None
1340
-
1341
- if has_nsfw_concept is None:
1342
- do_denormalize = [True] * image.shape[0]
1343
- else:
1344
- do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1345
-
1346
- image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1347
-
1348
- # Offload all models
1349
- self.maybe_free_model_hooks()
1350
-
1351
- if not return_dict:
1352
- return (image, has_nsfw_concept)
1353
-
1354
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)